Compare commits
71 Commits
release-v1
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
8bc494ff12 | ||
|
4dd1fbe0e5 | ||
|
0c10e8d197 | ||
|
20e1e8cb62 | ||
|
09ab5b98e6 | ||
|
4e49a99e7d | ||
|
a4ad691bf6 | ||
|
53db05171c | ||
|
e155bc057c | ||
|
ae53b24bd8 | ||
|
736ed6e9d2 | ||
|
fff25482f9 | ||
|
18288c5315 | ||
|
eac836b541 | ||
|
78ae92c90b | ||
|
f16b45fd49 | ||
|
aa73c23b30 | ||
|
667e14cd7e | ||
|
44d07346c7 | ||
|
83ed249c14 | ||
|
a4fb5298de | ||
|
5493302c59 | ||
|
3dda2d53a1 | ||
|
76c3eaa7e0 | ||
|
91bf0b8853 | ||
|
142828a6d2 | ||
|
2ce99a7b4e | ||
|
4eca8aa2cf | ||
|
5e9348e59c | ||
|
847dda6e8d | ||
|
7d433fc2f0 | ||
|
924e55fd52 | ||
|
abfec2bf95 | ||
|
6a38f4da7b | ||
|
c86be0aa47 | ||
|
1d94f4a7d9 | ||
|
cb0e1e8bf1 | ||
|
7f98ce76d1 | ||
|
9df7048eb9 | ||
|
751c3d8ca1 | ||
|
5299eada33 | ||
|
c4d76b1c10 | ||
|
21494e1da0 | ||
|
33e71d1917 | ||
|
4f73ce548a | ||
|
a018ad9b73 | ||
|
6115faf2e0 | ||
|
14b4f7c869 | ||
|
e98525c647 | ||
|
e3c6858eba | ||
|
17b2d05b1e | ||
|
6007220504 | ||
|
1524934754 | ||
|
940fb3c8b2 | ||
|
5837776bb1 | ||
|
e568942b63 | ||
|
c2c74bf320 | ||
|
c7cfd43c5c | ||
|
2754ce7703 | ||
|
516f0a53c9 | ||
|
f69ee6fcf4 | ||
|
aa61f23b28 | ||
|
2f524dabf7 | ||
|
eefd24b696 | ||
|
a5c0160e6a | ||
|
9a13983a8a | ||
|
8fd8338e2a | ||
|
78ad2f05f0 | ||
|
8923f6155a | ||
|
99dbc5d102 | ||
|
dcae836f45 |
104
README.md
@ -1,6 +1,6 @@
|
||||
## 人脸搜索M:N
|
||||
|
||||
* 本项目是阿里云视觉智能开放平台的人脸1:N的开源替代,项目中使用的模型均为开源模型,项目支持milvus和proxima向量存储库,并具有较高的自定义能力。
|
||||
* 本项目是阿里云视觉智能开放平台的人脸1:N的开源替代,项目中使用的模型均为开源模型,项目支持opensearch(1.x版本支持milvus和proxima)向量存储库,并具有较高的自定义能力。
|
||||
|
||||
* 项目使用纯Java开发,免去使用Python带来的服务不稳定性。
|
||||
|
||||
@ -22,9 +22,7 @@
|
||||
|
||||
    2、[onnx](https://github.com/onnx/onnx)
|
||||
|
||||
    3、[milvus](https://github.com/milvus-io/milvus/)
|
||||
|
||||
    4、[proxima](https://github.com/alibaba/proximabilin)
|
||||
    3、[opensearch](https://opensearch.org/)
|
||||
|
||||
* 深度学习模型
|
||||
|
||||
@ -32,15 +30,27 @@
|
||||
|
||||
    2、[PCN](https://github.com/Rock-100/FaceKit/tree/master/PCN)
|
||||
|
||||
### 版本1.1.0更新
|
||||
### 版本2.1.0更新
|
||||
|
||||
* 1、修复已知BUG
|
||||
* 2、添加人脸比对1:1接口,详见文档:[05、人脸比对服务](https://gitee.com/open-visual/face-search/blob/dev-1.1.0/scripts/docs/doc-1.1.0.md#05%E4%BA%BA%E8%84%B8%E6%AF%94%E5%AF%B9%E6%9C%8D%E5%8A%A1)
|
||||
* 1、InsightScrfdFaceDetection升级模型,使检测更加稳定,同时添加了人脸角度检测。
|
||||
* 2、InsightScrfdFaceDetection正对不能正常检出人脸的图片增加了补边操作,防止因为人脸过大导致不能检测到人脸。
|
||||
* 3、添加SeetaFaceOpenRecognition的人脸特征提取器,目前人脸特征提取器支持InsightArcFaceRecognition与SeetaFaceOpenRecognition。
|
||||
* 4、修复由于人脸过小,导致对齐异常的BUG。
|
||||
* 5、程序添加了SeetaFace6的人脸关键点遮挡模型。
|
||||
* 6、升级opencv、opensearch、onnxruntime的maven依赖版本。
|
||||
|
||||
### 版本2.0.1更新
|
||||
|
||||
* 1、修复PCN模型存在的潜在内存泄露问题
|
||||
|
||||
### 版本2.0.0更新
|
||||
|
||||
* 1、添加对opensearch的支持,删除对proxima与milvus向量引擎的支持
|
||||
* 2、更新:删除搜索结果中的距离指标,仅保留置信度指标(余弦相似度)
|
||||
|
||||
### 项目文档
|
||||
|
||||
* 在线文档:[文档-1.1.0](https://gitee.com/open-visual/face-search/blob/v1.1.0/scripts/docs/doc-1.1.0.md)
|
||||
* 在线文档:[文档-2.1.0](scripts/docs/2.1.0.md)
|
||||
|
||||
* swagger文档:启动项目且开启swagger,访问:host:port/doc.html, 如 http://127.0.0.1:8080/doc.html
|
||||
|
||||
@ -51,71 +61,78 @@
|
||||
<dependency>
|
||||
<groupId>com.visual.face.search</groupId>
|
||||
<artifactId>face-search-client</artifactId>
|
||||
<version>1.1.0</version>
|
||||
<version>2.1.0</version>
|
||||
</dependency>
|
||||
```
|
||||
* 其他语言依赖
|
||||
|
||||
   使用restful接口:[文档-1.1.0](https://gitee.com/open-visual/face-search/blob/v1.1.0/scripts/docs/doc-1.1.0.md)
|
||||
   使用restful接口:[文档-2.1.0](scripts/docs/2.1.0.md)
|
||||
|
||||
|
||||
### 项目部署
|
||||
|
||||
* docker部署,脚本目录:face-search/scripts
|
||||
```
|
||||
1、使用milvus作为向量搜索引擎
|
||||
docker-compose -f docker-compose-milvus.yml --compatibility up -d
|
||||
1、配置环境变量:FACESEARCH_VOLUME_DIRECTORY,指定当前的挂载根路径,默认为当前路径
|
||||
|
||||
2、使用proxima作为向量搜索引擎
|
||||
docker-compose -f docker-compose-proxima.yml --compatibility up -d
|
||||
2、对opensearch的挂载目录进行赋权:
|
||||
新建目录:${FACESEARCH_VOLUME_DIRECTORY:-.}/volumes-face-search/opensearch/data
|
||||
目录赋权:chmod 777 ${FACESEARCH_VOLUME_DIRECTORY:-.}/volumes-face-search/opensearch/data
|
||||
|
||||
3、使用opensearch作为向量搜索引擎
|
||||
docker-compose -f docker-compose-opensearch.yml --compatibility up -d
|
||||
|
||||
4、服务访问:
|
||||
opensearch自带的可视化工具:http://127.0.0.1:5601
|
||||
facesearch的swagger文档: http://127.0.0.1:56789/doc.html
|
||||
```
|
||||
|
||||
* 项目编译
|
||||
* 项目编译,并打包为docker镜像
|
||||
```
|
||||
1、克隆项目
|
||||
1、java版本最低为:11;安装maven编译工具。安装docker。
|
||||
2、克隆项目
|
||||
git clone https://gitee.com/open-visual/face-search.git
|
||||
2、项目打包
|
||||
3、项目打包
|
||||
cd face-search && sh scripts/docker_build.sh
|
||||
```
|
||||
|
||||
* 部署参数
|
||||
|
||||
| 参数 | 描述 | 默认值 | 可选值|
|
||||
| -------- | -----: | :----: |--------|
|
||||
| VISUAL_SWAGGER_ENABLE | 是否开启swagger | true | |
|
||||
| SPRING_DATASOURCE_URL | 数据库地址 | | |
|
||||
| SPRING_DATASOURCE_USERNAME | 数据库用户名 | | |
|
||||
| SPRING_DATASOURCE_PASSWORD | 数据库密码 | | |
|
||||
| VISUAL_ENGINE_SELECTED | 向量存储引擎 | proxima |proxima,milvus |
|
||||
| VISUAL_ENGINE_PROXIMA_HOST | PROXIMA地址 | |VISUAL_ENGINE_SELECTED=proxima时生效 |
|
||||
| VISUAL_ENGINE_PROXIMA_PORT | PROXIMA端口 | 16000 |VISUAL_ENGINE_SELECTED=proxima时生效 |
|
||||
| VISUAL_ENGINE_MILVUS_HOST | MILVUS地址 | |VISUAL_ENGINE_SELECTED=milvus时生效 |
|
||||
| VISUAL_ENGINE_MILVUS_PORT | MILVUS端口 | 19530 |VISUAL_ENGINE_SELECTED=milvus时生效 |
|
||||
| VISUAL_MODEL_FACEDETECTION_NAME | 人脸检测模型名称 | PcnNetworkFaceDetection |PcnNetworkFaceDetection,InsightScrfdFaceDetection |
|
||||
| VISUAL_MODEL_FACEDETECTION_BACKUP_NAME | 备用人脸检测模型名称 | InsightScrfdFaceDetection |PcnNetworkFaceDetection,InsightScrfdFaceDetection |
|
||||
| VISUAL_MODEL_FACEKEYPOINT_NAME | 人脸关键点模型名称 | InsightCoordFaceKeyPoint |InsightCoordFaceKeyPoint |
|
||||
| VISUAL_MODEL_FACEALIGNMENT_NAME | 人脸对齐模型名称 | Simple106pFaceAlignment |Simple106pFaceAlignment,Simple005pFaceAlignment |
|
||||
| VISUAL_MODEL_FACERECOGNITION_NAME | 人脸特征提取模型名称 | InsightArcFaceRecognition |InsightArcFaceRecognition |
|
||||
| 参数 | 描述 | 默认值 | 可选值 |
|
||||
| -------- | -----: | :----: |---------------------------------------------------|
|
||||
| VISUAL_SWAGGER_ENABLE | 是否开启swagger | true | |
|
||||
| SPRING_DATASOURCE_URL | 数据库地址 | | |
|
||||
| SPRING_DATASOURCE_USERNAME | 数据库用户名 | root | |
|
||||
| SPRING_DATASOURCE_PASSWORD | 数据库密码 | root | |
|
||||
| VISUAL_ENGINE_OPENSEARCH_HOST | OPENSEARCH地址 | | |
|
||||
| VISUAL_ENGINE_OPENSEARCH_PORT | OPENSEARCH端口 | 9200 | |
|
||||
| VISUAL_ENGINE_OPENSEARCH_SCHEME | OPENSEARCH协议 | https | |
|
||||
| VISUAL_ENGINE_OPENSEARCH_USERNAME | OPENSEARCH用户名 | admin | |
|
||||
| VISUAL_ENGINE_OPENSEARCH_PASSWORD | OPENSEARCH密码 | admin | |
|
||||
| VISUAL_MODEL_FACEDETECTION_NAME | 人脸检测模型名称 | InsightScrfdFaceDetection | PcnNetworkFaceDetection,InsightScrfdFaceDetection |
|
||||
| VISUAL_MODEL_FACEDETECTION_BACKUP_NAME | 备用人脸检测模型名称 | PcnNetworkFaceDetection | PcnNetworkFaceDetection,InsightScrfdFaceDetection |
|
||||
| VISUAL_MODEL_FACEKEYPOINT_NAME | 人脸关键点模型名称 | InsightCoordFaceKeyPoint | InsightCoordFaceKeyPoint |
|
||||
| VISUAL_MODEL_FACEALIGNMENT_NAME | 人脸对齐模型名称 | Simple106pFaceAlignment | Simple106pFaceAlignment,Simple005pFaceAlignment |
|
||||
| VISUAL_MODEL_FACERECOGNITION_NAME | 人脸特征提取模型名称 | InsightArcFaceRecognition | InsightArcFaceRecognition,SeetaFaceOpenRecognition |
|
||||
|
||||
### 性能优化
|
||||
|
||||
* 项目中为了提高人脸的检出率,使用了主要和次要的人脸检测模型,目前实现了两种人脸检测模型insightface和PCN,在docker的服务中,默认主服务为PCN,备用服务为insightface。insightface的效率高,但针对于旋转了大角度的人脸检出率不高,而pcn则可以识别大角度旋转的图片,但效率低一些。若图像均为正脸的图像,建议使用insightface为主模型,pcn为备用模型,如何切换,请查看部署参数。
|
||||
|
||||
* 在测试过程中,针对milvus和proxima,发现proxima的速度比milvus稍快,但稳定性没有milvus好,线上服务使用时,还是建议使用milvus作为向量检索引擎。
|
||||
* 项目中为了提高人脸的检出率,使用了主要和次要的人脸检测模型,目前实现了两种人脸检测模型Insightface和PCN,在docker的服务中,默认主服务为Insightface,备用服务为PCN。insightface的效率高,但针对于旋转了大角度的人脸检出率不高,而pcn则可以识别大角度旋转的图片,但效率低一些。若图像均为正脸的图像,建议使用insightface为主模型,pcn为备用模型,如何切换,请查看部署参数。
|
||||
|
||||
### 项目演示
|
||||
|
||||
* 1.1.0 测试用例:face-search-test[测试用例-FaceSearchExample](https://gitee.com/open-visual/face-search/blob/master/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java)
|
||||
* 2.1.0 测试用例:face-search-test[测试用例-FaceSearchExample](https://gitee.com/open-visual/face-search/blob/master/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java)
|
||||
|
||||
* 
|
||||
* 
|
||||
|
||||
* 1.2.0 测试用例(做了优化,增强了搜索结果的区分度):face-search-test[测试用例-FaceSearchExample](https://gitee.com/open-visual/face-search/blob/master/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java)
|
||||
### 演员识别(手机打开体验更好)
|
||||
* [http://actor-search.divenswu.com](http://actor-search.divenswu.com)
|
||||
* 
|
||||
|
||||
* 
|
||||
|
||||
### 交流群
|
||||
|
||||
* 钉钉交流群
|
||||
* 钉钉交流群(已解散)
|
||||
|
||||
关注微信公众号回复:钉钉群
|
||||
|
||||
@ -129,4 +146,7 @@
|
||||
|
||||
|
||||
### 项目开源前端:感谢`HeX`的开源
|
||||
* [https://gitee.com/hexpang/face-search-web](https://gitee.com/hexpang/face-search-web)
|
||||
* [https://gitee.com/hexpang/face-search-web](https://gitee.com/hexpang/face-search-web)
|
||||
|
||||
### 欢迎来访我的其他开源项目
|
||||
* [车牌识别:https://gitee.com/open-visual/open-anpr](https://gitee.com/open-visual/open-anpr)
|
7
face-search-client/pom.xml
Executable file → Normal file
@ -2,12 +2,11 @@
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>face-search-client</artifactId>
|
||||
<groupId>com.visual.face.search</groupId>
|
||||
<version>1.1.0</version>
|
||||
<artifactId>face-search-client</artifactId>
|
||||
<version>2.1.0</version>
|
||||
|
||||
<properties>
|
||||
<java.version>1.8</java.version>
|
||||
@ -19,7 +18,7 @@
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>fastjson</artifactId>
|
||||
<version>1.2.58</version>
|
||||
<version>1.2.83</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents</groupId>
|
||||
|
@ -4,6 +4,7 @@ import java.util.Map;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import com.alibaba.fastjson.TypeReference;
|
||||
import com.visual.face.search.common.Api;
|
||||
import com.visual.face.search.http.HttpClient;
|
||||
import com.visual.face.search.model.*;
|
||||
@ -46,7 +47,6 @@ public class CollectHandler extends BaseHandler<CollectHandler>{
|
||||
.setMaxDocsPerSegment(collect.getMaxDocsPerSegment())
|
||||
.setSampleColumns(collect.getSampleColumns())
|
||||
.setFaceColumns(collect.getFaceColumns())
|
||||
.setSyncBinLog(collect.isSyncBinLog())
|
||||
.setShardsNum(collect.getShardsNum())
|
||||
.setStorageFaceInfo(collect.getStorageFaceInfo())
|
||||
.setStorageEngine(collect.getStorageEngine());
|
||||
@ -72,7 +72,7 @@ public class CollectHandler extends BaseHandler<CollectHandler>{
|
||||
MapParam param = MapParam.build()
|
||||
.put("namespace", namespace)
|
||||
.put("collectionName", collectionName);
|
||||
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_get), param);
|
||||
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_get), param, new TypeReference<Response<CollectRep>>(){});
|
||||
}
|
||||
|
||||
/**
|
||||
@ -81,7 +81,7 @@ public class CollectHandler extends BaseHandler<CollectHandler>{
|
||||
*/
|
||||
public Response<List<CollectRep>> collectList(){
|
||||
MapParam param = MapParam.build().put("namespace", namespace);
|
||||
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_list), param);
|
||||
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_list), param, new TypeReference<Response<List<CollectRep>>>(){});
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package com.visual.face.search.handle;
|
||||
|
||||
import com.alibaba.fastjson.TypeReference;
|
||||
import com.visual.face.search.common.Api;
|
||||
import com.visual.face.search.http.HttpClient;
|
||||
import com.visual.face.search.model.*;
|
||||
@ -46,7 +47,7 @@ public class FaceHandler extends BaseHandler<FaceHandler>{
|
||||
.setFaceScoreThreshold(face.getFaceScoreThreshold())
|
||||
.setMinConfidenceThresholdWithThisSample(face.getMinConfidenceThresholdWithThisSample())
|
||||
.setMaxConfidenceThresholdWithOtherSample(face.getMaxConfidenceThresholdWithOtherSample());
|
||||
return HttpClient.post(Api.getUrl(this.serverHost, Api.face_create), faceReq);
|
||||
return HttpClient.post(Api.getUrl(this.serverHost, Api.face_create), faceReq, new TypeReference<Response<FaceRep>>(){});
|
||||
}
|
||||
|
||||
|
||||
|
@ -18,8 +18,6 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
|
||||
private List<FiledColumn> sampleColumns = new ArrayList<>();
|
||||
/**自定义的人脸字段**/
|
||||
private List<FiledColumn> faceColumns = new ArrayList<>();
|
||||
/**启用binlog同步**/
|
||||
private Boolean syncBinLog = false;
|
||||
/**是否保留图片及人脸信息**/
|
||||
private Boolean storageFaceInfo = false;
|
||||
/**保留图片及人脸信息的存储组件**/
|
||||
@ -86,17 +84,6 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public boolean isSyncBinLog() {
|
||||
return null == syncBinLog ? false : syncBinLog;
|
||||
}
|
||||
|
||||
public ExtendsVo setSyncBinLog(Boolean syncBinLog) {
|
||||
if(null != syncBinLog){
|
||||
this.syncBinLog = syncBinLog;
|
||||
}
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public boolean getStorageFaceInfo() {
|
||||
return null == storageFaceInfo ? false : storageFaceInfo;
|
||||
}
|
||||
|
@ -11,8 +11,6 @@ public class SampleFace implements Comparable<SampleFace>, Serializable {
|
||||
/**人脸人数质量**/
|
||||
private Float faceScore;
|
||||
/**转换后的置信度**/
|
||||
private Float distance;
|
||||
/**转换后的置信度**/
|
||||
private Float confidence;
|
||||
/**样本扩展的额外数据**/
|
||||
private KeyValues sampleData;
|
||||
@ -67,14 +65,6 @@ public class SampleFace implements Comparable<SampleFace>, Serializable {
|
||||
this.faceScore = faceScore;
|
||||
}
|
||||
|
||||
public Float getDistance() {
|
||||
return distance;
|
||||
}
|
||||
|
||||
public void setDistance(Float distance) {
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
public Float getConfidence() {
|
||||
return confidence;
|
||||
}
|
||||
|
5
face-search-core/pom.xml
Executable file → Normal file
@ -5,8 +5,9 @@
|
||||
<parent>
|
||||
<artifactId>face-search</artifactId>
|
||||
<groupId>com.visual.face.search</groupId>
|
||||
<version>1.2.0</version>
|
||||
<version>2.1.0</version>
|
||||
</parent>
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<artifactId>face-search-core</artifactId>
|
||||
|
||||
@ -30,6 +31,6 @@
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>fastjson</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
@ -0,0 +1,18 @@
|
||||
package com.visual.face.search.core.base;
|
||||
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.FaceInfo.Attribute;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface FaceAttribute {
|
||||
|
||||
/**
|
||||
* 人脸属性信息
|
||||
* @param imageMat 图像数据
|
||||
* @param params 参数信息
|
||||
* @return
|
||||
*/
|
||||
Attribute inference(ImageMat imageMat, Map<String, Object> params);
|
||||
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
package com.visual.face.search.core.base;
|
||||
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.QualityInfo;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 人脸关键点检测
|
||||
*/
|
||||
public interface FaceMaskPoint {
|
||||
|
||||
/**
|
||||
* 人脸关键点检测
|
||||
* @param imageMat 图像数据
|
||||
* @param params 参数信息
|
||||
* @return
|
||||
*/
|
||||
QualityInfo.MaskPoints inference(ImageMat imageMat, Map<String, Object> params);
|
||||
|
||||
}
|
@ -4,6 +4,7 @@ import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
|
||||
public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
/**人脸分数**/
|
||||
public float score;
|
||||
@ -15,6 +16,8 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
public Points points;
|
||||
/**人脸特征向量**/
|
||||
public Embedding embedding;
|
||||
/**人脸属性信息**/
|
||||
public Attribute attribute;
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
@ -120,9 +123,9 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
* @return 旋转后的角
|
||||
*/
|
||||
public Point rotation(Point center, float angle){
|
||||
double k = new Float(Math.toRadians(angle));
|
||||
float nx1 = new Float((this.x-center.x)*Math.cos(k) +(this.y-center.y)*Math.sin(k)+center.x);
|
||||
float ny1 = new Float(-(this.x-center.x)*Math.sin(k) + (this.y-center.y)*Math.cos(k)+center.y);
|
||||
double k = Math.toRadians(angle);
|
||||
float nx1 = (float) ((this.x - center.x) * Math.cos(k) + (this.y - center.y) * Math.sin(k) + center.x);
|
||||
float ny1 = (float) (-(this.x - center.x) * Math.sin(k) + (this.y - center.y) * Math.cos(k) + center.y);
|
||||
return new Point(nx1, ny1);
|
||||
}
|
||||
|
||||
@ -134,6 +137,18 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
public float distance(Point that){
|
||||
return (float) Math.sqrt(Math.pow((this.x-that.x), 2)+Math.pow((this.y-that.y), 2));
|
||||
}
|
||||
|
||||
/**
|
||||
* 将点进行平移
|
||||
* @param top 向上移动的像素点数
|
||||
* @param bottom 向下移动的像素点数
|
||||
* @param left 向左移动的像素点数
|
||||
* @param right 向右移动的像素点数
|
||||
* @return 平移后的点
|
||||
*/
|
||||
public Point move(int left, int right, int top, int bottom){
|
||||
return new Point(x - left + right, y - top + bottom);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -243,6 +258,22 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
}
|
||||
return points;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将点进行平移
|
||||
* @param top 向上移动的像素点数
|
||||
* @param bottom 向下移动的像素点数
|
||||
* @param left 向左移动的像素点数
|
||||
* @param right 向右移动的像素点数
|
||||
* @return 平移后的点
|
||||
*/
|
||||
public Points move(int left, int right, int top, int bottom){
|
||||
Points points = build();
|
||||
for(Point item : this){
|
||||
points.add(item.move(left, right, top, bottom));
|
||||
}
|
||||
return points;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -333,7 +364,7 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
* 判断当前的人脸框是否是标准的人脸框,即非旋转后的人脸框。
|
||||
* @return 否是标准的人脸框
|
||||
*/
|
||||
public boolean isNormal(){
|
||||
public boolean normal(){
|
||||
if((int)leftTop.x == (int)leftBottom.x && (int)leftTop.y == (int)rightTop.y){
|
||||
if((int)rightBottom.x == (int)rightTop.x && (int)rightBottom.y == (int)leftBottom.y){
|
||||
return true;
|
||||
@ -416,6 +447,23 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
new Point(leftBottom.x + change_x_p2_p4, leftBottom.y + change_y_p2_p4)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将框进行平移
|
||||
* @param top 向上移动的像素点数
|
||||
* @param bottom 向下移动的像素点数
|
||||
* @param left 向左移动的像素点数
|
||||
* @param right 向右移动的像素点数
|
||||
* @return 平移后的框
|
||||
*/
|
||||
public FaceBox move(int left, int right, int top, int bottom){
|
||||
return new FaceBox(
|
||||
new Point(leftTop.x - left + right, leftTop.y - top + bottom),
|
||||
new Point(rightTop.x - left + right, rightTop.y - top + bottom),
|
||||
new Point(rightBottom.x - left + right, rightBottom.y - top + bottom),
|
||||
new Point(leftBottom.x - left + right, leftBottom.y - top + bottom)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -446,4 +494,66 @@ public class FaceInfo implements Comparable<FaceInfo>, Serializable {
|
||||
return new Embedding(image, embeds);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 人脸属性信息
|
||||
*/
|
||||
public static class Attribute implements Serializable {
|
||||
public Integer age;
|
||||
public Integer gender;
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param gender 前图片的base64编码值
|
||||
* @param age 当前图片的人脸向量信息
|
||||
*/
|
||||
private Attribute(Gender gender, Integer age) {
|
||||
this.age = age;
|
||||
this.gender = null == gender ? -1 : gender.getCode();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取枚举值
|
||||
* @return
|
||||
*/
|
||||
public Gender valueOfGender(){
|
||||
return Gender.valueOf(this.gender);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建人脸属性信息
|
||||
* @param gender 前图片的base64编码值
|
||||
* @param age 当前图片的人脸向量信息
|
||||
*/
|
||||
public static Attribute build(Gender gender, Integer age){
|
||||
return new Attribute(gender, age);
|
||||
}
|
||||
}
|
||||
|
||||
public static enum Gender {
|
||||
MALE(0), //男性
|
||||
FEMALE(1), //女性
|
||||
UNKNOWN(-1); //未知
|
||||
|
||||
private int code;
|
||||
|
||||
Gender(int code) {
|
||||
this.code = code;
|
||||
}
|
||||
|
||||
public int getCode() {
|
||||
return this.code;
|
||||
}
|
||||
|
||||
public static Gender valueOf(Integer code) {
|
||||
code = null == code ? -1 : code;
|
||||
if(code == 0){
|
||||
return MALE;
|
||||
}
|
||||
if(code == 1){
|
||||
return FEMALE;
|
||||
}
|
||||
return UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,9 +66,9 @@ public class ImageMat implements Serializable {
|
||||
public static ImageMat fromBase64(String base64Str){
|
||||
InputStream inputStream = null;
|
||||
try {
|
||||
// 新版本JDK被移除,替换为Base64.Decoder
|
||||
// BASE64Decoder decoder = new BASE64Decoder();
|
||||
// byte[] data = decoder.decodeBuffer(base64Str);
|
||||
if(base64Str.contains(",")){
|
||||
base64Str = base64Str.substring(base64Str.indexOf(",")+1);
|
||||
}
|
||||
Base64.Decoder decoder = Base64.getMimeDecoder();
|
||||
byte[] data = decoder.decode(base64Str);
|
||||
inputStream = new ByteArrayInputStream(data);
|
||||
@ -252,6 +252,55 @@ public class ImageMat implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 对图像进行补边操作,不释放原始的图片
|
||||
* @param top 向上扩展的高度
|
||||
* @param bottom 向下扩展的高度
|
||||
* @param left 向左扩展的宽度
|
||||
* @param right 向右扩展的宽度
|
||||
* @param borderType 补边的类型
|
||||
* @return 补边后的图像
|
||||
*/
|
||||
public ImageMat copyMakeBorderAndNotReleaseMat(int top, int bottom, int left, int right, int borderType){
|
||||
return this.copyMakeBorder(top, bottom, left, right, borderType, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* 对图像进行补边操作,并且释放原始的图片
|
||||
* @param top 向上扩展的高度
|
||||
* @param bottom 向下扩展的高度
|
||||
* @param left 向左扩展的宽度
|
||||
* @param right 向右扩展的宽度
|
||||
* @param borderType 补边的类型
|
||||
* @return 补边后的图像
|
||||
*/
|
||||
public ImageMat copyMakeBorderAndDoReleaseMat(int top, int bottom, int left, int right, int borderType){
|
||||
return this.copyMakeBorder(top, bottom, left, right, borderType, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* 对图像进行补边操作
|
||||
* @param top 向上扩展的高度
|
||||
* @param bottom 向下扩展的高度
|
||||
* @param left 向左扩展的宽度
|
||||
* @param right 向右扩展的宽度
|
||||
* @param borderType 补边的类型
|
||||
* @param release 是否释放原始的图片
|
||||
* @return 补边后的图像
|
||||
*/
|
||||
private ImageMat copyMakeBorder(int top, int bottom, int left, int right, int borderType, boolean release){
|
||||
try {
|
||||
Mat tempMat = new Mat();
|
||||
Core.copyMakeBorder(mat, tempMat, top, bottom, left, right, borderType);
|
||||
return new ImageMat(tempMat);
|
||||
}finally {
|
||||
if(release){
|
||||
this.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 对图像进行预处理,不释放原始图片数据
|
||||
* @param scale 图像各通道数值的缩放比例
|
||||
@ -264,7 +313,7 @@ public class ImageMat implements Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* 对图像进行预处理,并释放原始图片数据
|
||||
* 对图像进行预处理,并释放原始图片数据:(先交换RB通道(swapRB),再减法(mean),最后缩放(scale))
|
||||
* @param scale 图像各通道数值的缩放比例
|
||||
* @param mean 用于各通道减去的值,以降低光照的影响
|
||||
* @param swapRB 交换RB通道,默认为False.
|
||||
@ -681,8 +730,12 @@ public class ImageMat implements Serializable {
|
||||
*/
|
||||
public void release(){
|
||||
if(this.mat != null){
|
||||
this.mat.release();
|
||||
this.mat = null;
|
||||
try {
|
||||
this.mat.release();
|
||||
this.mat = null;
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,44 @@
|
||||
package com.visual.face.search.core.domain;
|
||||
|
||||
import org.opencv.core.Mat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
public class Mats extends ArrayList<Mat> {
|
||||
|
||||
public static Mats build(){
|
||||
return new Mats();
|
||||
}
|
||||
|
||||
public Mats append(Mat mat){
|
||||
this.add(mat);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Mats clone(){
|
||||
Mats mats = new Mats();
|
||||
for(Mat mat : this){
|
||||
mats.add(mat.clone());
|
||||
}
|
||||
return mats;
|
||||
}
|
||||
|
||||
public void release(){
|
||||
if(this.isEmpty()){
|
||||
return;
|
||||
}
|
||||
for(Mat mat : this){
|
||||
if(null != mat){
|
||||
try {
|
||||
mat.release();
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
mat = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.clear();
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,129 @@
|
||||
package com.visual.face.search.core.domain;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class QualityInfo {
|
||||
|
||||
public MaskPoints maskPoints;
|
||||
|
||||
private QualityInfo(MaskPoints maskPoints) {
|
||||
this.maskPoints = maskPoints;
|
||||
}
|
||||
|
||||
public static QualityInfo build(MaskPoints maskPoints){
|
||||
return new QualityInfo(maskPoints);
|
||||
}
|
||||
|
||||
public MaskPoints getMaskPoints() {
|
||||
return maskPoints;
|
||||
}
|
||||
|
||||
|
||||
public boolean isMask(){
|
||||
return null != this.maskPoints && this.maskPoints.isMask();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 遮挡类
|
||||
*/
|
||||
public static class Mask implements Serializable {
|
||||
/**遮挡分数*/
|
||||
public float score;
|
||||
|
||||
public static Mask build(float score){
|
||||
return new QualityInfo.Mask(score);
|
||||
}
|
||||
|
||||
private Mask(float score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public boolean isMask(){
|
||||
return this.score >= 0.5;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Mask{" + "score=" + score + '}';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 点遮挡类
|
||||
*/
|
||||
public static class MaskPoint extends Mask{
|
||||
/**坐标X的值**/
|
||||
public float x;
|
||||
/**坐标Y的值**/
|
||||
public float y;
|
||||
|
||||
public static MaskPoint build(float x, float y, float score){
|
||||
return new MaskPoint(x, y, score);
|
||||
}
|
||||
|
||||
private MaskPoint(float x, float y, float score) {
|
||||
super(score);
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
public float getX() {
|
||||
return x;
|
||||
}
|
||||
|
||||
public float getY() {
|
||||
return y;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MaskPoint{" + "x=" + x + ", y=" + y + ", score=" + score + '}';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 点遮挡类集合
|
||||
*/
|
||||
public static class MaskPoints extends ArrayList<MaskPoint> {
|
||||
|
||||
private MaskPoints(){}
|
||||
|
||||
/**
|
||||
* 构建一个集合
|
||||
* @return
|
||||
*/
|
||||
public static MaskPoints build(){
|
||||
return new MaskPoints();
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加点
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
public MaskPoints add(MaskPoint...point){
|
||||
super.addAll(Arrays.asList(point));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 判定是否存在遮挡
|
||||
* @return
|
||||
*/
|
||||
public boolean isMask(){
|
||||
for(MaskPoint point : this){
|
||||
if(point.isMask()){
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
@ -2,19 +2,15 @@ package com.visual.face.search.core.extract;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.visual.face.search.core.base.FaceAlignment;
|
||||
import com.visual.face.search.core.base.FaceDetection;
|
||||
import com.visual.face.search.core.base.FaceKeyPoint;
|
||||
import com.visual.face.search.core.base.FaceRecognition;
|
||||
import org.opencv.core.Mat;
|
||||
import com.visual.face.search.core.base.*;
|
||||
import com.visual.face.search.core.domain.ExtParam;
|
||||
import com.visual.face.search.core.domain.FaceImage;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
|
||||
import com.visual.face.search.core.utils.CropUtil;
|
||||
import com.visual.face.search.core.utils.MaskUtil;
|
||||
import org.opencv.core.Mat;
|
||||
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
|
||||
|
||||
/**
|
||||
* 人脸特征提取器实现
|
||||
@ -28,6 +24,7 @@ public class FaceFeatureExtractorImpl implements FaceFeatureExtractor {
|
||||
private FaceAlignment faceAlignment;
|
||||
private FaceRecognition faceRecognition;
|
||||
private FaceDetection backupFaceDetection;
|
||||
private FaceAttribute faceAttribute;
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
@ -37,10 +34,14 @@ public class FaceFeatureExtractorImpl implements FaceFeatureExtractor {
|
||||
* @param faceAlignment 人脸对齐模型
|
||||
* @param faceRecognition 人脸特征提取模型
|
||||
*/
|
||||
public FaceFeatureExtractorImpl(FaceDetection faceDetection, FaceDetection backupFaceDetection, FaceKeyPoint faceKeyPoint, FaceAlignment faceAlignment, FaceRecognition faceRecognition) {
|
||||
public FaceFeatureExtractorImpl(
|
||||
FaceDetection faceDetection, FaceDetection backupFaceDetection,
|
||||
FaceKeyPoint faceKeyPoint, FaceAlignment faceAlignment,
|
||||
FaceRecognition faceRecognition, FaceAttribute faceAttribute) {
|
||||
this.faceKeyPoint = faceKeyPoint;
|
||||
this.faceDetection = faceDetection;
|
||||
this.faceAlignment = faceAlignment;
|
||||
this.faceAttribute = faceAttribute;
|
||||
this.faceRecognition = faceRecognition;
|
||||
this.backupFaceDetection = backupFaceDetection;
|
||||
}
|
||||
@ -71,13 +72,19 @@ public class FaceFeatureExtractorImpl implements FaceFeatureExtractor {
|
||||
ImageMat cropImageMat = null;
|
||||
ImageMat alignmentImage = null;
|
||||
try {
|
||||
//缩放人脸框的比例
|
||||
float scaling = extParam.getScaling() <= 0 ? defScaling : extParam.getScaling();
|
||||
//通过旋转角度获取正脸坐标,并进行图像裁剪
|
||||
FaceInfo.FaceBox box = faceInfo.rotateFaceBox().scaling(scaling);
|
||||
cropFace = CropUtil.crop(image.toCvMat(), box);
|
||||
//人脸标记关键点
|
||||
FaceInfo.FaceBox rotateFaceBox = faceInfo.rotateFaceBox();
|
||||
cropFace = CropUtil.crop(image.toCvMat(), rotateFaceBox);
|
||||
cropImageMat = ImageMat.fromCVMat(cropFace);
|
||||
//人脸属性检测
|
||||
FaceInfo.Attribute attribute = this.faceAttribute.inference(cropImageMat, params);
|
||||
faceInfo.attribute = attribute;
|
||||
//进行缩放人脸区域,并裁剪图片
|
||||
float scaling = extParam.getScaling() <= 0 ? defScaling : extParam.getScaling();
|
||||
FaceInfo.FaceBox box = rotateFaceBox.scaling(scaling);
|
||||
cropFace = CropUtil.crop(image.toCvMat(), box);
|
||||
cropImageMat = ImageMat.fromCVMat(cropFace);
|
||||
//人脸标记关键点
|
||||
FaceInfo.Points corpPoints = this.faceKeyPoint.inference(cropImageMat, params);
|
||||
//还原原始图片中的关键点
|
||||
FaceInfo.Point corpImageCenter = FaceInfo.Point.build((float)cropImageMat.center().x, (float)cropImageMat.center().y);
|
||||
|
@ -42,7 +42,7 @@ public class InsightArcFaceRecognition extends BaseOnnxInfer implements FaceRec
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = getSession().run(Collections.singletonMap(getInputName(), tensor));
|
||||
float[][] embeds = (float[][]) output.get(0).getValue();
|
||||
return FaceInfo.Embedding.build(image.toBase64AndNoReleaseMat(), embeds[0]);
|
||||
return Embedding.build(image.toBase64AndNoReleaseMat(), embeds[0]);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
|
@ -0,0 +1,115 @@
|
||||
package com.visual.face.search.core.models;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.visual.face.search.core.base.BaseOnnxInfer;
|
||||
import com.visual.face.search.core.base.FaceAttribute;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.utils.MathUtil;
|
||||
import org.apache.commons.math3.linear.RealMatrix;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 人脸属性检测:性别+年龄
|
||||
* git:https://github.com/deepinsight/insightface/tree/master/attribute
|
||||
*/
|
||||
public class InsightAttributeDetection extends BaseOnnxInfer implements FaceAttribute {
|
||||
|
||||
private static final int[] inputSize = new int[]{96, 96};
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param modelPath 模型路径
|
||||
* @param threads 线程数
|
||||
*/
|
||||
public InsightAttributeDetection(String modelPath, int threads) {
|
||||
super(modelPath, threads);
|
||||
}
|
||||
|
||||
/**
|
||||
* 人脸属性信息
|
||||
* @param imageMat 图像数据
|
||||
* @param params 参数信息
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public FaceInfo.Attribute inference(ImageMat imageMat, Map<String, Object> params) {
|
||||
Mat M =null;
|
||||
Mat img = null;
|
||||
OnnxTensor tensor = null;
|
||||
OrtSession.Result output = null;
|
||||
try {
|
||||
Mat image = imageMat.toCvMat();
|
||||
int w = image.size(1);
|
||||
int h = image.size(0);
|
||||
float cx = 1.0f * w / 2;
|
||||
float cy = 1.0f * h / 2;
|
||||
float[]center = new float[]{cx, cy};
|
||||
float rotate = 0;
|
||||
float _scale = (float) (1.0f * inputSize[0] / (Math.max(w, h)*1.5));
|
||||
Mat[] transform = transform(image, center, inputSize, _scale, rotate);
|
||||
img = transform[0];
|
||||
M = transform[1];
|
||||
tensor = ImageMat.fromCVMat(img)
|
||||
.blobFromImageAndDoReleaseMat(1.0, new Scalar(0, 0, 0), true)
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = this.getSession().run(Collections.singletonMap(this.getInputName(), tensor));
|
||||
float[] value = ((float[][]) output.get(0).getValue())[0];
|
||||
Integer age = Double.valueOf(Math.floor(value[2] * 100)).intValue();
|
||||
FaceInfo.Gender gender = (value[0] > value[1]) ? FaceInfo.Gender.FEMALE : FaceInfo.Gender.MALE;
|
||||
return FaceInfo.Attribute.build(gender, age);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
if(null != tensor){
|
||||
tensor.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
if(null != M){
|
||||
M.release();
|
||||
}
|
||||
if(null != img){
|
||||
img.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取人脸数据和仿射矩阵
|
||||
* @param image
|
||||
* @param center
|
||||
* @param outputSize
|
||||
* @param scale
|
||||
* @param rotation
|
||||
* @return
|
||||
*/
|
||||
private static Mat[] transform(Mat image, float[]center, int[]outputSize, float scale, float rotation){
|
||||
double scale_ratio = scale;
|
||||
double rot = rotation * Math.PI / 180.0;
|
||||
double cx = center[0] * scale_ratio;
|
||||
double cy = center[1] * scale_ratio;
|
||||
//矩阵构造
|
||||
RealMatrix t1 = MathUtil.similarityTransform((Double[][]) null, scale_ratio, null, null);
|
||||
RealMatrix t2 = MathUtil.similarityTransform((Double[][]) null, null, null, new Double[]{- cx, - cy});
|
||||
RealMatrix t3 = MathUtil.similarityTransform((Double[][]) null, null, rot, null);
|
||||
RealMatrix t4 = MathUtil.similarityTransform((Double[][]) null, null, null, new Double[]{1.0*outputSize[0]/2, 1.0*outputSize[1]/2});
|
||||
RealMatrix tx = MathUtil.dotProduct(t4, MathUtil.dotProduct(t3, MathUtil.dotProduct(t2, t1)));
|
||||
RealMatrix tm = tx.getSubMatrix(0, 1, 0, 2);
|
||||
//仿射矩阵
|
||||
Mat matMTemp = new MatOfDouble(MathUtil.flatMatrix(tm, 1).toArray());
|
||||
Mat matM = new Mat(2, 3, CvType.CV_32FC3);
|
||||
matMTemp.reshape(1,2).copyTo(matM);
|
||||
matMTemp.release();
|
||||
//使用open cv做仿射变换
|
||||
Mat dst = new Mat();
|
||||
Imgproc.warpAffine(image, dst, matM, new Size(outputSize[0], outputSize[1]));
|
||||
return new Mat[]{dst, matM};
|
||||
}
|
||||
}
|
@ -7,6 +7,9 @@ import com.visual.face.search.core.base.BaseOnnxInfer;
|
||||
import com.visual.face.search.core.base.FaceDetection;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.utils.ReleaseUtil;
|
||||
import org.opencv.core.Core;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Scalar;
|
||||
|
||||
import java.util.*;
|
||||
@ -24,6 +27,18 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
public final static float defScoreTh = 0.5f;
|
||||
//人脸重叠iou阈值
|
||||
public final static float defIouTh = 0.7f;
|
||||
//给人脸框一个默认的缩放
|
||||
public final static float defBoxScale = 1.0f;
|
||||
//人脸框缩放参数KEY
|
||||
public final static String scrfdFaceboxScaleParamKey = "scrfdFaceboxScale";
|
||||
//人脸框默认需要进行角度检测
|
||||
public final static boolean defNeedCheckFaceAngle = true;
|
||||
//是否需要进行角度检测的参数KEY
|
||||
public final static String scrfdFaceNeedCheckFaceAngleParamKey = "scrfdFaceNeedCheckFaceAngle";
|
||||
//人脸框默认需要进行角度检测
|
||||
public final static boolean defNoFaceImageNeedMakeBorder = true;
|
||||
//是否需要进行角度检测的参数KEY
|
||||
public final static String scrfdNoFaceImageNeedMakeBorderParamKey = "scrfdNoFaceImageNeedMakeBorder";
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
@ -43,11 +58,46 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
*/
|
||||
@Override
|
||||
public List<FaceInfo> inference(ImageMat image, float scoreTh, float iouTh, Map<String, Object> params) {
|
||||
List<FaceInfo> faceInfos = this.modelInference(image, scoreTh,iouTh, params);
|
||||
//对图像进行补边操作,进行二次识别
|
||||
if(this.getNoFaceImageNeedMakeBorder(params) && faceInfos.isEmpty()){
|
||||
//防止由于人脸占用大,导致检测模型识别失败
|
||||
int t = Double.valueOf(image.toCvMat().height() * 0.2).intValue();
|
||||
int b = Double.valueOf(image.toCvMat().height() * 0.2).intValue();
|
||||
int l = Double.valueOf(image.toCvMat().width() * 0.2).intValue();
|
||||
int r = Double.valueOf(image.toCvMat().width() * 0.2).intValue();
|
||||
ImageMat tempMat = null;
|
||||
try {
|
||||
//补边识别
|
||||
tempMat=image.copyMakeBorderAndNotReleaseMat(t, b, l, r, Core.BORDER_CONSTANT);
|
||||
faceInfos = this.modelInference(tempMat, scoreTh,iouTh, params);
|
||||
for(FaceInfo faceInfo : faceInfos){
|
||||
//还原原始的坐标
|
||||
faceInfo.box = faceInfo.box.move(l, 0, t, 0);
|
||||
faceInfo.points = faceInfo.points.move(l, 0, t, 0);
|
||||
}
|
||||
}finally {
|
||||
ReleaseUtil.release(tempMat);
|
||||
}
|
||||
}
|
||||
return faceInfos;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 模型推理,获取人脸信息
|
||||
* @param image 图像信息
|
||||
* @param scoreTh 人脸人数阈值
|
||||
* @param iouTh 人脸iou阈值
|
||||
* @return 人脸模型
|
||||
*/
|
||||
public List<FaceInfo> modelInference(ImageMat image, float scoreTh, float iouTh, Map<String, Object> params) {
|
||||
OnnxTensor tensor = null;
|
||||
OrtSession.Result output = null;
|
||||
ImageMat imageMat = image.clone();
|
||||
try {
|
||||
float imgScale = 1.0f;
|
||||
float boxScale = getBoxScale(params);
|
||||
iouTh = iouTh <= 0 ? defIouTh : iouTh;
|
||||
scoreTh = scoreTh <= 0 ? defScoreTh : scoreTh;
|
||||
int imageWidth = imageMat.getWidth(), imageHeight = imageMat.getHeight();
|
||||
@ -68,7 +118,10 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
.blobFromImageAndDoReleaseMat(1.0/128, new Scalar(127.5, 127.5, 127.5), true)
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = getSession().run(Collections.singletonMap(getInputName(), tensor));
|
||||
return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale);
|
||||
//获取人脸信息
|
||||
List<FaceInfo> faceInfos = fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale, boxScale);
|
||||
//对人脸进行角度检查
|
||||
return this.checkFaceAngle(faceInfos, this.getNeedCheckFaceAngle(params));
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
@ -94,7 +147,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
* @return
|
||||
* @throws OrtException
|
||||
*/
|
||||
private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale) throws OrtException {
|
||||
private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale, float boxScale) throws OrtException {
|
||||
//分数过滤及计算正确的人脸框值
|
||||
List<FaceInfo> faceInfos = new ArrayList<>();
|
||||
for(int index=0; index< 3; index++) {
|
||||
@ -122,7 +175,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
float pointY = (point[2*pointIndex+1] * strides[index] + anchorY) * imgScale;
|
||||
keyPoints.add(FaceInfo.Point.build(pointX, pointY));
|
||||
}
|
||||
faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2), keyPoints));
|
||||
faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2).scaling(boxScale), keyPoints));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -147,4 +200,117 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
return faces;
|
||||
}
|
||||
|
||||
/**
|
||||
* 对人脸进行角度检测,这里通过5个关键点来确定当前人脸的角度
|
||||
* @param faceInfos 人脸信息
|
||||
* @param needCheckFaceAngle 是否启用检测
|
||||
* @return
|
||||
*/
|
||||
private List<FaceInfo> checkFaceAngle(List<FaceInfo> faceInfos, boolean needCheckFaceAngle){
|
||||
if(!needCheckFaceAngle || null == faceInfos || faceInfos.isEmpty()){
|
||||
return faceInfos;
|
||||
}
|
||||
for(FaceInfo faceInfo : faceInfos){
|
||||
//计算当前人脸的角度数据
|
||||
float ax1 = faceInfo.points.get(1).x;
|
||||
float ay1 = faceInfo.points.get(1).y;
|
||||
float ax2 = faceInfo.points.get(0).x;
|
||||
float ay2 = faceInfo.points.get(0).y;
|
||||
int atan = Double.valueOf(Math.atan2((ay2-ay1), (ax2-ax1)) / Math.PI * 180).intValue();
|
||||
int angle = (180 - atan + 360) % 360;
|
||||
int ki = (angle + 45) % 360 / 90;
|
||||
int rotate = angle - (90 * ki); //
|
||||
float scaling = 1 + Double.valueOf(Math.abs(Math.sin(Math.toRadians(rotate)))).floatValue() / 3;
|
||||
faceInfo.angle = angle;
|
||||
//重组坐标点, 旋转及缩放
|
||||
if(ki == 0){
|
||||
FaceInfo.Point leftTop = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y1());
|
||||
FaceInfo.Point rightTop = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y1());
|
||||
FaceInfo.Point rightBottom = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y2());
|
||||
FaceInfo.Point leftBottom = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y2());
|
||||
faceInfo.box = new FaceInfo.FaceBox(leftTop, rightTop, rightBottom, leftBottom);
|
||||
faceInfo.box = faceInfo.box.rotate(rotate).scaling(scaling).rotate(-angle);
|
||||
}else if(ki == 1){
|
||||
FaceInfo.Point leftTop = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y2());
|
||||
FaceInfo.Point rightTop = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y1());
|
||||
FaceInfo.Point rightBottom = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y1());
|
||||
FaceInfo.Point leftBottom = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y2());
|
||||
faceInfo.box = new FaceInfo.FaceBox(leftTop, rightTop, rightBottom, leftBottom);
|
||||
faceInfo.box = faceInfo.box.rotate(rotate).scaling(scaling).rotate(-angle);
|
||||
}else if(ki == 2){
|
||||
FaceInfo.Point leftTop = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y2());
|
||||
FaceInfo.Point rightTop = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y2());
|
||||
FaceInfo.Point rightBottom = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y1());
|
||||
FaceInfo.Point leftBottom = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y1());
|
||||
faceInfo.box = new FaceInfo.FaceBox(leftTop, rightTop, rightBottom, leftBottom);
|
||||
faceInfo.box = faceInfo.box.rotate(rotate).scaling(scaling).rotate(-angle);
|
||||
}else if(ki == 3){
|
||||
FaceInfo.Point leftTop = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y1());
|
||||
FaceInfo.Point rightTop = FaceInfo.Point.build(faceInfo.box.x2(),faceInfo.box.y2());
|
||||
FaceInfo.Point rightBottom = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y2());
|
||||
FaceInfo.Point leftBottom = FaceInfo.Point.build(faceInfo.box.x1(),faceInfo.box.y1());
|
||||
faceInfo.box = new FaceInfo.FaceBox(leftTop, rightTop, rightBottom, leftBottom);
|
||||
faceInfo.box = faceInfo.box.rotate(rotate).scaling(scaling).rotate(-angle);
|
||||
}
|
||||
}
|
||||
return faceInfos;
|
||||
}
|
||||
|
||||
/**人脸框的默认缩放比例**/
|
||||
private float getBoxScale(Map<String, Object> params){
|
||||
float boxScale = 0;
|
||||
try {
|
||||
if(null != params && params.containsKey(scrfdFaceboxScaleParamKey)){
|
||||
Object value = params.get(scrfdFaceboxScaleParamKey);
|
||||
if(null != value){
|
||||
if (value instanceof Number){
|
||||
boxScale = ((Number) value).floatValue();
|
||||
}else{
|
||||
boxScale = Float.parseFloat(value.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}catch (Exception e){}
|
||||
return boxScale > 0 ? boxScale : defBoxScale;
|
||||
}
|
||||
|
||||
/**获取是否需要进行角度探测**/
|
||||
private boolean getNeedCheckFaceAngle(Map<String, Object> params){
|
||||
boolean needCheckFaceAngle = defNeedCheckFaceAngle;
|
||||
try {
|
||||
if(null != params && params.containsKey(scrfdFaceNeedCheckFaceAngleParamKey)){
|
||||
Object value = params.get(scrfdFaceNeedCheckFaceAngleParamKey);
|
||||
if(null != value){
|
||||
if (value instanceof Boolean){
|
||||
needCheckFaceAngle = (boolean) value;
|
||||
}else{
|
||||
needCheckFaceAngle = Boolean.parseBoolean(value.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}
|
||||
return needCheckFaceAngle;
|
||||
}
|
||||
|
||||
/**获取是否需要对没有检测到人脸的图像进行补边二次识别**/
|
||||
private boolean getNoFaceImageNeedMakeBorder(Map<String, Object> params){
|
||||
boolean noFaceImageNeedMakeBorder = defNoFaceImageNeedMakeBorder;
|
||||
try {
|
||||
if(null != params && params.containsKey(scrfdNoFaceImageNeedMakeBorderParamKey)){
|
||||
Object value = params.get(scrfdNoFaceImageNeedMakeBorderParamKey);
|
||||
if(null != value){
|
||||
if (value instanceof Boolean){
|
||||
noFaceImageNeedMakeBorder = (boolean) value;
|
||||
}else{
|
||||
noFaceImageNeedMakeBorder = Boolean.parseBoolean(value.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}
|
||||
return noFaceImageNeedMakeBorder;
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,8 @@ import com.visual.face.search.core.base.BaseOnnxInfer;
|
||||
import com.visual.face.search.core.base.FaceDetection;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.Mats;
|
||||
import com.visual.face.search.core.utils.ReleaseUtil;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
@ -25,6 +27,8 @@ import java.util.Map;
|
||||
*/
|
||||
public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetection {
|
||||
|
||||
public final static int ValueOfPcnMaxSide = 512;
|
||||
public final static String KeyOfPcnMaxSide = "pcn-max-side";
|
||||
//常量参数
|
||||
private final static float stride_ = 8;
|
||||
private final static float minFace_ = 28;
|
||||
@ -60,114 +64,137 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
public List<FaceInfo> inference(ImageMat image, float scoreTh, float iouTh, Map<String, Object> params) {
|
||||
Mat mat = null;
|
||||
Mat imgPad = null;
|
||||
Mat resizeMat = null;
|
||||
ImageMat imageMat = image.clone();
|
||||
try {
|
||||
//防止分辨率过高,导致内存持续增加
|
||||
int maxSide = getMaxSide(params);
|
||||
mat = imageMat.toCvMat();
|
||||
imgPad = pad_img_not_release_mat(mat);
|
||||
Size size = mat.size();
|
||||
float scale = 1.0f;
|
||||
if(size.height > maxSide && size.width > maxSide){
|
||||
scale = (float) ((size.height > size.width) ? size.width/maxSide: size.height/maxSide);
|
||||
}
|
||||
//推理
|
||||
resizeMat = resize_img_release_mat(mat.clone(), scale);
|
||||
imgPad = pad_img_release_mat(resizeMat.clone());
|
||||
float[] iouThs = iouTh <= 0 ? defIouThs : new float[]{iouTh, iouTh, 0.3f};
|
||||
float[] scoreThs = scoreTh <= 0 ? defScoreThs : new float[]{0.375f * scoreTh, 0.5f * scoreTh, scoreTh};
|
||||
List<PcnNetworkFaceDetection.Window2> willis = detect(this.getSessions(), mat, imgPad, scoreThs, iouThs);
|
||||
return trans_window(mat, imgPad, willis);
|
||||
List<Window2> willis = detect_release_mat(this.getSessions(), resizeMat.clone(), imgPad.clone(), scoreThs, iouThs);
|
||||
return trans_window(resizeMat, imgPad, willis, scale);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
if(null != mat){
|
||||
mat.release();
|
||||
}
|
||||
if(null != imgPad){
|
||||
imgPad.release();
|
||||
}
|
||||
if(null != imageMat){
|
||||
imageMat.release();
|
||||
}
|
||||
ReleaseUtil.release(imgPad, resizeMat, mat);
|
||||
ReleaseUtil.release(imageMat);
|
||||
}
|
||||
}
|
||||
|
||||
/********************************分割线*************************************/
|
||||
|
||||
private static Mat pad_img_not_release_mat(Mat mat){
|
||||
int row = Math.min((int)(mat.size().height * 0.2), 100);
|
||||
int col = Math.min((int)(mat.size().width * 0.2), 100);
|
||||
Mat dst = new Mat();
|
||||
Core.copyMakeBorder(mat, dst, row, row, col, col, Core.BORDER_CONSTANT);
|
||||
return dst;
|
||||
private static int getMaxSide(Map<String, Object> params){
|
||||
try {
|
||||
if(null != params && params.containsKey(KeyOfPcnMaxSide)){
|
||||
Object value = params.get(KeyOfPcnMaxSide);
|
||||
if(null != value && !value.toString().trim().isEmpty()){
|
||||
return Integer.parseInt(value.toString());
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}
|
||||
return ValueOfPcnMaxSide;
|
||||
}
|
||||
|
||||
private static Mat resize_img(Mat mat, float scale){
|
||||
double h = mat.size().height;
|
||||
double w = mat.size().width;
|
||||
int h_ = (int) (h / scale);
|
||||
int w_ = (int) (w / scale);
|
||||
Mat matF32 = new Mat();
|
||||
if(mat.type() != CvType.CV_32FC3){
|
||||
mat.convertTo(matF32, CvType.CV_32FC3);
|
||||
}else{
|
||||
mat.copyTo(matF32);
|
||||
private static Mat pad_img_release_mat(Mat mat){
|
||||
try {
|
||||
int row = Math.min((int)(mat.size().height * 0.2), 100);
|
||||
int col = Math.min((int)(mat.size().width * 0.2), 100);
|
||||
Mat dst = new Mat();
|
||||
Core.copyMakeBorder(mat, dst, row, row, col, col, Core.BORDER_CONSTANT);
|
||||
return dst;
|
||||
}finally {
|
||||
ReleaseUtil.release(mat);
|
||||
}
|
||||
Mat dst = new Mat();
|
||||
Imgproc.resize(matF32, dst, new Size(w_, h_), 0,0, Imgproc.INTER_NEAREST);
|
||||
mat.release();
|
||||
matF32.release();
|
||||
return dst;
|
||||
}
|
||||
|
||||
private static Mat preprocess_img(Mat mat, int dim){
|
||||
Mat matTmp = new Mat();
|
||||
if(dim > 0){
|
||||
Imgproc.resize(mat, matTmp, new Size(dim, dim), 0, 0, Imgproc.INTER_NEAREST);
|
||||
}else{
|
||||
mat.copyTo(matTmp);
|
||||
private static Mat resize_img_release_mat(Mat mat, float scale){
|
||||
Mat matF32 = null;
|
||||
try{
|
||||
double h = mat.size().height;
|
||||
double w = mat.size().width;
|
||||
int h_ = (int) (h / scale);
|
||||
int w_ = (int) (w / scale);
|
||||
matF32 = new Mat();
|
||||
if(mat.type() != CvType.CV_32FC3){
|
||||
mat.convertTo(matF32, CvType.CV_32FC3);
|
||||
}else{
|
||||
mat.copyTo(matF32);
|
||||
}
|
||||
Mat dst = new Mat();
|
||||
Imgproc.resize(matF32, dst, new Size(w_, h_), 0,0, Imgproc.INTER_NEAREST);
|
||||
return dst;
|
||||
}finally {
|
||||
ReleaseUtil.release(matF32, mat);
|
||||
}
|
||||
|
||||
//格式转化
|
||||
Mat matF32 = new Mat();
|
||||
if(mat.type() != CvType.CV_32FC3){
|
||||
matTmp.convertTo(matF32, CvType.CV_32FC3);
|
||||
}else{
|
||||
matTmp.copyTo(matF32);
|
||||
}
|
||||
|
||||
Mat dst = new Mat();
|
||||
Core.subtract(matF32, new Scalar(104, 117, 123), dst);
|
||||
|
||||
mat.release();
|
||||
matTmp.release();
|
||||
matF32.release();
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
private static OnnxTensor set_input(Mat mat){
|
||||
private static Mat preprocess_img_release_mat(Mat mat, int dim){
|
||||
Mat matTmp = null;
|
||||
Mat matF32 = null;
|
||||
try {
|
||||
//resize
|
||||
matTmp = new Mat();
|
||||
if(dim > 0){
|
||||
Imgproc.resize(mat, matTmp, new Size(dim, dim), 0, 0, Imgproc.INTER_NEAREST);
|
||||
}else{
|
||||
mat.copyTo(matTmp);
|
||||
}
|
||||
//格式转化
|
||||
matF32 = new Mat();
|
||||
if(mat.type() != CvType.CV_32FC3){
|
||||
matTmp.convertTo(matF32, CvType.CV_32FC3);
|
||||
}else{
|
||||
matTmp.copyTo(matF32);
|
||||
}
|
||||
//减法运算
|
||||
Mat dst = new Mat();
|
||||
Core.subtract(matF32, new Scalar(104, 117, 123), dst);
|
||||
return dst;
|
||||
}finally {
|
||||
ReleaseUtil.release(matF32, matTmp, mat);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static OnnxTensor set_input_release_mat(Mat mat){
|
||||
Mat dst = null;
|
||||
try {
|
||||
dst = new Mat();
|
||||
mat.copyTo(dst);
|
||||
return ImageMat.fromCVMat(dst).to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
}finally {
|
||||
if(null != dst){
|
||||
dst.release();
|
||||
}
|
||||
ReleaseUtil.release(dst, mat);
|
||||
}
|
||||
}
|
||||
|
||||
private static OnnxTensor set_input(List<Mat> mats){
|
||||
float[][][][] arrays = new float[mats.size()][][][];
|
||||
for(int i=0; i< mats.size(); i++){
|
||||
Mat dst = new Mat();
|
||||
mats.get(i).copyTo(dst);
|
||||
float[][][][] array = ImageMat.fromCVMat(dst).to4dFloatArrayAndDoReleaseMat(true);
|
||||
arrays[i] = array[0];
|
||||
dst.release();
|
||||
}
|
||||
private static OnnxTensor set_input_release_mat(Mats mats){
|
||||
try {
|
||||
float[][][][] arrays = new float[mats.size()][][][];
|
||||
for(int i=0; i< mats.size(); i++){
|
||||
float[][][][] array = ImageMat.fromCVMat(mats.get(i)).to4dFloatArrayAndDoReleaseMat(true);
|
||||
arrays[i] = array[0];
|
||||
}
|
||||
return OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), arrays);
|
||||
}catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
ReleaseUtil.release(mats);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean legal(int x, int y, Mat mat){
|
||||
if(0 <= x && x < mat.size().width && 0 <= y && y< mat.size().height){
|
||||
private static boolean legal(int x, int y, Size size){
|
||||
if(0 <= x && x < size.width && 0 <= y && y< size.height){
|
||||
return true;
|
||||
}else{
|
||||
return false;
|
||||
@ -191,87 +218,106 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
}
|
||||
|
||||
private static List<Window2> NMS(List<Window2> winlist, boolean local, float threshold){
|
||||
if(winlist==null || winlist.isEmpty()){
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//排序
|
||||
Collections.sort(winlist);
|
||||
int [] flag = new int[winlist.size()];
|
||||
for(int i=0; i< winlist.size(); i++){
|
||||
if(flag[i] > 0){
|
||||
continue;
|
||||
try{
|
||||
if(winlist==null || winlist.isEmpty()){
|
||||
return new ArrayList<>();
|
||||
}
|
||||
for(int j=i+1; j<winlist.size(); j++){
|
||||
if(local && Math.abs(winlist.get(i).scale - winlist.get(j).scale) > EPS){
|
||||
//排序
|
||||
Collections.sort(winlist);
|
||||
int [] flag = new int[winlist.size()];
|
||||
for(int i=0; i< winlist.size(); i++){
|
||||
if(flag[i] > 0){
|
||||
continue;
|
||||
}
|
||||
if(IoU(winlist.get(i), winlist.get(j)) > threshold){
|
||||
flag[j] = 1;
|
||||
for(int j=i+1; j<winlist.size(); j++){
|
||||
if(local && Math.abs(winlist.get(i).scale - winlist.get(j).scale) > EPS){
|
||||
continue;
|
||||
}
|
||||
if(IoU(winlist.get(i), winlist.get(j)) > threshold){
|
||||
flag[j] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
List<Window2> list = new ArrayList<>();
|
||||
for(int i=0; i< flag.length; i++){
|
||||
if(flag[i] == 0){
|
||||
list.add(winlist.get(i));
|
||||
List<Window2> list = new ArrayList<>();
|
||||
for(int i=0; i< flag.length; i++){
|
||||
if(flag[i] == 0){
|
||||
list.add(winlist.get(i));
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}finally {
|
||||
if(null != winlist && !winlist.isEmpty()){
|
||||
winlist.clear();
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private static List<Window2> deleteFP(List<Window2> winlist){
|
||||
if (winlist == null || winlist.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//排序
|
||||
Collections.sort(winlist);
|
||||
int [] flag = new int[winlist.size()];
|
||||
for(int i=0; i< winlist.size(); i++){
|
||||
if(flag[i] > 0){
|
||||
continue;
|
||||
try {
|
||||
if(null == winlist || winlist.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
for(int j=i+1; j<winlist.size(); j++){
|
||||
Window2 win = winlist.get(j);
|
||||
if(inside(win.x, win.y, winlist.get(i)) && inside(win.x + win.w - 1, win.y + win.h - 1, winlist.get(i))){
|
||||
flag[j] = 1;
|
||||
//排序
|
||||
Collections.sort(winlist);
|
||||
int [] flag = new int[winlist.size()];
|
||||
for(int i=0; i< winlist.size(); i++){
|
||||
if(flag[i] > 0){
|
||||
continue;
|
||||
}
|
||||
for(int j=i+1; j<winlist.size(); j++){
|
||||
Window2 win = winlist.get(j);
|
||||
if(inside(win.x, win.y, winlist.get(i)) && inside(win.x + win.w - 1, win.y + win.h - 1, winlist.get(i))){
|
||||
flag[j] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
List<Window2> list = new ArrayList<>();
|
||||
for(int i=0; i< flag.length; i++){
|
||||
if(flag[i] == 0){
|
||||
list.add(winlist.get(i));
|
||||
List<Window2> list = new ArrayList<>();
|
||||
for(int i=0; i< flag.length; i++){
|
||||
if(flag[i] == 0){
|
||||
list.add(winlist.get(i));
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}finally {
|
||||
if(null != winlist && !winlist.isEmpty()){
|
||||
winlist.clear();
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private static List<FaceInfo> trans_window(Mat img, Mat imgPad, List<Window2> winlist){
|
||||
int row = (imgPad.size(0) - img.size(0)) / 2;
|
||||
int col = (imgPad.size(1) - img.size(1)) / 2;
|
||||
private static List<FaceInfo> trans_window(Mat img, Mat imgPad, List<Window2> winlist, float scale){
|
||||
List<FaceInfo> ret = new ArrayList<>();
|
||||
for(Window2 win : winlist){
|
||||
if( win.w > 0 && win.h > 0){
|
||||
int x1 = win.x - col;
|
||||
int y1 = win.y - row;
|
||||
int x2 = win.x - col + win.w;
|
||||
int y2 = win.y - row + win.w;
|
||||
int angle = (win.angle + 360) % 360;
|
||||
//扩展人脸高度
|
||||
float rw = 0f;
|
||||
float rh = 0.1f;
|
||||
int w = Math.abs(x2 - x1);
|
||||
int h = Math.abs(y2 - y1);
|
||||
x1 = Math.max(x1 - (int)(w * rw), 1);
|
||||
y1 = Math.max(y1 - (int)(h * rh), 1);
|
||||
x2 = Math.min(x2 + (int)(w * rw), img.size(1)-1);
|
||||
y2 = Math.min(y2 + (int)(h * rh), img.size(0)-1);
|
||||
//构建人脸信息
|
||||
FaceInfo faceInfo = FaceInfo.build(win.conf, angle, FaceInfo.FaceBox.build(x1, y1, x2, y2), FaceInfo.Points.build());
|
||||
ret.add(faceInfo);
|
||||
try {
|
||||
int row = (imgPad.size(0) - img.size(0)) / 2;
|
||||
int col = (imgPad.size(1) - img.size(1)) / 2;
|
||||
for(Window2 win : winlist){
|
||||
if( win.w > 0 && win.h > 0){
|
||||
int x1 = win.x - col;
|
||||
int y1 = win.y - row;
|
||||
int x2 = win.x - col + win.w;
|
||||
int y2 = win.y - row + win.w;
|
||||
int angle = (win.angle + 360) % 360;
|
||||
//扩展人脸高度
|
||||
float rw = 0f;
|
||||
float rh = 0.1f;
|
||||
int w = Math.abs(x2 - x1);
|
||||
int h = Math.abs(y2 - y1);
|
||||
x1 = Math.max(Float.valueOf((x1 - (int)(w * rw)) * scale).intValue(), 1);
|
||||
y1 = Math.max(Float.valueOf((y1 - (int)(h * rh)) * scale).intValue(), 1);
|
||||
x2 = Math.min(Float.valueOf((x2 + (int)(w * rw)) * scale).intValue(), Float.valueOf((img.size(1)) * scale).intValue()-1);
|
||||
y2 = Math.min(Float.valueOf((y2 + (int)(h * rh)) * scale).intValue(), Float.valueOf((img.size(0)) * scale).intValue()-1);
|
||||
//构建人脸信息
|
||||
FaceInfo faceInfo = FaceInfo.build(win.conf, angle, FaceInfo.FaceBox.build(x1, y1, x2, y2), FaceInfo.Points.build());
|
||||
ret.add(faceInfo);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}finally {
|
||||
if(null != winlist && !winlist.isEmpty()){
|
||||
winlist.clear();
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
@ -284,34 +330,30 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
* @throws OrtException
|
||||
* @throws IOException
|
||||
*/
|
||||
private static List<Window2> stage1(Mat img, Mat imgPad, OrtSession net, float thres) throws RuntimeException {
|
||||
int netSize = 24;
|
||||
float curScale = minFace_ / netSize;
|
||||
int row = (int) ((imgPad.size().height - img.size().height) / 2);
|
||||
int col = (int) ((imgPad.size().width - img.size().width) / 2);
|
||||
|
||||
private static List<Window2> stage1_release_mat(Mat img, Mat imgPad, OrtSession net, float thres) throws RuntimeException {
|
||||
Mat img_resized = null;
|
||||
OnnxTensor net_input1 = null;
|
||||
OrtSession.Result output = null;
|
||||
List<Window2> winlist = new ArrayList<>();
|
||||
try {
|
||||
img_resized = resize_img(img.clone(), curScale);
|
||||
//获取必要参数
|
||||
int netSize = 24;
|
||||
float curScale = minFace_ / netSize;
|
||||
int row = (int) ((imgPad.size().height - img.size().height) / 2);
|
||||
int col = (int) ((imgPad.size().width - img.size().width) / 2);
|
||||
img_resized = resize_img_release_mat(img.clone(), curScale);
|
||||
//循环处理
|
||||
while(Math.min(img_resized.size().height, img_resized.size().width) >= netSize){
|
||||
img_resized = preprocess_img(img_resized, 0);
|
||||
net_input1 = set_input(img_resized);
|
||||
img_resized = preprocess_img_release_mat(img_resized, 0);
|
||||
net_input1 = set_input_release_mat(img_resized.clone());
|
||||
//推理网络
|
||||
output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), net_input1));
|
||||
float[][][][] cls_prob = (float[][][][]) output.get(0).getValue();
|
||||
float[][][][] rotate = (float[][][][]) output.get(1).getValue();
|
||||
float[][][][] bbox = (float[][][][]) output.get(2).getValue();
|
||||
//关闭对象
|
||||
if(null != net_input1){
|
||||
net_input1.close();
|
||||
net_input1 = null;
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
output = null;
|
||||
}
|
||||
ReleaseUtil.release(output); output = null;
|
||||
ReleaseUtil.release(net_input1); net_input1 = null;
|
||||
//计算业务逻辑
|
||||
float w = netSize * curScale;
|
||||
for(int i=0; i< cls_prob[0][0].length; i++){
|
||||
@ -323,7 +365,7 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
int rx = (int)(j * curScale * stride_ - 0.5 * sn * w + sn * xn * w + 0.5 * w) + col;
|
||||
int ry = (int)(i * curScale * stride_ - 0.5 * sn * w + sn * yn * w + 0.5 * w) + row;
|
||||
int rw = (int)(w * sn);
|
||||
if (legal(rx, ry, imgPad) && legal(rx + rw - 1, ry + rw - 1, imgPad)){
|
||||
if (legal(rx, ry, imgPad.size()) && legal(rx + rw - 1, ry + rw - 1, imgPad.size())){
|
||||
if (rotate[0][1][i][j] > 0.5){
|
||||
winlist.add(new Window2(rx, ry, rw, rw, 0, curScale, cls_prob[0][1][i][j]));
|
||||
}else{
|
||||
@ -333,21 +375,15 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
}
|
||||
}
|
||||
}
|
||||
img_resized = resize_img(img_resized, scale_);
|
||||
img_resized = resize_img_release_mat(img_resized, scale_);
|
||||
curScale = (float) (img.size().height / img_resized.size().height);
|
||||
}
|
||||
}catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
if(null != net_input1){
|
||||
net_input1.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
if(null != img_resized){
|
||||
img_resized.release();
|
||||
}
|
||||
ReleaseUtil.release(output);
|
||||
ReleaseUtil.release(net_input1);
|
||||
ReleaseUtil.release(img_resized, imgPad, img);
|
||||
}
|
||||
//返回
|
||||
return winlist;
|
||||
@ -364,56 +400,58 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
* @return
|
||||
* @throws OrtException
|
||||
*/
|
||||
private static List<Window2> stage2(Mat img, Mat img180, OrtSession net, float thres, int dim, List<Window2> winlist) throws OrtException {
|
||||
private static List<Window2> stage2_release_mat(Mat img, Mat img180, OrtSession net, float thres, int dim, List<Window2> winlist) throws OrtException {
|
||||
if(winlist==null || winlist.isEmpty()){
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//逻辑处理
|
||||
float[][] bbox;
|
||||
float[][] rotate;
|
||||
float[][] cls_prob;
|
||||
Mats datalist = null;
|
||||
OnnxTensor input = null;
|
||||
OrtSession.Result output = null;
|
||||
Size size = img.size();
|
||||
int height = img.size(0);
|
||||
List<Mat> datalist = new ArrayList<>();
|
||||
for(Window2 win : winlist){
|
||||
if(Math.abs(win.angle) < EPS){
|
||||
Mat cloneMat = img.clone();
|
||||
Mat corp = new Mat(cloneMat, new Rect(win.x, win.y, win.w, win.h));
|
||||
Mat corp1 = preprocess_img(corp, dim);
|
||||
datalist.add(corp1);
|
||||
if(null != corp){
|
||||
corp.release();
|
||||
}
|
||||
if(null != cloneMat){
|
||||
cloneMat.release();
|
||||
}
|
||||
}else{
|
||||
int y2 = win.y + win.h - 1;
|
||||
int y = height - 1 - y2;
|
||||
|
||||
Mat cloneMat = img180.clone();
|
||||
Mat corp = new Mat(cloneMat, new Rect(win.x, y, win.w, win.h));
|
||||
Mat corp1 = preprocess_img(corp, dim);
|
||||
datalist.add(corp1);
|
||||
if(null != corp){
|
||||
corp.release();
|
||||
}
|
||||
if(null != cloneMat){
|
||||
cloneMat.release();
|
||||
try {
|
||||
datalist = Mats.build();
|
||||
for(Window2 win : winlist){
|
||||
if(Math.abs(win.angle) < EPS){
|
||||
Mat corp = null;
|
||||
try {
|
||||
corp = new Mat(img, new Rect(win.x, win.y, win.w, win.h));
|
||||
Mat preprocess = preprocess_img_release_mat(corp.clone(), dim);
|
||||
datalist.add(preprocess);
|
||||
}finally {
|
||||
ReleaseUtil.release(corp);
|
||||
}
|
||||
}else{
|
||||
Mat corp = null;
|
||||
try {
|
||||
int y2 = win.y + win.h - 1;
|
||||
int y = height - 1 - y2;
|
||||
corp = new Mat(img180, new Rect(win.x, y, win.w, win.h));
|
||||
Mat preprocess = preprocess_img_release_mat(corp.clone(), dim);
|
||||
datalist.add(preprocess);
|
||||
}finally {
|
||||
ReleaseUtil.release(corp);
|
||||
}
|
||||
}
|
||||
}
|
||||
//模型推理
|
||||
input = set_input_release_mat(datalist.clone());
|
||||
output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), input));
|
||||
cls_prob = (float[][]) output.get(0).getValue();
|
||||
rotate = (float[][]) output.get(1).getValue();
|
||||
bbox = (float[][]) output.get(2).getValue();
|
||||
}finally {
|
||||
ReleaseUtil.release(datalist);
|
||||
ReleaseUtil.release(output);
|
||||
ReleaseUtil.release(input);
|
||||
ReleaseUtil.release(img180, img);
|
||||
output = null; input = null;
|
||||
}
|
||||
|
||||
OnnxTensor net_input = set_input(datalist);
|
||||
OrtSession.Result output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), net_input));
|
||||
float[][] cls_prob = (float[][]) output.get(0).getValue();
|
||||
float[][] rotate = (float[][]) output.get(1).getValue();
|
||||
float[][] bbox = (float[][]) output.get(2).getValue();
|
||||
//关闭对象
|
||||
for(Mat mat : datalist){
|
||||
mat.release();
|
||||
}
|
||||
if(null != net_input){
|
||||
net_input.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
//再次后处理数据
|
||||
List<Window2> ret = new ArrayList<>();
|
||||
for(int i=0; i<winlist.size(); i++){
|
||||
if(cls_prob[i][1] > thres){
|
||||
@ -439,7 +477,7 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
}
|
||||
}
|
||||
|
||||
if(legal(x, y, img) && legal(x + w - 1, y + w - 1, img)){
|
||||
if(legal(x, y, size) && legal(x + w - 1, y + w - 1, size)){
|
||||
int angle = 0;
|
||||
if(Math.abs(winlist.get(i).angle) < EPS){
|
||||
if(maxRotateIndex == 0){
|
||||
@ -479,49 +517,79 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
* @return
|
||||
* @throws OrtException
|
||||
*/
|
||||
private static List<Window2> stage3(Mat imgPad, Mat img180, Mat img90, Mat imgNeg90, OrtSession net, float thres, int dim, List<Window2> winlist) throws OrtException {
|
||||
private static List<Window2> stage3_release_mat(Mat imgPad, Mat img180, Mat img90, Mat imgNeg90, OrtSession net, float thres, int dim, List<Window2> winlist) throws OrtException {
|
||||
if (winlist == null || winlist.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//逻辑处理
|
||||
float[][] bbox;
|
||||
float[][] rotate;
|
||||
float[][] cls_prob;
|
||||
Mats datalist = null;
|
||||
OnnxTensor input = null;
|
||||
OrtSession.Result output = null;
|
||||
|
||||
int height = imgPad.size(0);
|
||||
int width = imgPad.size(1);
|
||||
List<Mat> datalist = new ArrayList<>();
|
||||
for(Window2 win : winlist){
|
||||
if(Math.abs(win.angle) < EPS){
|
||||
Mat corp = new Mat(imgPad, new Rect(win.x, win.y, win.w, win.h));
|
||||
datalist.add(preprocess_img(corp, dim));
|
||||
}else if(Math.abs(win.angle - 90) < EPS){
|
||||
Mat corp = new Mat(img90, new Rect(win.y, win.x, win.h, win.w));
|
||||
datalist.add(preprocess_img(corp, dim));
|
||||
}else if(Math.abs(win.angle + 90) < EPS){
|
||||
int x = win.y;
|
||||
int y = width - 1 - (win.x + win.w - 1);
|
||||
Mat corp = new Mat(imgNeg90, new Rect(x, y, win.w, win.h));
|
||||
datalist.add(preprocess_img(corp, dim));
|
||||
}else{
|
||||
int y2 = win.y + win.h - 1;
|
||||
int y = height - 1 - y2;
|
||||
Mat corp = new Mat(img180, new Rect(win.x, y, win.w, win.h));
|
||||
datalist.add(preprocess_img(corp, dim));
|
||||
Size imgPadSize = imgPad.size();
|
||||
Size img180Size = img180.size();
|
||||
Size img90Size = img90.size();
|
||||
Size imgNeg90Size = imgNeg90.size();
|
||||
|
||||
try {
|
||||
datalist = Mats.build();
|
||||
for(Window2 win : winlist){
|
||||
if(Math.abs(win.angle) < EPS){
|
||||
Mat corp = null;
|
||||
try {
|
||||
corp = new Mat(imgPad, new Rect(win.x, win.y, win.w, win.h));
|
||||
datalist.add(preprocess_img_release_mat(corp.clone(), dim));
|
||||
}finally {
|
||||
ReleaseUtil.release(corp);
|
||||
}
|
||||
}else if(Math.abs(win.angle - 90) < EPS){
|
||||
Mat corp = null;
|
||||
try {
|
||||
corp = new Mat(img90, new Rect(win.y, win.x, win.h, win.w));
|
||||
datalist.add(preprocess_img_release_mat(corp.clone(), dim));
|
||||
}finally {
|
||||
ReleaseUtil.release(corp);
|
||||
}
|
||||
}else if(Math.abs(win.angle + 90) < EPS){
|
||||
Mat corp = null;
|
||||
try {
|
||||
int x = win.y;
|
||||
int y = width - 1 - (win.x + win.w - 1);
|
||||
corp = new Mat(imgNeg90, new Rect(x, y, win.w, win.h));
|
||||
datalist.add(preprocess_img_release_mat(corp.clone(), dim));
|
||||
}finally {
|
||||
ReleaseUtil.release(corp);
|
||||
}
|
||||
}else{
|
||||
Mat corp = null;
|
||||
try {
|
||||
int y2 = win.y + win.h - 1;
|
||||
int y = height - 1 - y2;
|
||||
corp = new Mat(img180, new Rect(win.x, y, win.w, win.h));
|
||||
datalist.add(preprocess_img_release_mat(corp.clone(), dim));
|
||||
}finally {
|
||||
ReleaseUtil.release(corp);
|
||||
}
|
||||
}
|
||||
}
|
||||
input = set_input_release_mat(datalist);
|
||||
output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), input));
|
||||
cls_prob = (float[][]) output.get(0).getValue();
|
||||
rotate = (float[][]) output.get(1).getValue();
|
||||
bbox = (float[][]) output.get(2).getValue();
|
||||
}finally {
|
||||
ReleaseUtil.release(datalist);
|
||||
ReleaseUtil.release(output);
|
||||
ReleaseUtil.release(input);
|
||||
ReleaseUtil.release(imgPad, img180, img90, imgNeg90);
|
||||
output = null; input = null;
|
||||
}
|
||||
OnnxTensor net_input = set_input(datalist);
|
||||
OrtSession.Result output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), net_input));
|
||||
float[][] cls_prob = (float[][]) output.get(0).getValue();
|
||||
float[][] rotate = (float[][]) output.get(1).getValue();
|
||||
float[][] bbox = (float[][]) output.get(2).getValue();
|
||||
|
||||
//关闭对象
|
||||
for(Mat mat : datalist){
|
||||
mat.release();
|
||||
}
|
||||
if(null != net_input){
|
||||
net_input.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
|
||||
//模型后处理
|
||||
List<Window2> ret = new ArrayList<>();
|
||||
for(int i=0; i<winlist.size(); i++) {
|
||||
if (cls_prob[i][1] > thres) {
|
||||
@ -531,24 +599,24 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
float cropX = winlist.get(i).x;
|
||||
float cropY = winlist.get(i).y;
|
||||
float cropW = winlist.get(i).w;
|
||||
Mat img_tmp = imgPad;
|
||||
Size img_tmp_size = imgPadSize;
|
||||
if (Math.abs(winlist.get(i).angle - 180) < EPS) {
|
||||
cropY = height - 1 - (cropY + cropW - 1);
|
||||
img_tmp = img180;
|
||||
img_tmp_size = img180Size;
|
||||
}else if (Math.abs(winlist.get(i).angle - 90) < EPS) {
|
||||
cropX = winlist.get(i).y;
|
||||
cropY = winlist.get(i).x;
|
||||
img_tmp = img90;
|
||||
img_tmp_size = img90Size;
|
||||
}else if (Math.abs(winlist.get(i).angle + 90) < EPS) {
|
||||
cropX = winlist.get(i).y;
|
||||
cropY = width - 1 - (winlist.get(i).x + winlist.get(i).w - 1);
|
||||
img_tmp = imgNeg90;
|
||||
img_tmp_size = imgNeg90Size;
|
||||
}
|
||||
int w = (int) (sn * cropW);
|
||||
int x = (int) (cropX - 0.5 * sn * cropW + cropW * sn * xn + 0.5 * cropW);
|
||||
int y = (int) (cropY - 0.5 * sn * cropW + cropW * sn * yn + 0.5 * cropW);
|
||||
int angle = (int)(angleRange_ * rotate[i][0]);
|
||||
if(legal(x, y, img_tmp) && legal(x + w - 1, y + w - 1, img_tmp)){
|
||||
if(legal(x, y, img_tmp_size) && legal(x + w - 1, y + w - 1, img_tmp_size)){
|
||||
if(Math.abs(winlist.get(i).angle) < EPS){
|
||||
ret.add(new Window2(x, y, w, w, angle, winlist.get(i).scale, cls_prob[i][1]));
|
||||
}else if(Math.abs(winlist.get(i).angle - 180) < EPS){
|
||||
@ -573,38 +641,39 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
* @throws OrtException
|
||||
* @throws IOException
|
||||
*/
|
||||
private static List<PcnNetworkFaceDetection.Window2> detect(OrtSession[] sessions, Mat img, Mat imgPad, float[] scoreThs, float iouThs[]) throws OrtException, IOException {
|
||||
Mat img180 = new Mat();
|
||||
Core.flip(imgPad, img180, 0);
|
||||
|
||||
Mat img90 = new Mat();
|
||||
Core.transpose(imgPad, img90);
|
||||
|
||||
Mat imgNeg90 = new Mat();
|
||||
Core.flip(img90, imgNeg90, 0);
|
||||
|
||||
List<PcnNetworkFaceDetection.Window2> winlist = stage1(img, imgPad, sessions[0], scoreThs[0]);
|
||||
winlist = NMS(winlist, true, iouThs[0]);
|
||||
|
||||
winlist = stage2(imgPad, img180, sessions[1], scoreThs[1], 24, winlist);
|
||||
winlist = NMS(winlist, true, iouThs[1]);
|
||||
|
||||
winlist = stage3(imgPad, img180, img90, imgNeg90, sessions[2], scoreThs[2], 48, winlist);
|
||||
winlist = NMS(winlist, false, iouThs[2]);
|
||||
|
||||
winlist = deleteFP(winlist);
|
||||
|
||||
img90.release();
|
||||
img180.release();
|
||||
imgNeg90.release();
|
||||
|
||||
return winlist;
|
||||
private static List<Window2> detect_release_mat(OrtSession[] sessions, Mat img, Mat imgPad, float[] scoreThs, float iouThs[]) throws OrtException, IOException {
|
||||
Mat img180 = null;
|
||||
Mat img90 = null;
|
||||
Mat imgNeg90 = null;
|
||||
try {
|
||||
//前置处理
|
||||
img180 = new Mat();
|
||||
Core.flip(imgPad, img180, 0);
|
||||
img90 = new Mat();
|
||||
Core.transpose(imgPad, img90);
|
||||
imgNeg90 = new Mat();
|
||||
Core.flip(img90, imgNeg90, 0);
|
||||
//第一步
|
||||
List<Window2> winlist = stage1_release_mat(img.clone(), imgPad.clone(), sessions[0], scoreThs[0]);
|
||||
winlist = NMS(winlist, true, iouThs[0]);
|
||||
//第二步
|
||||
winlist = stage2_release_mat(imgPad.clone(), img180.clone(), sessions[1], scoreThs[1], 24, winlist);
|
||||
winlist = NMS(winlist, true, iouThs[1]);
|
||||
//第三步
|
||||
winlist = stage3_release_mat(imgPad.clone(), img180.clone(), img90.clone(), imgNeg90.clone(), sessions[2], scoreThs[2], 48, winlist);
|
||||
winlist = NMS(winlist, false, iouThs[2]);
|
||||
//后处理
|
||||
winlist = deleteFP(winlist);
|
||||
return winlist;
|
||||
}finally {
|
||||
ReleaseUtil.release(imgNeg90, img90, img180, imgPad, img);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 临时的人脸框
|
||||
*/
|
||||
private static class Window2 implements Comparable<PcnNetworkFaceDetection.Window2>{
|
||||
private static class Window2 implements Comparable<Window2>{
|
||||
public int x;
|
||||
public int y;
|
||||
public int w;
|
||||
@ -624,7 +693,7 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(PcnNetworkFaceDetection.Window2 o) {
|
||||
public int compareTo(Window2 o) {
|
||||
if(o.conf == this.conf){
|
||||
return new Integer(this.y).compareTo(o.y);
|
||||
}else{
|
||||
@ -642,7 +711,7 @@ public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetect
|
||||
", angle=" + angle +
|
||||
", scale=" + scale +
|
||||
", conf=" + conf +
|
||||
'}' +"\n";
|
||||
"}" +"\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,59 @@
|
||||
package com.visual.face.search.core.models;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.visual.face.search.core.base.BaseOnnxInfer;
|
||||
import com.visual.face.search.core.base.FaceRecognition;
|
||||
import com.visual.face.search.core.domain.FaceInfo.Embedding;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.utils.ArrayUtil;
|
||||
import org.opencv.core.Scalar;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 人脸识别-人脸特征提取
|
||||
* git:https://github.com/SeetaFace6Open/index
|
||||
*/
|
||||
public class SeetaFaceOpenRecognition extends BaseOnnxInfer implements FaceRecognition {
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param modelPath 模型路径
|
||||
* @param threads 线程数
|
||||
*/
|
||||
public SeetaFaceOpenRecognition(String modelPath, int threads) {
|
||||
super(modelPath, threads);
|
||||
}
|
||||
|
||||
/**
|
||||
* 人脸识别,人脸特征向量
|
||||
* @param image 图像信息
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Embedding inference(ImageMat image, Map<String, Object> params) {
|
||||
OnnxTensor tensor = null;
|
||||
OrtSession.Result output = null;
|
||||
try {
|
||||
tensor = image.resizeAndNoReleaseMat(112,112)
|
||||
.blobFromImageAndDoReleaseMat(1.0/255, new Scalar(0, 0, 0), false)
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = getSession().run(Collections.singletonMap(getInputName(), tensor));
|
||||
float[] embeds = ((float[][]) output.get(0).getValue())[0];
|
||||
double normValue = ArrayUtil.matrixNorm(embeds);
|
||||
float[] embedding = ArrayUtil.division(embeds, Double.valueOf(normValue).floatValue());
|
||||
return Embedding.build(image.toBase64AndNoReleaseMat(), embedding);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
if(null != tensor){
|
||||
tensor.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
package com.visual.face.search.core.models;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.visual.face.search.core.base.BaseOnnxInfer;
|
||||
import com.visual.face.search.core.base.FaceMaskPoint;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.QualityInfo;
|
||||
import com.visual.face.search.core.utils.SoftMaxUtil;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
public class SeetaMaskFaceKeyPoint extends BaseOnnxInfer implements FaceMaskPoint {
|
||||
|
||||
private static final int stride = 8;
|
||||
private static final int shape = 128;
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param modelPath 模型路径
|
||||
* @param threads 线程数
|
||||
*/
|
||||
public SeetaMaskFaceKeyPoint(String modelPath, int threads) {
|
||||
super(modelPath, threads);
|
||||
}
|
||||
|
||||
/**
|
||||
* 人脸关键点检测
|
||||
*
|
||||
* @param imageMat 图像数据
|
||||
* @param params 参数信息
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public QualityInfo.MaskPoints inference(ImageMat imageMat, Map<String, Object> params) {
|
||||
Mat borderMat = null;
|
||||
Mat resizeMat = null;
|
||||
OnnxTensor tensor = null;
|
||||
OrtSession.Result output = null;
|
||||
try {
|
||||
Mat image = imageMat.toCvMat();
|
||||
//将图片转换为正方形
|
||||
int w = imageMat.getWidth();
|
||||
int h = imageMat.getHeight();
|
||||
int new_w = Math.max(h, w);
|
||||
int new_h = Math.max(h, w);
|
||||
if (Math.max(h, w) % stride != 0){
|
||||
new_w = new_w + (stride - Math.max(h, w) % stride);
|
||||
new_h = new_h + (stride - Math.max(h, w) % stride);
|
||||
}
|
||||
int ow = (new_w - w) / 2;
|
||||
int oh = (new_h - h) / 2;
|
||||
borderMat = new Mat();
|
||||
Core.copyMakeBorder(image, borderMat, oh, oh, ow, ow, Core.BORDER_CONSTANT, new Scalar(114, 114, 114));
|
||||
//对图片进行resize
|
||||
float ratio = 1.0f * shape / new_h;
|
||||
resizeMat = new Mat();
|
||||
Imgproc.resize(borderMat, resizeMat, new Size(shape, shape));
|
||||
//模型推理
|
||||
tensor = ImageMat.fromCVMat(resizeMat)
|
||||
.blobFromImageAndDoReleaseMat(1.0/32, new Scalar(104, 117, 123), false)
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = this.getSession().run(Collections.singletonMap(this.getInputName(), tensor));
|
||||
float[] value = ((float[][]) output.get(0).getValue())[0];
|
||||
//转换为标准的坐标点
|
||||
QualityInfo.MaskPoints pointList = QualityInfo.MaskPoints.build();
|
||||
for(int i=0; i<5; i++){
|
||||
float x = value[i * 4 + 0] / ratio * 128 - ow;
|
||||
float y = value[i * 4 + 1] / ratio * 128 - oh;
|
||||
double[] softMax = SoftMaxUtil.softMax(new double[]{value[i * 4 + 2], value[i * 4 + 3]});
|
||||
pointList.add(QualityInfo.MaskPoint.build(x, y, Double.valueOf(softMax[1]).floatValue()));
|
||||
}
|
||||
return pointList;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
if(null != tensor){
|
||||
tensor.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
if(null != borderMat){
|
||||
borderMat.release();
|
||||
}
|
||||
if(null != resizeMat){
|
||||
resizeMat.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -12,6 +12,8 @@ import java.util.Map;
|
||||
* 五点对齐法
|
||||
*/
|
||||
public class Simple005pFaceAlignment implements FaceAlignment {
|
||||
/**最小边的长度**/
|
||||
private final static float minEdgeLength = 128;
|
||||
|
||||
/**对齐矩阵**/
|
||||
private final static double[][] dst_points = new double[][]{
|
||||
@ -31,16 +33,33 @@ public class Simple005pFaceAlignment implements FaceAlignment {
|
||||
*/
|
||||
@Override
|
||||
public ImageMat inference(ImageMat imageMat, FaceInfo.Points imagePoint, Map<String, Object> params) {
|
||||
double [][] image_points;
|
||||
if(imagePoint.size() == 5){
|
||||
image_points = imagePoint.toDoubleArray();
|
||||
}else if(imagePoint.size() == 106){
|
||||
image_points = imagePoint.select(38, 88, 80, 52, 61).toDoubleArray();
|
||||
}else{
|
||||
throw new RuntimeException("need 5 point, but get "+ imagePoint.size());
|
||||
ImageMat alignmentImageMat = null;
|
||||
try {
|
||||
FaceInfo.Points alignmentPoints = imagePoint;
|
||||
if(imageMat.getWidth() < minEdgeLength || imageMat.getHeight() < minEdgeLength){
|
||||
float scale = minEdgeLength / Math.min(imageMat.getWidth(), imageMat.getHeight());
|
||||
int newWidth = Float.valueOf(imageMat.getWidth() * scale).intValue();
|
||||
int newHeight = Float.valueOf(imageMat.getHeight() * scale).intValue();
|
||||
alignmentImageMat = imageMat.resizeAndNoReleaseMat(newWidth, newHeight);
|
||||
alignmentPoints = imagePoint.operateMultiply(scale);
|
||||
}else{
|
||||
alignmentImageMat = imageMat.clone();
|
||||
}
|
||||
double [][] image_points;
|
||||
if(alignmentPoints.size() == 5){
|
||||
image_points = alignmentPoints.toDoubleArray();
|
||||
}else if(alignmentPoints.size() == 106){
|
||||
image_points = alignmentPoints.select(38, 88, 80, 52, 61).toDoubleArray();
|
||||
}else{
|
||||
throw new RuntimeException("need 5 point, but get "+ imagePoint.size());
|
||||
}
|
||||
Mat alignMat = AlignUtil.alignedImage(alignmentImageMat.toCvMat(), image_points, 112, 112, dst_points);
|
||||
return ImageMat.fromCVMat(alignMat);
|
||||
}finally {
|
||||
if(null != alignmentImageMat){
|
||||
alignmentImageMat.release();
|
||||
}
|
||||
}
|
||||
Mat alignMat = AlignUtil.alignedImage(imageMat.toCvMat(), image_points, 112, 112, dst_points);
|
||||
return ImageMat.fromCVMat(alignMat);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -10,6 +10,9 @@ import java.util.Map;
|
||||
|
||||
public class Simple106pFaceAlignment implements FaceAlignment {
|
||||
|
||||
/**最小边的长度**/
|
||||
private final static float minEdgeLength = 128;
|
||||
|
||||
/**矫正的偏移**/
|
||||
private final static double x_offset = 0;
|
||||
private final static double y_offset = -8;
|
||||
@ -125,14 +128,31 @@ public class Simple106pFaceAlignment implements FaceAlignment {
|
||||
|
||||
@Override
|
||||
public ImageMat inference(ImageMat imageMat, FaceInfo.Points imagePoint, Map<String, Object> params) {
|
||||
double [][] image_points;
|
||||
if(imagePoint.size() == 106){
|
||||
image_points = imagePoint.toDoubleArray();
|
||||
}else{
|
||||
throw new RuntimeException("need 106 point, but get "+ imagePoint.size());
|
||||
ImageMat alignmentImageMat = null;
|
||||
try {
|
||||
FaceInfo.Points alignmentPoints = imagePoint;
|
||||
if(imageMat.getWidth() < minEdgeLength || imageMat.getHeight() < minEdgeLength){
|
||||
float scale = minEdgeLength / Math.min(imageMat.getWidth(), imageMat.getHeight());
|
||||
int newWidth = Float.valueOf(imageMat.getWidth() * scale).intValue();
|
||||
int newHeight = Float.valueOf(imageMat.getHeight() * scale).intValue();
|
||||
alignmentImageMat = imageMat.resizeAndNoReleaseMat(newWidth, newHeight);
|
||||
alignmentPoints = imagePoint.operateMultiply(scale);
|
||||
}else{
|
||||
alignmentImageMat = imageMat.clone();
|
||||
}
|
||||
double [][] image_points;
|
||||
if(alignmentPoints.size() == 106){
|
||||
image_points = alignmentPoints.toDoubleArray();
|
||||
}else{
|
||||
throw new RuntimeException("need 106 point, but get "+ alignmentPoints.size());
|
||||
}
|
||||
Mat alignMat = AlignUtil.alignedImage(alignmentImageMat.toCvMat(), image_points, 112, 112, dst_points);
|
||||
return ImageMat.fromCVMat(alignMat);
|
||||
}finally {
|
||||
if(null != alignmentImageMat){
|
||||
alignmentImageMat.release();
|
||||
}
|
||||
}
|
||||
Mat alignMat = AlignUtil.alignedImage(imageMat.toCvMat(), image_points, 112, 112, dst_points);
|
||||
return ImageMat.fromCVMat(alignMat);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -24,6 +24,22 @@ public class ArrayUtil {
|
||||
return output;
|
||||
}
|
||||
|
||||
public static float[] division(float[] input, float division){
|
||||
float[] output = new float[input.length];
|
||||
for(int i=0; i< input.length; i++){
|
||||
output[i] = input[i] / division;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
public static double matrixNorm(float[] matrix){
|
||||
return matrixNorm(new double[][]{floatToDouble(matrix)});
|
||||
}
|
||||
|
||||
public static double matrixNorm(double[] matrix){
|
||||
return matrixNorm(new double[][]{matrix});
|
||||
}
|
||||
|
||||
public static double matrixNorm(double[][] matrix){
|
||||
double sum=0.0;
|
||||
for(double[] temp1:matrix){
|
||||
|
@ -0,0 +1,98 @@
|
||||
package com.visual.face.search.core.utils;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.Mats;
|
||||
import org.opencv.core.Mat;
|
||||
|
||||
public class ReleaseUtil {
|
||||
|
||||
public static void release(Mat ...mats){
|
||||
for(Mat mat : mats){
|
||||
if(null != mat){
|
||||
try {
|
||||
mat.release();
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
mat = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void release(Mats mats){
|
||||
if(null == mats || mats.isEmpty()){
|
||||
return;
|
||||
}
|
||||
try {
|
||||
mats.release();
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
mats = null;
|
||||
}
|
||||
}
|
||||
|
||||
public static void release(ImageMat ...imageMats){
|
||||
for(ImageMat imageMat : imageMats){
|
||||
if(null != imageMat){
|
||||
try {
|
||||
imageMat.release();
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
imageMat = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void release(OnnxTensor ...tensors){
|
||||
if(null == tensors || tensors.length == 0){
|
||||
return;
|
||||
}
|
||||
try {
|
||||
for(OnnxTensor tensor : tensors){
|
||||
try {
|
||||
if(null != tensor){
|
||||
tensor.close();
|
||||
}
|
||||
}catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
tensor = null;
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
tensors = null;
|
||||
}
|
||||
}
|
||||
|
||||
public static void release(OrtSession.Result ...results){
|
||||
if(null == results || results.length == 0){
|
||||
return;
|
||||
}
|
||||
try {
|
||||
for(OrtSession.Result result : results){
|
||||
try {
|
||||
if(null != result){
|
||||
result.close();
|
||||
}
|
||||
}catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
result = null;
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
e.printStackTrace();
|
||||
}finally {
|
||||
results = null;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -76,4 +76,20 @@ public class Similarity {
|
||||
return Double.valueOf(sim).floatValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* 对cos的原始值进行进行增强
|
||||
* @param cos
|
||||
* @return
|
||||
*/
|
||||
public static float cosEnhance(float cos){
|
||||
double sim = cos;
|
||||
if(cos >= 0.5){
|
||||
sim = cos + 2 * (cos - 0.5) * (1 - cos);
|
||||
}else if(cos >= 0){
|
||||
sim = cos - 2 * (cos - 0.5) * (0 - cos);
|
||||
}
|
||||
return Double.valueOf(sim).floatValue();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
BIN
face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx
Executable file → Normal file
@ -0,0 +1,80 @@
|
||||
package com.visual.face.search.core.test.extract;
|
||||
|
||||
import com.visual.face.search.core.base.*;
|
||||
import com.visual.face.search.core.domain.ExtParam;
|
||||
import com.visual.face.search.core.domain.FaceImage;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.extract.FaceFeatureExtractor;
|
||||
import com.visual.face.search.core.extract.FaceFeatureExtractorImpl;
|
||||
import com.visual.face.search.core.models.*;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import com.visual.face.search.core.utils.Similarity;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class FaceCompareTest extends BaseTest {
|
||||
|
||||
private static String modelPcn1Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn1_sd.onnx";
|
||||
private static String modelPcn2Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn2_sd.onnx";
|
||||
private static String modelPcn3Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn3_sd.onnx";
|
||||
private static String modelScrfdPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
|
||||
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
|
||||
private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_512.onnx";
|
||||
private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx";
|
||||
|
||||
private static String imagePath = "face-search-test/src/main/resources/image/validate/index/马化腾/";
|
||||
private static String imagePath3 = "face-search-test/src/main/resources/image/validate/index/雷军/";
|
||||
// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682053163961.jpg";
|
||||
// private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/1c7abcaf2dabdd2bc08e90c224d4c381.jpeg";
|
||||
private static String imagePath1 = "face-search-core/src/test/resources/images/faces/small/1.png";
|
||||
private static String imagePath2 = "face-search-core/src/test/resources/images/faces/small/2.png";
|
||||
public static void main(String[] args) {
|
||||
//口罩模型0.48,light模型0.52,normal模型0.62
|
||||
Map<String, String> map1 = getImagePathMap(imagePath1);
|
||||
Map<String, String> map2 = getImagePathMap(imagePath2);
|
||||
FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1);
|
||||
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
|
||||
FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1);
|
||||
FaceRecognition insightSeetaFaceRecognition = new SeetaFaceOpenRecognition(modelSeetaPath, 1);
|
||||
FaceAlignment simple005pFaceAlignment = new Simple005pFaceAlignment();
|
||||
FaceAlignment simple106pFaceAlignment = new Simple106pFaceAlignment();
|
||||
FaceDetection pcnNetworkFaceDetection = new PcnNetworkFaceDetection(new String[]{modelPcn1Path, modelPcn2Path, modelPcn3Path}, 1);
|
||||
FaceAttribute insightFaceAttribute = new InsightAttributeDetection(modelArrPath, 1);
|
||||
|
||||
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(
|
||||
insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint,
|
||||
simple005pFaceAlignment, insightSeetaFaceRecognition, insightFaceAttribute);
|
||||
|
||||
for(String file1 : map1.keySet()){
|
||||
for(String file2 : map2.keySet()){
|
||||
Mat image1 = Imgcodecs.imread(map1.get(file1));
|
||||
long s = System.currentTimeMillis();
|
||||
ExtParam extParam = ExtParam.build().setMask(false).setTopK(20).setScoreTh(0).setIouTh(0);
|
||||
FaceImage faceImage1 = extractor.extract(ImageMat.fromCVMat(image1), extParam, null);
|
||||
List<FaceInfo> faceInfos1 = faceImage1.faceInfos();
|
||||
long e = System.currentTimeMillis();
|
||||
System.out.println("image1 extract cost:"+(e-s)+"ms");;
|
||||
|
||||
Mat image2 = Imgcodecs.imread(map2.get(file2));
|
||||
s = System.currentTimeMillis();
|
||||
FaceImage faceImage2 = extractor.extract(ImageMat.fromCVMat(image2), extParam, null);
|
||||
List<FaceInfo> faceInfos2 = faceImage2.faceInfos();
|
||||
e = System.currentTimeMillis();
|
||||
System.out.println("image2 extract cost:"+(e-s)+"ms");
|
||||
float similarity = Similarity.cosineSimilarityNorm(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds);
|
||||
System.out.println(file1 + ","+ file2 + ",face similarity="+similarity);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
package com.visual.face.search.core.test.extract;
|
||||
|
||||
import com.visual.face.search.core.base.FaceAlignment;
|
||||
import com.visual.face.search.core.base.FaceDetection;
|
||||
import com.visual.face.search.core.base.FaceKeyPoint;
|
||||
import com.visual.face.search.core.base.FaceRecognition;
|
||||
import com.visual.face.search.core.base.*;
|
||||
import com.visual.face.search.core.domain.ExtParam;
|
||||
import com.visual.face.search.core.domain.FaceImage;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
@ -24,6 +21,7 @@ public class FaceFeatureExtractOOMTest extends BaseTest {
|
||||
private static String modelScrfdPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
|
||||
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
|
||||
private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx";
|
||||
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
@ -47,7 +45,11 @@ public class FaceFeatureExtractOOMTest extends BaseTest {
|
||||
FaceAlignment simple005pFaceAlignment = new Simple005pFaceAlignment();
|
||||
FaceAlignment simple106pFaceAlignment = new Simple106pFaceAlignment();
|
||||
FaceDetection pcnNetworkFaceDetection = new PcnNetworkFaceDetection(new String[]{modelPcn1Path, modelPcn2Path, modelPcn3Path}, 1);
|
||||
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint, simple106pFaceAlignment, insightArcFaceRecognition);
|
||||
FaceAttribute insightFaceAttribute = new InsightAttributeDetection(modelArrPath, 1);
|
||||
|
||||
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(
|
||||
insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint,
|
||||
simple106pFaceAlignment, insightArcFaceRecognition, insightFaceAttribute);
|
||||
// FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(insightScrfdFaceDetection, insightCoordFaceKeyPoint, simple106pFaceAlignment, insightArcFaceRecognition);
|
||||
for (int i = 0; i < 100000; i++) {
|
||||
for (String fileName : map.keySet()) {
|
||||
|
@ -1,9 +1,7 @@
|
||||
package com.visual.face.search.core.test.extract;
|
||||
|
||||
import com.visual.face.search.core.base.FaceAlignment;
|
||||
import com.visual.face.search.core.base.FaceDetection;
|
||||
import com.visual.face.search.core.base.FaceKeyPoint;
|
||||
import com.visual.face.search.core.base.FaceRecognition;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.visual.face.search.core.base.*;
|
||||
import com.visual.face.search.core.domain.ExtParam;
|
||||
import com.visual.face.search.core.domain.FaceImage;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
@ -12,7 +10,6 @@ import com.visual.face.search.core.extract.FaceFeatureExtractor;
|
||||
import com.visual.face.search.core.extract.FaceFeatureExtractorImpl;
|
||||
import com.visual.face.search.core.models.*;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import com.visual.face.search.core.utils.CropUtil;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Point;
|
||||
import org.opencv.core.Scalar;
|
||||
@ -31,9 +28,11 @@ public class FaceFeatureExtractTest extends BaseTest {
|
||||
private static String modelScrfdPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
|
||||
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
|
||||
|
||||
private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx";
|
||||
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate/rotate_0002.jpg";
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
@ -44,7 +43,11 @@ public class FaceFeatureExtractTest extends BaseTest {
|
||||
FaceAlignment simple005pFaceAlignment = new Simple005pFaceAlignment();
|
||||
FaceAlignment simple106pFaceAlignment = new Simple106pFaceAlignment();
|
||||
FaceDetection pcnNetworkFaceDetection = new PcnNetworkFaceDetection(new String[]{modelPcn1Path, modelPcn2Path, modelPcn3Path}, 1);
|
||||
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(pcnNetworkFaceDetection, insightScrfdFaceDetection, insightCoordFaceKeyPoint, simple106pFaceAlignment, insightArcFaceRecognition);
|
||||
FaceAttribute insightFaceAttribute = new InsightAttributeDetection(modelArrPath, 1);
|
||||
|
||||
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(
|
||||
insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint,
|
||||
simple005pFaceAlignment, insightArcFaceRecognition, insightFaceAttribute);
|
||||
for(String fileName : map.keySet()){
|
||||
String imageFilePath = map.get(fileName);
|
||||
System.out.println(imageFilePath);
|
||||
@ -55,7 +58,8 @@ public class FaceFeatureExtractTest extends BaseTest {
|
||||
.setTopK(20)
|
||||
.setScoreTh(0)
|
||||
.setIouTh(0);
|
||||
FaceImage faceImage = extractor.extract(ImageMat.fromCVMat(image), extParam, null);
|
||||
Map<String, Object> params = new JSONObject().fluentPut(InsightScrfdFaceDetection.scrfdFaceNeedCheckFaceAngleParamKey, true);
|
||||
FaceImage faceImage = extractor.extract(ImageMat.fromCVMat(image), extParam, params);
|
||||
List<FaceInfo> faceInfos = faceImage.faceInfos();
|
||||
long e = System.currentTimeMillis();
|
||||
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos);
|
||||
@ -66,7 +70,7 @@ public class FaceFeatureExtractTest extends BaseTest {
|
||||
Imgproc.line(image, new Point(box.rightTop.x, box.rightTop.y), new Point(box.rightBottom.x, box.rightBottom.y), new Scalar(255,0,0), 1);
|
||||
Imgproc.line(image, new Point(box.rightBottom.x, box.rightBottom.y), new Point(box.leftBottom.x, box.leftBottom.y), new Scalar(255,0,0), 1);
|
||||
Imgproc.line(image, new Point(box.leftBottom.x, box.leftBottom.y), new Point(box.leftTop.x, box.leftTop.y), new Scalar(255,0,0), 1);
|
||||
Imgproc.putText(image, String.valueOf(faceInfo.angle), new Point(box.leftTop.x, box.leftTop.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
|
||||
Imgproc.putText(image, String.valueOf(faceInfo.angle), new Point(box.leftTop.x, box.leftTop.y+15), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
|
||||
// Imgproc.rectangle(image, new Point(faceInfo.box.x1(), faceInfo.box.y1()), new Point(faceInfo.box.x2(), faceInfo.box.y2()), new Scalar(255,0,255));
|
||||
|
||||
FaceInfo.FaceBox box1 = faceInfo.rotateFaceBox();
|
||||
@ -75,6 +79,11 @@ public class FaceFeatureExtractTest extends BaseTest {
|
||||
Imgproc.circle(image, new Point(box1.rightBottom.x, box1.rightBottom.y), 3, new Scalar(0,0,255), -1);
|
||||
Imgproc.circle(image, new Point(box1.leftBottom.x, box1.leftBottom.y), 3, new Scalar(0,0,255), -1);
|
||||
|
||||
FaceInfo.Attribute attribute = faceInfo.attribute;
|
||||
Imgproc.putText(image, attribute.valueOfGender().name(), new Point(box.center().x-10, box.center().y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
|
||||
Imgproc.putText(image, ""+attribute.age, new Point(box.center().x-10, box.center().y+20), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
|
||||
|
||||
|
||||
int pointNum = 1;
|
||||
for(FaceInfo.Point keyPoint : faceInfo.points){
|
||||
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 1, new Scalar(0,0,255), -1);
|
||||
|
@ -0,0 +1,69 @@
|
||||
package com.visual.face.search.core.test.models;
|
||||
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.models.InsightAttributeDetection;
|
||||
import com.visual.face.search.core.models.InsightScrfdFaceDetection;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import com.visual.face.search.core.utils.CropUtil;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Point;
|
||||
import org.opencv.core.Scalar;
|
||||
import org.opencv.highgui.HighGui;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class InsightAttributeDetectionTest extends BaseTest {
|
||||
|
||||
private static String modelPathDetection = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
private static String modelPathAttribute = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx";
|
||||
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug";
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
Map<String, String> map = getImagePathMap(imagePath);
|
||||
InsightScrfdFaceDetection inferDetection = new InsightScrfdFaceDetection(modelPathDetection, 2);
|
||||
InsightAttributeDetection inferAttribute = new InsightAttributeDetection(modelPathAttribute, 2);
|
||||
|
||||
for(String fileName : map.keySet()){
|
||||
String imageFilePath = map.get(fileName);
|
||||
System.out.println(imageFilePath);
|
||||
Mat image = Imgcodecs.imread(imageFilePath);
|
||||
long s = System.currentTimeMillis();
|
||||
List<FaceInfo> faceInfos = inferDetection.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null);
|
||||
long e = System.currentTimeMillis();
|
||||
if(faceInfos.size() > 0){
|
||||
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos.get(0).score);
|
||||
}else{
|
||||
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos);
|
||||
}
|
||||
|
||||
for(FaceInfo faceInfo : faceInfos){
|
||||
Mat cropFace = CropUtil.crop(image, faceInfo.box);
|
||||
long a = System.currentTimeMillis();
|
||||
FaceInfo.Attribute attribute = inferAttribute.inference(ImageMat.fromCVMat(cropFace), null);
|
||||
System.out.println("ssss="+(System.currentTimeMillis() - a));
|
||||
Imgproc.putText(image, attribute.valueOfGender().name(), new Point(faceInfo.box.x1()+10, faceInfo.box.y1()+10), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
|
||||
Imgproc.putText(image, ""+attribute.age, new Point(faceInfo.box.x1()+10, faceInfo.box.y1()+40), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
|
||||
|
||||
Imgproc.rectangle(image, new Point(faceInfo.box.x1(), faceInfo.box.y1()), new Point(faceInfo.box.x2(), faceInfo.box.y2()), new Scalar(0,0,255));
|
||||
int pointNum = 1;
|
||||
for(FaceInfo.Point keyPoint : faceInfo.points){
|
||||
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 3, new Scalar(0,0,255), -1);
|
||||
Imgproc.putText(image, String.valueOf(pointNum), new Point(keyPoint.x+1, keyPoint.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
|
||||
pointNum ++ ;
|
||||
}
|
||||
}
|
||||
HighGui.imshow(fileName, image);
|
||||
HighGui.waitKey();
|
||||
}
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package com.visual.face.search.core.test.models;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.models.InsightScrfdFaceDetection;
|
||||
@ -11,15 +12,17 @@ import org.opencv.highgui.HighGui;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class InsightScrfdFaceDetectionTest extends BaseTest {
|
||||
private static String modelPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate/rotate_0001.jpg";
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/big/big_002.jpg";
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
@ -31,7 +34,9 @@ public class InsightScrfdFaceDetectionTest extends BaseTest {
|
||||
System.out.println(imageFilePath);
|
||||
Mat image = Imgcodecs.imread(imageFilePath);
|
||||
long s = System.currentTimeMillis();
|
||||
List<FaceInfo> faceInfos = infer.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null);
|
||||
Map<String, Object> params = new JSONObject().fluentPut(InsightScrfdFaceDetection.scrfdFaceNeedCheckFaceAngleParamKey, true);
|
||||
|
||||
List<FaceInfo> faceInfos = infer.inference(ImageMat.fromCVMat(image), 0.48f, 0.7f, params);
|
||||
long e = System.currentTimeMillis();
|
||||
if(faceInfos.size() > 0){
|
||||
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos.get(0).score);
|
||||
@ -39,10 +44,20 @@ public class InsightScrfdFaceDetectionTest extends BaseTest {
|
||||
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos);
|
||||
}
|
||||
|
||||
//对坐标进行调整
|
||||
for(FaceInfo faceInfo : faceInfos){
|
||||
Imgproc.rectangle(image, new Point(faceInfo.box.x1(), faceInfo.box.y1()), new Point(faceInfo.box.x2(), faceInfo.box.y2()), new Scalar(0,0,255));
|
||||
FaceInfo.FaceBox box = faceInfo.rotateFaceBox();
|
||||
Imgproc.circle(image, new Point(box.leftTop.x, box.leftTop.y), 3, new Scalar(0,0,255), -1);
|
||||
Imgproc.circle(image, new Point(box.rightBottom.x, box.rightBottom.y), 3, new Scalar(0,0,255), -1);
|
||||
Imgproc.line(image, new Point(box.leftTop.x, box.leftTop.y), new Point(box.rightTop.x, box.rightTop.y), new Scalar(0,0,255), 1);
|
||||
Imgproc.line(image, new Point(box.rightTop.x, box.rightTop.y), new Point(box.rightBottom.x, box.rightBottom.y), new Scalar(255,0,0), 1);
|
||||
Imgproc.line(image, new Point(box.rightBottom.x, box.rightBottom.y), new Point(box.leftBottom.x, box.leftBottom.y), new Scalar(255,0,0), 1);
|
||||
Imgproc.line(image, new Point(box.leftBottom.x, box.leftBottom.y), new Point(box.leftTop.x, box.leftTop.y), new Scalar(255,0,0), 1);
|
||||
Imgproc.putText(image, String.valueOf(faceInfo.angle), new Point(box.leftTop.x, box.leftTop.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
|
||||
|
||||
FaceInfo.Points points = faceInfo.points;
|
||||
int pointNum = 1;
|
||||
for(FaceInfo.Point keyPoint : faceInfo.points){
|
||||
for(FaceInfo.Point keyPoint : points){
|
||||
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 3, new Scalar(0,0,255), -1);
|
||||
Imgproc.putText(image, String.valueOf(pointNum), new Point(keyPoint.x+1, keyPoint.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
|
||||
pointNum ++ ;
|
||||
|
@ -0,0 +1,47 @@
|
||||
package com.visual.face.search.core.test.models;
|
||||
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.models.PcnNetworkFaceDetection;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class PcnNetworkFaceDetectionStress extends BaseTest {
|
||||
|
||||
static{ nu.pattern.OpenCV.loadShared(); }
|
||||
private OrtEnvironment env = OrtEnvironment.getEnvironment();
|
||||
|
||||
private static String model1Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn1_sd.onnx";
|
||||
private static String model2Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn2_sd.onnx";
|
||||
private static String model3Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn3_sd.onnx";
|
||||
|
||||
private static String imagePath = "face-search-test/src/main/resources/image/validate/noface";
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
Map<String, String> map = getImagePathMap(imagePath);
|
||||
PcnNetworkFaceDetection infer = new PcnNetworkFaceDetection(new String[]{model1Path, model2Path, model3Path}, 1);
|
||||
|
||||
int num = 0;
|
||||
for (int i = 0; i < 10000; i++) {
|
||||
for (String fileName : map.keySet()) {
|
||||
num = num + 1;
|
||||
String imageFilePath = map.get(fileName);
|
||||
System.out.println(num+":"+imageFilePath);
|
||||
Mat image = Imgcodecs.imread(imageFilePath);
|
||||
ImageMat imageMat = ImageMat.fromCVMat(image);
|
||||
|
||||
List<FaceInfo> faceInfos = infer.inference(imageMat, PcnNetworkFaceDetection.defScoreTh, PcnNetworkFaceDetection.defIouTh, null);
|
||||
faceInfos.clear();
|
||||
|
||||
imageMat.release();
|
||||
image.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -21,7 +21,9 @@ public class PcnNetworkFaceDetectionTest extends BaseTest {
|
||||
private static String model2Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn2_sd.onnx";
|
||||
private static String model3Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn3_sd.onnx";
|
||||
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate/rotate_0001.jpg";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/big/big_002.jpg";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug";
|
||||
|
||||
|
@ -0,0 +1,53 @@
|
||||
package com.visual.face.search.core.test.models;
|
||||
|
||||
import com.visual.face.search.core.base.FaceAlignment;
|
||||
import com.visual.face.search.core.base.FaceKeyPoint;
|
||||
import com.visual.face.search.core.base.FaceRecognition;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
|
||||
import com.visual.face.search.core.models.SeetaFaceOpenRecognition;
|
||||
import com.visual.face.search.core.models.Simple005pFaceAlignment;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import com.visual.face.search.core.utils.CropUtil;
|
||||
import com.visual.face.search.core.utils.Similarity;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class SeetaFaceOpenRecognitionTest extends BaseTest {
|
||||
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
|
||||
private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_fcae_seeta/face_recognizer_512.onnx";
|
||||
// private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_fcae_seeta/face_recognizer_1024.onnx";
|
||||
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/debug/debug_0004.jpeg";
|
||||
private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg";
|
||||
private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682053163961.jpg";
|
||||
|
||||
public static void main(String[] args) {
|
||||
FaceAlignment simple005pFaceAlignment = new Simple005pFaceAlignment();
|
||||
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
|
||||
FaceRecognition insightSeetaFaceRecognition = new SeetaFaceOpenRecognition(modelSeetaPath, 1);
|
||||
|
||||
Mat image1 = Imgcodecs.imread(imagePath1);
|
||||
Mat image2 = Imgcodecs.imread(imagePath2);
|
||||
// image1 = CropUtil.crop(image1, FaceInfo.FaceBox.build(54,27,310,380));
|
||||
// image2 = CropUtil.crop(image2, FaceInfo.FaceBox.build(48,13,292,333));
|
||||
// image2 = CropUtil.crop(image2, FaceInfo.FaceBox.build(52,9,235,263));
|
||||
|
||||
// simple005pFaceAlignment.inference()
|
||||
|
||||
FaceInfo.Embedding embedding1 = insightSeetaFaceRecognition.inference(ImageMat.fromCVMat(image1), null);
|
||||
FaceInfo.Embedding embedding2 = insightSeetaFaceRecognition.inference(ImageMat.fromCVMat(image2), null);
|
||||
float similarity = Similarity.cosineSimilarity(embedding1.embeds, embedding2.embeds);
|
||||
System.out.println(similarity);
|
||||
// System.out.println(Arrays.toString(embedding1.embeds));
|
||||
// System.out.println(Arrays.toString(embedding2.embeds));
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
package com.visual.face.search.core.test.models;
|
||||
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.QualityInfo;
|
||||
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
|
||||
import com.visual.face.search.core.models.InsightScrfdFaceDetection;
|
||||
import com.visual.face.search.core.models.SeetaMaskFaceKeyPoint;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import com.visual.face.search.core.utils.CropUtil;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Point;
|
||||
import org.opencv.core.Scalar;
|
||||
import org.opencv.highgui.HighGui;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SeetaMaskFaceKeyPointTest extends BaseTest {
|
||||
private static String modelDetectionPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
private static String modelKeypointPath = "face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx";
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg";
|
||||
|
||||
public static void main(String[] args) {
|
||||
Map<String, String> map = getImagePathMap(imagePath);
|
||||
InsightScrfdFaceDetection detectionInfer = new InsightScrfdFaceDetection(modelDetectionPath, 1);
|
||||
SeetaMaskFaceKeyPoint keyPointInfer = new SeetaMaskFaceKeyPoint(modelKeypointPath, 1);
|
||||
for(String fileName : map.keySet()) {
|
||||
System.out.println(fileName);
|
||||
String imageFilePath = map.get(fileName);
|
||||
Mat image = Imgcodecs.imread(imageFilePath);
|
||||
List<FaceInfo> faceInfos = detectionInfer.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null);
|
||||
for(FaceInfo faceInfo : faceInfos){
|
||||
FaceInfo.FaceBox rotateFaceBox = faceInfo.rotateFaceBox();
|
||||
Mat cropFace = CropUtil.crop(image, rotateFaceBox.scaling(1.0f));
|
||||
ImageMat cropImageMat = ImageMat.fromCVMat(cropFace);
|
||||
QualityInfo.MaskPoints maskPoints = keyPointInfer.inference(cropImageMat, null);
|
||||
System.out.println(maskPoints);
|
||||
for(QualityInfo.MaskPoint maskPoint : maskPoints){
|
||||
if(maskPoint.isMask()){
|
||||
Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(0, 0, 255), -1);
|
||||
}else{
|
||||
Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(255, 0, 0), -1);
|
||||
}
|
||||
}
|
||||
HighGui.imshow(fileName, cropFace);
|
||||
HighGui.waitKey();
|
||||
}
|
||||
}
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
}
|
BIN
face-search-core/src/test/resources/images/faces/big/big_001.jpg
Normal file
After Width: | Height: | Size: 32 KiB |
BIN
face-search-core/src/test/resources/images/faces/big/big_002.jpg
Normal file
After Width: | Height: | Size: 122 KiB |
After Width: | Height: | Size: 46 KiB |
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 23 KiB |
After Width: | Height: | Size: 75 KiB |
After Width: | Height: | Size: 19 KiB |
After Width: | Height: | Size: 68 KiB |
BIN
face-search-core/src/test/resources/images/faces/small/1.png
Normal file
After Width: | Height: | Size: 616 KiB |
BIN
face-search-core/src/test/resources/images/faces/small/2.png
Normal file
After Width: | Height: | Size: 727 KiB |
104
face-search-engine/pom.xml
Executable file → Normal file
@ -5,118 +5,22 @@
|
||||
<parent>
|
||||
<artifactId>face-search</artifactId>
|
||||
<groupId>com.visual.face.search</groupId>
|
||||
<version>1.2.0</version>
|
||||
<version>2.1.0</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>face-search-engine</artifactId>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<grpc.version>1.36.0</grpc.version>
|
||||
<protobuf.version>3.12.0</protobuf.version>
|
||||
<protoc.version>3.12.0</protoc.version>
|
||||
<commons-collections4.version>4.3</commons-collections4.version>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-bom</artifactId>
|
||||
<version>${grpc.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-netty-shaded</artifactId>
|
||||
<version>1.36.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-stub</artifactId>
|
||||
<version>1.36.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.grpc</groupId>
|
||||
<artifactId>grpc-protobuf</artifactId>
|
||||
<version>1.36.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
<version>21.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>1.16.16</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>3.14.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.7.30</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-slf4j-impl</artifactId>
|
||||
<version>2.12.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.milvus</groupId>
|
||||
<artifactId>milvus-java-sdk</artifactId>
|
||||
<version>2.0.4</version>
|
||||
<scope>system</scope>
|
||||
<systemPath>${project.basedir}/libs/milvus-java-sdk-2.0.4.jar</systemPath>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-gpg-plugin</artifactId>
|
||||
<version>1.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java-util</artifactId>
|
||||
<version>${protobuf.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-text</artifactId>
|
||||
<version>1.6</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-collections4</artifactId>
|
||||
<version>${commons-collections4.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.json</groupId>
|
||||
<artifactId>json</artifactId>
|
||||
<version>20190722</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba.proxima</groupId>
|
||||
<artifactId>proxima-be-java-sdk</artifactId>
|
||||
<version>0.2.0</version>
|
||||
<scope>system</scope>
|
||||
<systemPath>${project.basedir}/libs/proxima-be-java-sdk-0.2.0.jar</systemPath>
|
||||
<groupId>org.opensearch.client</groupId>
|
||||
<artifactId>opensearch-rest-high-level-client</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
@ -0,0 +1,30 @@
|
||||
package com.visual.face.search.engine.api;
|
||||
|
||||
|
||||
import com.visual.face.search.engine.model.MapParam;
|
||||
import com.visual.face.search.engine.model.SearchResponse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface SearchEngine {
|
||||
|
||||
public Object getEngine();
|
||||
|
||||
public boolean exist(String collectionName);
|
||||
|
||||
public boolean dropCollection(String collectionName);
|
||||
|
||||
public boolean createCollection(String collectionName, MapParam param);
|
||||
|
||||
public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors);
|
||||
|
||||
public boolean deleteVectorByKey(String collectionName, String faceId);
|
||||
|
||||
public boolean deleteVectorByKey(String collectionName, List<String> faceIds);
|
||||
|
||||
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK);
|
||||
|
||||
public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm);
|
||||
|
||||
public float searchMaxScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm);
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
package com.visual.face.search.engine.conf;
|
||||
|
||||
public class Constant {
|
||||
|
||||
public final static String IndexShardsNum = "shardsNum";
|
||||
public final static String IndexReplicasNum = "replicasNum";
|
||||
|
||||
public final static String ColumnNameFaceId = "face_id";
|
||||
public final static String ColumnNameSampleId = "sample_id";
|
||||
public final static String ColumnNameFaceVector = "face_vector";
|
||||
public final static String ColumnNameFaceScore = "face_score";
|
||||
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package com.visual.face.search.engine.exps;
|
||||
|
||||
|
||||
public class SearchEngineException extends RuntimeException{
|
||||
|
||||
public SearchEngineException() {
|
||||
}
|
||||
|
||||
public SearchEngineException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public SearchEngineException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public SearchEngineException(Throwable cause) {
|
||||
super(cause);
|
||||
}
|
||||
|
||||
public SearchEngineException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
|
||||
super(message, cause, enableSuppression, writableStackTrace);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,247 @@
|
||||
package com.visual.face.search.engine.impl;
|
||||
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import com.visual.face.search.engine.exps.SearchEngineException;
|
||||
import com.visual.face.search.engine.model.*;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.opensearch.action.DocWriteResponse;
|
||||
import org.opensearch.action.admin.indices.delete.DeleteIndexRequest;
|
||||
import org.opensearch.action.delete.DeleteRequest;
|
||||
import org.opensearch.action.delete.DeleteResponse;
|
||||
import org.opensearch.action.index.IndexRequest;
|
||||
import org.opensearch.action.index.IndexResponse;
|
||||
import org.opensearch.action.search.MultiSearchRequest;
|
||||
import org.opensearch.action.search.MultiSearchResponse;
|
||||
import org.opensearch.action.search.SearchRequest;
|
||||
import org.opensearch.client.RequestOptions;
|
||||
import org.opensearch.client.RestHighLevelClient;
|
||||
import org.opensearch.client.indices.CreateIndexRequest;
|
||||
import org.opensearch.client.indices.CreateIndexResponse;
|
||||
import org.opensearch.client.indices.GetIndexRequest;
|
||||
import org.opensearch.common.settings.Settings;
|
||||
import org.opensearch.index.query.*;
|
||||
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
|
||||
import org.opensearch.index.reindex.BulkByScrollResponse;
|
||||
import org.opensearch.index.reindex.DeleteByQueryRequest;
|
||||
import org.opensearch.rest.RestStatus;
|
||||
import org.opensearch.script.Script;
|
||||
import org.opensearch.search.SearchHit;
|
||||
import org.opensearch.search.builder.SearchSourceBuilder;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
public class OpenSearchEngine implements SearchEngine {
|
||||
|
||||
private RestHighLevelClient client;
|
||||
private MapParam params = new MapParam();
|
||||
|
||||
public OpenSearchEngine(RestHighLevelClient client){
|
||||
this(client, null);
|
||||
}
|
||||
|
||||
public OpenSearchEngine(RestHighLevelClient client, MapParam params){
|
||||
this.client = client;
|
||||
if(null != params) { this.params = params; }
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getEngine() {
|
||||
return this.client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean exist(String collectionName) {
|
||||
try {
|
||||
GetIndexRequest request = new GetIndexRequest(collectionName);
|
||||
return this.client.indices().exists(request, RequestOptions.DEFAULT);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean dropCollection(String collectionName) {
|
||||
try {
|
||||
DeleteIndexRequest request = new DeleteIndexRequest(collectionName);
|
||||
return this.client.indices().delete(request, RequestOptions.DEFAULT).isAcknowledged();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean createCollection(String collectionName, MapParam param) {
|
||||
try {
|
||||
//构建请求
|
||||
CreateIndexRequest createIndexRequest = new CreateIndexRequest(collectionName);
|
||||
createIndexRequest.settings(Settings.builder()
|
||||
.put("index.number_of_shards", param.getIndexShardsNum())
|
||||
.put("index.number_of_replicas", param.getIndexReplicasNum())
|
||||
);
|
||||
HashMap<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constant.ColumnNameSampleId, Map.of("type", "keyword"));
|
||||
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512"));
|
||||
createIndexRequest.mapping(Map.of("properties", properties));
|
||||
//创建集合
|
||||
CreateIndexResponse createIndexResponse = client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
return createIndexResponse.isAcknowledged();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors) {
|
||||
try {
|
||||
//构建请求
|
||||
IndexRequest request = new IndexRequest(collectionName)
|
||||
.id(faceId)
|
||||
.source(Map.of(
|
||||
Constant.ColumnNameSampleId, sampleId,
|
||||
Constant.ColumnNameFaceVector, vectors
|
||||
));
|
||||
//插入数据
|
||||
IndexResponse indexResponse = client.index(request, RequestOptions.DEFAULT);
|
||||
DocWriteResponse.Result result = indexResponse.getResult();
|
||||
return DocWriteResponse.Result.CREATED == result || DocWriteResponse.Result.UPDATED == result;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean deleteVectorByKey(String collectionName, String faceId) {
|
||||
try {
|
||||
DeleteRequest deleteDocumentRequest = new DeleteRequest(collectionName, faceId);
|
||||
DeleteResponse deleteResponse = client.delete(deleteDocumentRequest, RequestOptions.DEFAULT);
|
||||
return DocWriteResponse.Result.DELETED == deleteResponse.getResult();
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean deleteVectorByKey(String collectionName, List<String> keyIds) {
|
||||
try {
|
||||
String[] idArray = new String[keyIds.size()]; idArray = keyIds.toArray(idArray);
|
||||
QueryBuilder queryBuilder = new BoolQueryBuilder().must(QueryBuilders.idsQuery().addIds(idArray));
|
||||
DeleteByQueryRequest request = new DeleteByQueryRequest(collectionName).setQuery(queryBuilder);
|
||||
BulkByScrollResponse response = client.deleteByQuery(request, RequestOptions.DEFAULT);
|
||||
return response.getBulkFailures() != null && response.getBulkFailures().size() == 0;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK) {
|
||||
try {
|
||||
//构建搜索请求
|
||||
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
|
||||
for(float[] feature : features){
|
||||
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("field", Constant.ColumnNameFaceVector);
|
||||
params.put("space_type", algorithm);
|
||||
params.put("query_value", feature);
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
|
||||
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
|
||||
.query(scriptScoreQueryBuilder).size(topK)
|
||||
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段
|
||||
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
|
||||
multiSearchRequest.add(searchRequest);
|
||||
}
|
||||
//查询索引
|
||||
MultiSearchResponse response = this.client.msearch(multiSearchRequest, RequestOptions.DEFAULT);
|
||||
MultiSearchResponse.Item[] responses = response.getResponses();
|
||||
if(features.length != responses.length){
|
||||
throw new SearchEngineException("features.length != responses.length");
|
||||
}
|
||||
//解析数据
|
||||
List<SearchResult> result = new ArrayList<>();
|
||||
for(MultiSearchResponse.Item item : response.getResponses()){
|
||||
List<SearchDocument> documents = new ArrayList<>();
|
||||
SearchHit[] searchHits = item.getResponse().getHits().getHits();
|
||||
if(searchHits != null){
|
||||
for(SearchHit searchHit : searchHits){
|
||||
String faceId = searchHit.getId();
|
||||
float score = searchHit.getScore()-1;
|
||||
Map<String, Object> sourceMap = searchHit.getSourceAsMap();
|
||||
String sampleId = MapUtils.getString(sourceMap, Constant.ColumnNameSampleId);
|
||||
Object faceVector = MapUtils.getObject(sourceMap, Constant.ColumnNameFaceVector);
|
||||
SearchDocument document = SearchDocument.build(sampleId, faceId, score).setVectors(faceVector);
|
||||
documents.add(document);
|
||||
}
|
||||
}
|
||||
result.add(SearchResult.build(documents));
|
||||
}
|
||||
//返回结果信息
|
||||
return SearchResponse.build(SearchStatus.build(0, "success"), result);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm) {
|
||||
try {
|
||||
//构建请求
|
||||
QueryBuilder queryBuilder = new MatchQueryBuilder(Constant.ColumnNameSampleId, sampleId);
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("field", Constant.ColumnNameFaceVector);
|
||||
params.put("space_type", algorithm);
|
||||
params.put("query_value", feature);
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
|
||||
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
|
||||
.query(scriptScoreQueryBuilder)
|
||||
.fetchSource(false).size(10000); //是否需要索引字段
|
||||
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
|
||||
//搜索请求
|
||||
org.opensearch.action.search.SearchResponse response = this.client.search(searchRequest, RequestOptions.DEFAULT);
|
||||
if(RestStatus.OK == response.status()){
|
||||
SearchHit[] searchHits = response.getHits().getHits();
|
||||
Double minScore = Arrays.stream(searchHits).mapToDouble(SearchHit::getScore).min().orElse(2f);
|
||||
return minScore.floatValue()-1;
|
||||
}else{
|
||||
throw new RuntimeException("get score error!");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new SearchEngineException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float searchMaxScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm) {
|
||||
try {
|
||||
//构建请求
|
||||
QueryBuilder queryBuilder = new BoolQueryBuilder()
|
||||
.mustNot(new MatchQueryBuilder(Constant.ColumnNameSampleId, sampleId));
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("field", Constant.ColumnNameFaceVector);
|
||||
params.put("space_type", algorithm);
|
||||
params.put("query_value", feature);
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
|
||||
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
|
||||
.query(scriptScoreQueryBuilder)
|
||||
.fetchSource(false).size(1); //是否需要索引字段
|
||||
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
|
||||
//搜索请求
|
||||
org.opensearch.action.search.SearchResponse response = this.client.search(searchRequest, RequestOptions.DEFAULT);
|
||||
if(RestStatus.OK == response.status()){
|
||||
SearchHit[] searchHits = response.getHits().getHits();
|
||||
Double maxScore = Arrays.stream(searchHits).mapToDouble(SearchHit::getScore).max().orElse(1f);
|
||||
return maxScore.floatValue()-1;
|
||||
}else{
|
||||
throw new RuntimeException("get score error!");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new SearchEngineException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
package com.visual.face.search.server.engine.model;
|
||||
package com.visual.face.search.engine.model;
|
||||
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
@ -11,7 +11,6 @@ public class MapParam extends ConcurrentHashMap<String, Object> {
|
||||
return new MapParam();
|
||||
}
|
||||
|
||||
|
||||
public MapParam put(String key, Object value){
|
||||
if(null != key && null != value){
|
||||
super.put(key, value);
|
||||
@ -60,15 +59,15 @@ public class MapParam extends ConcurrentHashMap<String, Object> {
|
||||
}
|
||||
|
||||
/******************************************************************************************************************/
|
||||
public Long getMaxDocsPerSegment(){
|
||||
Long maxDocsPerSegment = this.getLong(Constant.ParamKeyMaxDocsPerSegment, 0L);
|
||||
public Long getIndexReplicasNum(){
|
||||
Long maxDocsPerSegment = this.getLong(Constant.IndexReplicasNum, 1L);
|
||||
maxDocsPerSegment = (null == maxDocsPerSegment || maxDocsPerSegment < 0) ? 0 : maxDocsPerSegment;
|
||||
return maxDocsPerSegment;
|
||||
}
|
||||
|
||||
public Integer getShardsNum(){
|
||||
Integer shardsNum = this.getInteger(Constant.ParamKeyShardsNum, 0);
|
||||
shardsNum = (null == shardsNum || shardsNum <= 0) ? 2 : shardsNum;
|
||||
public Integer getIndexShardsNum(){
|
||||
Integer shardsNum = this.getInteger(Constant.IndexShardsNum, 4);
|
||||
shardsNum = (null == shardsNum || shardsNum <= 0) ? 4 : shardsNum;
|
||||
return shardsNum;
|
||||
}
|
||||
|
@ -0,0 +1,72 @@
|
||||
package com.visual.face.search.engine.model;
|
||||
|
||||
import java.io.Serializable;
|
||||
import com.visual.face.search.engine.utils.NumberUtils;
|
||||
|
||||
public class SearchDocument implements Serializable {
|
||||
private float score;
|
||||
private String faceId;
|
||||
private String sampleId;
|
||||
private Float[] vectors;
|
||||
|
||||
public SearchDocument(){}
|
||||
|
||||
public SearchDocument(String sampleId, String faceId, float score) {
|
||||
this(sampleId, faceId, score, null);
|
||||
}
|
||||
|
||||
public SearchDocument(String sampleId, String faceId, float score, Float[] vectors) {
|
||||
this.score = score;
|
||||
this.faceId = faceId;
|
||||
this.sampleId = sampleId;
|
||||
this.vectors = vectors;
|
||||
}
|
||||
|
||||
public static SearchDocument build(String sampleId, String faceId, float score){
|
||||
return build(sampleId, faceId, score, null);
|
||||
}
|
||||
|
||||
public static SearchDocument build(String sampleId, String faceId, float score, Float[] vectors){
|
||||
return new SearchDocument(sampleId, faceId, score, vectors);
|
||||
}
|
||||
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public SearchDocument setScore(float score) {
|
||||
this.score = score;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getFaceId() {
|
||||
return faceId;
|
||||
}
|
||||
|
||||
public SearchDocument setFaceId(String faceId) {
|
||||
this.faceId = faceId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getSampleId() {
|
||||
return sampleId;
|
||||
}
|
||||
|
||||
public SearchDocument setSampleId(String sampleId) {
|
||||
this.sampleId = sampleId;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Float[] getVectors() {
|
||||
return vectors;
|
||||
}
|
||||
|
||||
public SearchDocument setVectors(Object vectors) {
|
||||
if(vectors == null){
|
||||
return this;
|
||||
}else{
|
||||
this.vectors = NumberUtils.getFloatArray(vectors);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package com.visual.face.search.server.engine.model;
|
||||
package com.visual.face.search.engine.model;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
@ -1,4 +1,4 @@
|
||||
package com.visual.face.search.server.engine.model;
|
||||
package com.visual.face.search.engine.model;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
@ -1,8 +1,8 @@
|
||||
package com.visual.face.search.server.engine.model;
|
||||
package com.visual.face.search.engine.model;
|
||||
|
||||
public class SearchStatus {
|
||||
private int code;
|
||||
private String reason;
|
||||
private String reason;
|
||||
|
||||
public SearchStatus(){}
|
||||
|
@ -0,0 +1,74 @@
|
||||
package com.visual.face.search.engine.utils;
|
||||
|
||||
import java.text.NumberFormat;
|
||||
import java.text.ParseException;
|
||||
import java.util.List;
|
||||
|
||||
public class NumberUtils {
|
||||
|
||||
public static Number getNumber(Object value) {
|
||||
if (value != null) {
|
||||
if (value instanceof Number) {
|
||||
return (Number)value;
|
||||
}
|
||||
if (value instanceof String) {
|
||||
try {
|
||||
String text = (String)value;
|
||||
return NumberFormat.getInstance().parse(text);
|
||||
} catch (ParseException var4) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Float getFloat(Object value) {
|
||||
Number answer = getNumber(value);
|
||||
if (answer == null) {
|
||||
return null;
|
||||
} else {
|
||||
return answer instanceof Float ? (Float)answer : answer.floatValue();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static Float[] getFloatArray(Object values) {
|
||||
if(null != values){
|
||||
if(values.getClass().isArray()){
|
||||
return getFloatArray((Object[]) values);
|
||||
}else if(values instanceof List){
|
||||
return getFloatArray((List)values);
|
||||
}else{
|
||||
throw new RuntimeException("type error for:"+values.getClass());
|
||||
}
|
||||
}else{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static Float[] getFloatArray(List<Object> values) {
|
||||
if(null != values){
|
||||
Float[] floats = new Float[values.size()];
|
||||
for(int i=0; i<floats.length; i++){
|
||||
floats[i] = getFloat(values.get(i));
|
||||
}
|
||||
return floats;
|
||||
}else{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static Float[] getFloatArray(Object[] values) {
|
||||
if(null != values){
|
||||
Float[] floats = new Float[values.length];
|
||||
for(int i=0; i<floats.length; i++){
|
||||
floats[i] = getFloat(values[i]);
|
||||
}
|
||||
return floats;
|
||||
}else{
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package com.visual.face.search.server.engine.utils;
|
||||
package com.visual.face.search.engine.utils;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
@ -0,0 +1,130 @@
|
||||
package com.visual.face.search.engine.test;
|
||||
|
||||
import com.visual.face.search.engine.impl.OpenSearchEngine;
|
||||
import com.visual.face.search.engine.model.MapParam;
|
||||
import com.visual.face.search.engine.model.SearchResponse;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.auth.AuthScope;
|
||||
import org.apache.http.auth.UsernamePasswordCredentials;
|
||||
import org.apache.http.client.CredentialsProvider;
|
||||
import org.apache.http.impl.client.BasicCredentialsProvider;
|
||||
import org.opensearch.client.RestClient;
|
||||
import org.opensearch.client.RestClientBuilder;
|
||||
import org.opensearch.client.RestHighLevelClient;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.security.KeyManagementException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.SecureRandom;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class SearchEngineTest {
|
||||
|
||||
public static RestHighLevelClient getOpenSearchClientImpl(){
|
||||
// final String hostName = "192.168.10.201";
|
||||
final String hostName = "172.16.36.229";
|
||||
final Integer hostPort = 9200;
|
||||
final String hostScheme = "https";
|
||||
final String userName = "admin";
|
||||
final String password = "admin";
|
||||
|
||||
//认证参数
|
||||
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
|
||||
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(userName, password));
|
||||
|
||||
//ssl设置
|
||||
final SSLContext sslContext;
|
||||
try {
|
||||
sslContext = SSLContext.getInstance("SSL");
|
||||
sslContext.init(null, new TrustManager[] { new X509TrustManager() {
|
||||
public X509Certificate[] getAcceptedIssuers() { return null; }
|
||||
public void checkClientTrusted(X509Certificate[] certs, String authType) {}
|
||||
public void checkServerTrusted(X509Certificate[] certs, String authType) {}
|
||||
}}, new SecureRandom());
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
//构建请求
|
||||
RestClientBuilder builder = RestClient.builder(new HttpHost(hostName, hostPort, hostScheme))
|
||||
.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder
|
||||
.setDefaultCredentialsProvider(credentialsProvider)
|
||||
.setSSLHostnameVerifier((hostname, session) -> true)
|
||||
.setSSLContext(sslContext)
|
||||
.setMaxConnTotal(10)
|
||||
.setMaxConnPerRoute(10)
|
||||
);
|
||||
//构建client
|
||||
return new RestHighLevelClient(builder);
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws InterruptedException {
|
||||
// String index_name = "app-opensearch-index-001";
|
||||
// String index_name = "app-opensearch-index-002";
|
||||
// String index_name = "python-test-index2";
|
||||
String index_name = "visual_search_namespace_1_collect_20211201_v05_ohr6_vector";
|
||||
|
||||
OpenSearchEngine engine = new OpenSearchEngine(getOpenSearchClientImpl());
|
||||
|
||||
boolean exist = engine.exist(index_name);
|
||||
System.out.println(exist);
|
||||
if(exist){
|
||||
// boolean drop = engine.dropCollection(index_name);
|
||||
// System.out.println(drop);
|
||||
}
|
||||
|
||||
// boolean create = engine.createCollection(index_name, MapParam.build());
|
||||
// System.out.println(create);
|
||||
//
|
||||
// for(int i=0; i<10; i++){
|
||||
// float[] vectors = new float[512];
|
||||
// vectors[i] = 0.23333333f;
|
||||
// boolean insert = engine.insertVector(index_name, "simple-0001", String.valueOf(i), vectors);
|
||||
// System.out.println("insert="+insert);
|
||||
// }
|
||||
//
|
||||
// Thread.sleep(2000);
|
||||
|
||||
// boolean delete = engine.deleteVectorByKey(index_name, "0");
|
||||
// System.out.println(delete);
|
||||
|
||||
// boolean delete1 = engine.deleteVectorByKey(index_name, Arrays.asList("1", "5", "9"));
|
||||
// System.out.println(delete1);
|
||||
|
||||
|
||||
// float[][] a = new float[2][];
|
||||
//
|
||||
float[] vectors = new float[512];
|
||||
vectors[0] = 0.768888f;
|
||||
vectors[1] = 20000.768888f;
|
||||
vectors[162] = 33333f;
|
||||
// a[0] = vectors;
|
||||
//
|
||||
// float[] vectors1 = new float[512];
|
||||
// vectors1[2] = 10000.54444f;
|
||||
// a[1] = vectors1;
|
||||
// SearchResponse searchResponse = engine.search(index_name, a, "cosinesimil",1);
|
||||
// System.out.println(searchResponse);
|
||||
|
||||
// engine.searchCount(index_name, vectors, "", 1);
|
||||
float minScore = engine.searchMinScoreBySampleId(index_name, "d4395b36984926a1934a0f9b916b32d21", vectors, "cosinesimil");
|
||||
float maxScore = engine.searchMaxScoreBySampleId(index_name, "d4395b36984926a1934a0f9b916b32d21", vectors, "cosinesimil");
|
||||
|
||||
System.out.println(minScore);
|
||||
System.out.println(maxScore);
|
||||
System.exit(1);
|
||||
|
||||
// Object y = vectors;
|
||||
// Object[] y1 = (Object[]) y;
|
||||
// System.out.println(y1.length);
|
||||
//
|
||||
// System.out.println(y.getClass().getComponentType());
|
||||
// System.out.println(y.getClass().isArray());
|
||||
// System.out.println(cc instanceof Float);
|
||||
// System.exit(1);
|
||||
}
|
||||
}
|
26
face-search-server/pom.xml
Executable file → Normal file
@ -5,7 +5,7 @@
|
||||
<parent>
|
||||
<artifactId>face-search</artifactId>
|
||||
<groupId>com.visual.face.search</groupId>
|
||||
<version>1.2.0</version>
|
||||
<version>2.1.0</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
@ -17,6 +17,22 @@
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-validation</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.hibernate.validator</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.hibernate.validator</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.visual.face.search</groupId>
|
||||
<artifactId>face-search-core</artifactId>
|
||||
@ -76,11 +92,11 @@
|
||||
<!--文档插件-->
|
||||
<dependency>
|
||||
<groupId>io.springfox</groupId>
|
||||
<artifactId>springfox-swagger2</artifactId>
|
||||
<artifactId>springfox-boot-starter</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.xiaoymin</groupId>
|
||||
<artifactId>swagger-bootstrap-ui</artifactId>
|
||||
<artifactId>knife4j-spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
@ -90,9 +106,7 @@
|
||||
<plugin>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-maven-plugin</artifactId>
|
||||
<configuration>
|
||||
<includeSystemScope>true</includeSystemScope>
|
||||
</configuration>
|
||||
<version>2.6.0</version>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
@ -1,49 +1,74 @@
|
||||
package com.visual.face.search.server.bootstrap.conf;
|
||||
|
||||
import com.alibaba.proxima.be.client.ConnectParam;
|
||||
import com.alibaba.proxima.be.client.ProximaGrpcSearchClient;
|
||||
import com.alibaba.proxima.be.client.ProximaSearchClient;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.impl.MilvusSearchEngine;
|
||||
import com.visual.face.search.server.engine.impl.ProximaSearchEngine;
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.engine.impl.OpenSearchEngine;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.auth.AuthScope;
|
||||
import org.apache.http.auth.UsernamePasswordCredentials;
|
||||
import org.apache.http.client.CredentialsProvider;
|
||||
import org.apache.http.impl.client.BasicCredentialsProvider;
|
||||
import org.opensearch.client.RestClient;
|
||||
import org.opensearch.client.RestClientBuilder;
|
||||
import org.opensearch.client.RestHighLevelClient;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.security.KeyManagementException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.SecureRandom;
|
||||
import java.security.cert.X509Certificate;
|
||||
|
||||
@Configuration("visualEngineConfig")
|
||||
public class EngineConfig {
|
||||
//日志
|
||||
public Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@Value("${visual.engine.selected:proxima}")
|
||||
private String selected;
|
||||
|
||||
@Value("${visual.engine.proxima.host}")
|
||||
private String proximaHost;
|
||||
@Value("${visual.engine.proxima.port:16000}")
|
||||
private Integer proximaPort;
|
||||
|
||||
@Value("${visual.engine.milvus.host}")
|
||||
private String milvusHost;
|
||||
@Value("${visual.engine.milvus.port:19530}")
|
||||
private Integer milvusPort;
|
||||
@Value("${visual.engine.open-search.host:localhost}")
|
||||
private String openSearchHost;
|
||||
@Value("${visual.engine.open-search.port:9200}")
|
||||
private Integer openSearchPort;
|
||||
@Value("${visual.engine.open-search.scheme:https}")
|
||||
private String openSearchScheme;
|
||||
@Value("${visual.engine.open-search.username:admin}")
|
||||
private String openSearchUserName;
|
||||
@Value("${visual.engine.open-search.password:admin}")
|
||||
private String openSearchPassword;
|
||||
|
||||
@Bean(name = "visualSearchEngine")
|
||||
public SearchEngine getSearchEngine(){
|
||||
if(selected.equalsIgnoreCase("milvus")){
|
||||
logger.info("current vector engine is milvus");
|
||||
io.milvus.param.ConnectParam connectParam = io.milvus.param.ConnectParam.newBuilder().withHost(milvusHost).withPort(milvusPort).build();
|
||||
MilvusServiceClient client = new MilvusServiceClient(connectParam);
|
||||
return new MilvusSearchEngine(client);
|
||||
}else{
|
||||
logger.info("current vector engine is proxima");
|
||||
ConnectParam connectParam = ConnectParam.newBuilder().withHost(proximaHost).withPort(proximaPort).build();
|
||||
ProximaSearchClient client = new ProximaGrpcSearchClient(connectParam);
|
||||
return new ProximaSearchEngine(client);
|
||||
public SearchEngine simpleSearchEngine(){
|
||||
//认证参数
|
||||
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
|
||||
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(openSearchUserName, openSearchPassword));
|
||||
//ssl设置
|
||||
final SSLContext sslContext;
|
||||
try {
|
||||
sslContext = SSLContext.getInstance("SSL");
|
||||
sslContext.init(null, new TrustManager[] { new X509TrustManager() {
|
||||
public X509Certificate[] getAcceptedIssuers() { return null; }
|
||||
public void checkClientTrusted(X509Certificate[] certs, String authType) {}
|
||||
public void checkServerTrusted(X509Certificate[] certs, String authType) {}
|
||||
}}, new SecureRandom());
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException e) {
|
||||
logger.error("create SearchEngine error:", e);
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
//构建请求
|
||||
RestClientBuilder builder = RestClient.builder(new HttpHost(openSearchHost, openSearchPort, openSearchScheme))
|
||||
.setHttpClientConfigCallback(httpClientBuilder -> httpClientBuilder
|
||||
.setDefaultCredentialsProvider(credentialsProvider)
|
||||
.setSSLHostnameVerifier((hostname, session) -> true)
|
||||
.setSSLContext(sslContext)
|
||||
.setMaxConnTotal(10)
|
||||
.setMaxConnPerRoute(10)
|
||||
);
|
||||
//构建client
|
||||
return new OpenSearchEngine(new RestHighLevelClient(builder));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,39 @@
|
||||
package com.visual.face.search.server.bootstrap.conf;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import springfox.documentation.spi.DocumentationType;
|
||||
import springfox.documentation.builders.PathSelectors;
|
||||
import springfox.documentation.builders.ApiInfoBuilder;
|
||||
import springfox.documentation.spring.web.plugins.Docket;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import springfox.documentation.oas.annotations.EnableOpenApi;
|
||||
import springfox.documentation.builders.RequestHandlerSelectors;
|
||||
import com.github.xiaoymin.knife4j.spring.annotations.EnableKnife4j;
|
||||
|
||||
@Configuration
|
||||
@EnableOpenApi
|
||||
@EnableKnife4j
|
||||
public class Knife4jConfig {
|
||||
|
||||
@Value("${visual.swagger.enable:true}")
|
||||
private Boolean enable;
|
||||
|
||||
@Bean
|
||||
public Docket createRestApi() {
|
||||
return new Docket(DocumentationType.OAS_30)
|
||||
.enable(enable)
|
||||
.apiInfo(new ApiInfoBuilder()
|
||||
.title("人脸搜索服务API")
|
||||
.description("人脸搜索服务API")
|
||||
.version("2.1.0")
|
||||
.build())
|
||||
.groupName("2.1.0")
|
||||
.select()
|
||||
.apis(RequestHandlerSelectors.basePackage("com.visual.face.search.server.controller.server"))
|
||||
.paths(PathSelectors.any())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,10 @@
|
||||
package com.visual.face.search.server.bootstrap.conf;
|
||||
|
||||
import com.visual.face.search.core.base.FaceAlignment;
|
||||
import com.visual.face.search.core.base.FaceDetection;
|
||||
import com.visual.face.search.core.base.FaceKeyPoint;
|
||||
import com.visual.face.search.core.base.FaceRecognition;
|
||||
import com.visual.face.search.core.base.*;
|
||||
import com.visual.face.search.core.extract.FaceFeatureExtractor;
|
||||
import com.visual.face.search.core.extract.FaceFeatureExtractorImpl;
|
||||
import com.visual.face.search.core.models.*;
|
||||
import com.visual.face.search.server.utils.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@ -15,8 +13,8 @@ import org.springframework.context.annotation.Configuration;
|
||||
@Configuration("visualModelConfig")
|
||||
public class ModelConfig {
|
||||
|
||||
@Value("${spring.profiles.active}")
|
||||
private String profile;
|
||||
@Value("${visual.model.baseModelPath}")
|
||||
private String baseModelPath;
|
||||
|
||||
@Value("${visual.model.faceDetection.name}")
|
||||
private String faceDetectionName;
|
||||
@ -50,6 +48,14 @@ public class ModelConfig {
|
||||
private Integer faceRecognitionNameThread;
|
||||
|
||||
|
||||
@Value("${visual.model.faceAttribute.name:InsightAttributeDetection}")
|
||||
private String faceAttributeDetectionName;
|
||||
@Value("${visual.model.faceAttribute.modelPath}")
|
||||
private String[] faceAttributeDetectionNameModel;
|
||||
@Value("${visual.model.faceAttribute.thread:4}")
|
||||
private Integer faceAttributeDetectionNameThread;
|
||||
|
||||
|
||||
/**
|
||||
* 获取人脸识别模型
|
||||
* @return
|
||||
@ -71,14 +77,12 @@ public class ModelConfig {
|
||||
*/
|
||||
@Bean(name = "visualBackupFaceDetection")
|
||||
public FaceDetection getBackupFaceDetection(){
|
||||
if(faceDetectionName.equalsIgnoreCase(backupFaceDetectionName)){
|
||||
return null;
|
||||
}else if(backupFaceDetectionName.equalsIgnoreCase("PcnNetworkFaceDetection")){
|
||||
if(backupFaceDetectionName.equalsIgnoreCase("PcnNetworkFaceDetection")){
|
||||
return new PcnNetworkFaceDetection(getModelPath(backupFaceDetectionName, backupFaceDetectionModel), backupFaceDetectionThread);
|
||||
}else if(backupFaceDetectionName.equalsIgnoreCase("InsightScrfdFaceDetection")){
|
||||
return new InsightScrfdFaceDetection(getModelPath(backupFaceDetectionName, backupFaceDetectionModel)[0], backupFaceDetectionThread);
|
||||
}else{
|
||||
return new PcnNetworkFaceDetection(backupFaceDetectionModel, backupFaceDetectionThread);
|
||||
return this.getFaceDetection();
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,11 +122,26 @@ public class ModelConfig {
|
||||
public FaceRecognition getFaceRecognition(){
|
||||
if(faceRecognitionName.equalsIgnoreCase("InsightArcFaceRecognition")){
|
||||
return new InsightArcFaceRecognition(getModelPath(faceRecognitionName, faceRecognitionNameModel)[0], faceRecognitionNameThread);
|
||||
}else if(faceRecognitionName.equalsIgnoreCase("SeetaFaceOpenRecognition")){
|
||||
return new SeetaFaceOpenRecognition(getModelPath(faceRecognitionName, faceRecognitionNameModel)[0], faceRecognitionNameThread);
|
||||
}else{
|
||||
return new InsightArcFaceRecognition(getModelPath(faceRecognitionName, faceRecognitionNameModel)[0], faceRecognitionNameThread);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 人脸属性检测
|
||||
* @return
|
||||
*/
|
||||
@Bean(name = "visualAttributeDetection")
|
||||
public InsightAttributeDetection getAttributeDetection(){
|
||||
if(faceAttributeDetectionName.equalsIgnoreCase("InsightAttributeDetection")){
|
||||
return new InsightAttributeDetection(getModelPath(faceAttributeDetectionName, faceAttributeDetectionNameModel)[0], faceAttributeDetectionNameThread);
|
||||
}else{
|
||||
return new InsightAttributeDetection(getModelPath(faceAttributeDetectionName, faceAttributeDetectionNameModel)[0], faceAttributeDetectionNameThread);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建特征提取器
|
||||
* @param faceDetection 人脸识别模型
|
||||
@ -136,8 +155,20 @@ public class ModelConfig {
|
||||
@Qualifier("visualBackupFaceDetection")FaceDetection backupFaceDetection,
|
||||
@Qualifier("visualFaceKeyPoint")FaceKeyPoint faceKeyPoint,
|
||||
@Qualifier("visualFaceAlignment")FaceAlignment faceAlignment,
|
||||
@Qualifier("visualFaceRecognition")FaceRecognition faceRecognition){
|
||||
return new FaceFeatureExtractorImpl(faceDetection, backupFaceDetection, faceKeyPoint, faceAlignment, faceRecognition);
|
||||
@Qualifier("visualFaceRecognition")FaceRecognition faceRecognition,
|
||||
@Qualifier("visualAttributeDetection") FaceAttribute faceAttribute
|
||||
){
|
||||
if(faceDetection.getClass().isAssignableFrom(backupFaceDetection.getClass())){
|
||||
return new FaceFeatureExtractorImpl(
|
||||
faceDetection, null, faceKeyPoint,
|
||||
faceAlignment, faceRecognition, faceAttribute
|
||||
);
|
||||
}else{
|
||||
return new FaceFeatureExtractorImpl(
|
||||
faceDetection, backupFaceDetection, faceKeyPoint,
|
||||
faceAlignment, faceRecognition, faceAttribute
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -146,10 +177,12 @@ public class ModelConfig {
|
||||
* @return
|
||||
*/
|
||||
private String[] getModelPath(String modelName, String modelPath[]){
|
||||
|
||||
String basePath = "face-search-core/src/main/resources/";
|
||||
if("docker".equalsIgnoreCase(profile)){
|
||||
basePath = "/app/face-search/";
|
||||
if(StringUtils.isNotEmpty(this.baseModelPath)){
|
||||
basePath = this.baseModelPath;
|
||||
basePath = basePath.replaceAll("^\'|\'$", "");
|
||||
basePath = basePath.replaceAll("^\"|\"$", "");
|
||||
basePath = basePath.endsWith("/") ? basePath : basePath +"/";
|
||||
}
|
||||
|
||||
if((null == modelPath || modelPath.length != 3) && "PcnNetworkFaceDetection".equalsIgnoreCase(modelName)){
|
||||
@ -172,6 +205,18 @@ public class ModelConfig {
|
||||
return new String[]{basePath + "model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx"};
|
||||
}
|
||||
|
||||
if((null == modelPath || modelPath.length != 1) && "SeetaFaceOpenRecognition".equalsIgnoreCase(modelName)){
|
||||
return new String[]{basePath + "model/onnx/recognition_face_seeta/face_recognizer_512.onnx"};
|
||||
}
|
||||
|
||||
if((null == modelPath || modelPath.length != 1) && "InsightAttributeDetection".equalsIgnoreCase(modelName)){
|
||||
return new String[]{basePath + "model/onnx/attribute_gender_age/insight_gender_age.onnx"};
|
||||
}
|
||||
|
||||
if((null == modelPath || modelPath.length != 1) && "SeetaMaskFaceKeyPoint".equalsIgnoreCase(modelName)){
|
||||
return new String[]{basePath + "model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx"};
|
||||
}
|
||||
|
||||
return modelPath;
|
||||
}
|
||||
}
|
||||
|
@ -1,45 +0,0 @@
|
||||
package com.visual.face.search.server.bootstrap.conf;
|
||||
|
||||
import com.github.xiaoymin.swaggerbootstrapui.annotations.EnableSwaggerBootstrapUI;
|
||||
import io.swagger.annotations.Api;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
|
||||
import springfox.documentation.builders.ApiInfoBuilder;
|
||||
import springfox.documentation.builders.RequestHandlerSelectors;
|
||||
import springfox.documentation.service.ApiInfo;
|
||||
import springfox.documentation.spi.DocumentationType;
|
||||
import springfox.documentation.spring.web.plugins.Docket;
|
||||
import springfox.documentation.swagger2.annotations.EnableSwagger2;
|
||||
|
||||
@Configuration
|
||||
@EnableSwagger2
|
||||
@EnableSwaggerBootstrapUI
|
||||
public class SwaggerConfig implements WebMvcConfigurer {
|
||||
|
||||
@Value("${visual.swagger.enable:true}")
|
||||
private Boolean enable;
|
||||
|
||||
@Bean
|
||||
public Docket ProductApi() {
|
||||
return new Docket(DocumentationType.SWAGGER_2)
|
||||
.enable(enable)
|
||||
.useDefaultResponseMessages(false)
|
||||
.forCodeGeneration(false)
|
||||
.pathMapping("/")
|
||||
.apiInfo(apiInfo())
|
||||
.select()
|
||||
.apis(RequestHandlerSelectors.withClassAnnotation(Api.class))
|
||||
.build();
|
||||
}
|
||||
|
||||
private ApiInfo apiInfo() {
|
||||
return new ApiInfoBuilder()
|
||||
.title("人脸搜索服务API")
|
||||
.description("人脸搜索服务API")
|
||||
.version("1.1.0")
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
@ -89,7 +89,7 @@ public class DruidConfig
|
||||
Filter filter = new Filter()
|
||||
{
|
||||
@Override
|
||||
public void init(javax.servlet.FilterConfig filterConfig) throws ServletException
|
||||
public void init(FilterConfig filterConfig) throws ServletException
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,23 @@
|
||||
package com.visual.face.search.server.controller.server.api;
|
||||
|
||||
import com.visual.face.search.server.domain.common.ResponseInfo;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface AdminControllerApi {
|
||||
|
||||
/**
|
||||
*获取命名空间集合列表
|
||||
* @return 命名空间列表
|
||||
*/
|
||||
public ResponseInfo<List<String>> getNamespaceList();
|
||||
|
||||
/**
|
||||
*根据命名空间查看集合列表
|
||||
* @param namespace 命名空间
|
||||
* @return 集合列表
|
||||
*/
|
||||
public ResponseInfo<List<Map<String, String>>> getCollectList(String namespace);
|
||||
|
||||
}
|
@ -0,0 +1,21 @@
|
||||
package com.visual.face.search.server.controller.server.impl;
|
||||
|
||||
import com.visual.face.search.server.controller.base.BaseController;
|
||||
import com.visual.face.search.server.controller.server.api.AdminControllerApi;
|
||||
import com.visual.face.search.server.domain.common.ResponseInfo;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class AdminControllerImpl extends BaseController implements AdminControllerApi {
|
||||
|
||||
@Override
|
||||
public ResponseInfo<List<String>> getNamespaceList() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ResponseInfo<List<Map<String, String>>> getCollectList(String namespace) {
|
||||
return null;
|
||||
}
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package com.visual.face.search.server.controller.server.restful;
|
||||
|
||||
import com.visual.face.search.server.controller.server.impl.AdminControllerImpl;
|
||||
import com.visual.face.search.server.domain.common.ResponseInfo;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import io.swagger.annotations.ApiParam;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Api(tags="06、管理员接口")
|
||||
@RestController("visualAdminController")
|
||||
@RequestMapping("/visual/admin")
|
||||
public class AdminController extends AdminControllerImpl {
|
||||
|
||||
@ApiOperation(value="1、获取命名空间集合列表")
|
||||
@Override
|
||||
@ResponseBody
|
||||
@RequestMapping(value = "/getNamespaceList", method = RequestMethod.GET)
|
||||
public ResponseInfo<List<String>> getNamespaceList() {
|
||||
return super.getNamespaceList();
|
||||
}
|
||||
|
||||
@ApiOperation(value="2、根据命名空间查看集合列表")
|
||||
@Override
|
||||
@ResponseBody
|
||||
@RequestMapping(value = "/getCollectList", method = RequestMethod.GET)
|
||||
public ResponseInfo<List<Map<String, String>>> getCollectList(@ApiParam(value = "命名空间", name="namespace", required=true) @RequestParam(value = "namespace", required = true) String namespace) {
|
||||
return super.getCollectList(namespace);
|
||||
}
|
||||
|
||||
}
|
@ -28,7 +28,7 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
/**数据分片中最大的文件个数**/
|
||||
@Min(value = 0, message = "maxDocsPerSegment must greater than or equal to 0")
|
||||
@ApiModelProperty(value="数据分片中最大的文件个数,默认为0(不限制),仅对Proxima引擎生效", position = 3,required = false)
|
||||
private Long maxDocsPerSegment;
|
||||
private Long replicasNum;
|
||||
/**数据分片中最大的文件个数**/
|
||||
@Min(value = 0, message = "shardsNum must greater than or equal to 0")
|
||||
@ApiModelProperty(value="要创建的集合的分片数,默认为0(即系统默认),仅对Milvus引擎生效", position = 4,required = false)
|
||||
@ -39,9 +39,6 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
/**自定义的人脸字段**/
|
||||
@ApiModelProperty(value="自定义的人脸属性字段", position = 6,required = false)
|
||||
private List<FiledColumn> faceColumns = new ArrayList<>();
|
||||
/**启用binlog同步**/
|
||||
@ApiModelProperty(value="启用binlog同步。扩展字段,暂不支持该功能", position = 7,required = false)
|
||||
private Boolean syncBinLog;
|
||||
/**是否保留图片及人脸信息**/
|
||||
@ApiModelProperty(value="是否保留图片及人脸信息", position = 8,required = false)
|
||||
private Boolean storageFaceInfo;
|
||||
@ -86,12 +83,12 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public Long getMaxDocsPerSegment() {
|
||||
return maxDocsPerSegment;
|
||||
public Long getReplicasNum() {
|
||||
return replicasNum;
|
||||
}
|
||||
|
||||
public ExtendsVo setMaxDocsPerSegment(Long maxDocsPerSegment) {
|
||||
this.maxDocsPerSegment = maxDocsPerSegment;
|
||||
public ExtendsVo setReplicasNum(Long replicasNum) {
|
||||
this.replicasNum = replicasNum;
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
@ -126,15 +123,6 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public boolean isSyncBinLog() {
|
||||
return null == syncBinLog ? false : syncBinLog;
|
||||
}
|
||||
|
||||
public ExtendsVo setSyncBinLog(Boolean syncBinLog) {
|
||||
this.syncBinLog = syncBinLog;
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public boolean getStorageFaceInfo() {
|
||||
return null == storageFaceInfo ? false : storageFaceInfo;
|
||||
}
|
||||
|
@ -16,16 +16,13 @@ public class SampleFaceVo implements Comparable<SampleFaceVo>, Serializable {
|
||||
@ApiModelProperty(value="人脸分数:[0,100]", position = 3, required = true)
|
||||
private Float faceScore;
|
||||
/**转换后的置信度**/
|
||||
@ApiModelProperty(value="向量距离:>=0", position = 4, required = true)
|
||||
private Float distance;
|
||||
/**转换后的置信度**/
|
||||
@ApiModelProperty(value="转换后的置信度:[-100,100],值越大,相似度越高。", position = 5, required = true)
|
||||
@ApiModelProperty(value="转换后的置信度:[-100,100],值越大,相似度越高。", position = 4, required = true)
|
||||
private Float confidence;
|
||||
/**样本扩展的额外数据**/
|
||||
@ApiModelProperty(value="样本扩展的额外数据", position = 6, required = false)
|
||||
@ApiModelProperty(value="样本扩展的额外数据", position = 5, required = false)
|
||||
private FieldKeyValues sampleData;
|
||||
/**人脸扩展的额外数据**/
|
||||
@ApiModelProperty(value="人脸扩展的额外数据", position = 7, required = false)
|
||||
@ApiModelProperty(value="人脸扩展的额外数据", position = 6, required = false)
|
||||
private FieldKeyValues faceData;
|
||||
|
||||
/**
|
||||
@ -76,14 +73,6 @@ public class SampleFaceVo implements Comparable<SampleFaceVo>, Serializable {
|
||||
this.faceScore = faceScore;
|
||||
}
|
||||
|
||||
public Float getDistance() {
|
||||
return distance;
|
||||
}
|
||||
|
||||
public void setDistance(Float distance) {
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
public Float getConfidence() {
|
||||
return confidence;
|
||||
}
|
||||
|
@ -30,13 +30,16 @@ public class FaceSearchReqVo extends BaseVo {
|
||||
@Range(min = -100, max = 100, message = "faceScoreThreshold is not in the range")
|
||||
@ApiModelProperty(value="人脸匹配分数阈值,范围:[-100,100]:默认0", position = 4, required = false)
|
||||
private Float confidenceThreshold = 0f;
|
||||
/**选择搜索评分的算法,默认余弦相似度(COSINESIMIL),可选参数:L1、L2、LINF、COSINESIMIL、INNERPRODUCT、HAMMINGBIT**/
|
||||
@ApiModelProperty(hidden = true, value="选择搜索评分的算法,默认是余弦相似度(COSINESIMIL),可选参数:L1、L2、LINF、COSINESIMIL、INNERPRODUCT、HAMMINGBIT", position = 5, required = false)
|
||||
private String algorithm = SearchAlgorithm.COSINESIMIL.name();
|
||||
/**搜索条数:默认10**/
|
||||
@Min(value = 0, message = "limit must greater than or equal to 0")
|
||||
@ApiModelProperty(value="最大搜索条数:默认5", position = 5, required = false)
|
||||
@ApiModelProperty(value="最大搜索条数:默认5", position = 6, required = false)
|
||||
private Integer limit;
|
||||
/**对输入图像中多少个人脸进行检索比对**/
|
||||
@Min(value = 0, message = "maxFaceNum must greater than or equal to 0")
|
||||
@ApiModelProperty(value="对输入图像中多少个人脸进行检索比对:默认5", position = 6, required = false)
|
||||
@ApiModelProperty(value="对输入图像中多少个人脸进行检索比对:默认5", position = 7, required = false)
|
||||
private Integer maxFaceNum;
|
||||
|
||||
/**
|
||||
@ -95,6 +98,21 @@ public class FaceSearchReqVo extends BaseVo {
|
||||
return this;
|
||||
}
|
||||
|
||||
public SearchAlgorithm getAlgorithm() {
|
||||
if(null != algorithm && !algorithm.isEmpty()){
|
||||
return SearchAlgorithm.valueOf(this.algorithm);
|
||||
}else{
|
||||
return SearchAlgorithm.COSINESIMIL;
|
||||
}
|
||||
}
|
||||
|
||||
public FaceSearchReqVo setAlgorithm(String algorithm) {
|
||||
if(null != algorithm){
|
||||
this.algorithm = algorithm;
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Integer getLimit() {
|
||||
return limit;
|
||||
}
|
||||
|
@ -0,0 +1,22 @@
|
||||
package com.visual.face.search.server.domain.request;
|
||||
|
||||
public enum SearchAlgorithm {
|
||||
L1("l1"),
|
||||
L2("l2"),
|
||||
LINF("linf"),
|
||||
HAMMINGBIT("innerproduct"),
|
||||
INNERPRODUCT("hammingbit"),
|
||||
COSINESIMIL("cosinesimil");
|
||||
|
||||
|
||||
private String algorithm;
|
||||
|
||||
SearchAlgorithm(String algorithm){
|
||||
this.algorithm = algorithm;
|
||||
}
|
||||
|
||||
public String algorithm(){
|
||||
return this.algorithm;
|
||||
}
|
||||
|
||||
}
|
@ -1,26 +0,0 @@
|
||||
package com.visual.face.search.server.engine.api;
|
||||
|
||||
import com.visual.face.search.server.engine.model.MapParam;
|
||||
import com.visual.face.search.server.engine.model.SearchResponse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface SearchEngine {
|
||||
|
||||
public Object getEngine();
|
||||
|
||||
public boolean exist(String collectionName);
|
||||
|
||||
public boolean dropCollection(String collectionName);
|
||||
|
||||
public boolean createCollection(String collectionName, MapParam param);
|
||||
|
||||
public boolean insertVector(String collectionName, Long keyId, String faceID, float[] vectors);
|
||||
|
||||
public boolean deleteVectorByKey(String collectionName, Long keyId);
|
||||
|
||||
public boolean deleteVectorByKey(String collectionName, List<Long> keyIds);
|
||||
|
||||
public SearchResponse search(String collectionName, float[][] features, int topK);
|
||||
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
package com.visual.face.search.server.engine.conf;
|
||||
|
||||
public class Constant {
|
||||
|
||||
public final static String ParamKeyShardsNum = "shardsNum";
|
||||
public final static String ParamKeyMaxDocsPerSegment = "maxDocsPerSegment";
|
||||
|
||||
public final static String ColumnPrimaryKey = "id";
|
||||
public final static String ColumnNameFaceId = "face_id";
|
||||
public final static String ColumnNameFaceScore = "face_score";
|
||||
public final static String ColumnNameFaceIndex = "face_index";
|
||||
public final static String ColumnNameFaceVector = "face_vector";
|
||||
public final static String ColumnNameSampleId = "sample_id";
|
||||
|
||||
}
|
@ -1,199 +0,0 @@
|
||||
package com.visual.face.search.server.engine.impl;
|
||||
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.server.engine.model.*;
|
||||
import com.visual.face.search.server.engine.utils.VectorUtils;
|
||||
import io.milvus.response.SearchResultsWrapper;
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.milvus.grpc.*;
|
||||
import io.milvus.param.*;
|
||||
import io.milvus.param.collection.*;
|
||||
import io.milvus.param.dml.DeleteParam;
|
||||
import io.milvus.param.dml.InsertParam;
|
||||
import io.milvus.param.dml.SearchParam;
|
||||
import io.milvus.param.index.CreateIndexParam;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class MilvusSearchEngine implements SearchEngine {
|
||||
|
||||
private static final Integer SUCCESS_STATUE = 0;
|
||||
|
||||
private MilvusServiceClient client;
|
||||
|
||||
public MilvusSearchEngine(MilvusServiceClient client) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getEngine(){
|
||||
return this.client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean exist(String collectionName) {
|
||||
HasCollectionParam param = HasCollectionParam.newBuilder().withCollectionName(collectionName).build();
|
||||
R<Boolean> response = this.client.hasCollection(param);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
return response.getData();
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean dropCollection(String collectionName) {
|
||||
DropCollectionParam param = DropCollectionParam.newBuilder().withCollectionName(collectionName).build();
|
||||
R<RpcStatus> response = this.client.dropCollection(param);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean createCollection(String collectionName, MapParam param) {
|
||||
FieldType keyFieldType = FieldType.newBuilder()
|
||||
.withName(Constant.ColumnPrimaryKey)
|
||||
.withDescription("id")
|
||||
.withDataType(DataType.Int64)
|
||||
.withPrimaryKey(true)
|
||||
.withAutoID(false)
|
||||
.build();
|
||||
|
||||
FieldType indexFieldType = FieldType.newBuilder()
|
||||
.withName(Constant.ColumnNameFaceIndex)
|
||||
.withDescription("face vector")
|
||||
.withDataType(DataType.FloatVector)
|
||||
.withDimension(512)
|
||||
.build();
|
||||
|
||||
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withDescription(collectionName)
|
||||
.addFieldType(keyFieldType)
|
||||
.addFieldType(indexFieldType)
|
||||
.withShardsNum(param.getShardsNum())
|
||||
.build();
|
||||
|
||||
CreateIndexParam indexParam = CreateIndexParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withFieldName(Constant.ColumnNameFaceIndex)
|
||||
.withIndexType(IndexType.IVF_FLAT)
|
||||
.withMetricType(MetricType.L2)
|
||||
.withExtraParam("{\"nlist\":128}")
|
||||
.withSyncMode(Boolean.TRUE)
|
||||
.build();
|
||||
|
||||
LoadCollectionParam loadParam = LoadCollectionParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.build();
|
||||
|
||||
R<RpcStatus> response = this.client.createCollection(createCollectionReq);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
R<RpcStatus> indexResponse = this.client.createIndex(indexParam);
|
||||
if(SUCCESS_STATUE.equals(indexResponse.getStatus())){
|
||||
R<RpcStatus> loadResponse = this.client.loadCollection(loadParam);
|
||||
if(SUCCESS_STATUE.equals(loadResponse.getStatus())){
|
||||
return true;
|
||||
}else{
|
||||
this.dropCollection(collectionName);
|
||||
return false;
|
||||
}
|
||||
}else{
|
||||
this.dropCollection(collectionName);
|
||||
return false;
|
||||
}
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean insertVector(String collectionName, Long keyId, String faceID, float[] vectors) {
|
||||
List<InsertParam.Field> fields = new ArrayList<>();
|
||||
fields.add(new InsertParam.Field(Constant.ColumnPrimaryKey, DataType.Int64, Collections.singletonList(keyId)));
|
||||
fields.add(new InsertParam.Field(Constant.ColumnNameFaceIndex, DataType.FloatVector, Collections.singletonList(VectorUtils.convertVector(vectors))));
|
||||
|
||||
InsertParam insertParam = InsertParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withFields(fields)
|
||||
.build();
|
||||
|
||||
R<MutationResult> response = this.client.insert(insertParam);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean deleteVectorByKey(String collectionName, Long keyId) {
|
||||
String deleteExpr = Constant.ColumnPrimaryKey + " in " + "[" + keyId + "]";
|
||||
DeleteParam build = DeleteParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withExpr(deleteExpr)
|
||||
.build();
|
||||
|
||||
R<MutationResult> response = this.client.delete(build);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean deleteVectorByKey(String collectionName, List<Long> keyIds) {
|
||||
String deleteExpr = Constant.ColumnPrimaryKey + " in " + keyIds.toString();
|
||||
DeleteParam build = DeleteParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withExpr(deleteExpr)
|
||||
.build();
|
||||
|
||||
R<MutationResult> response = this.client.delete(build);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchResponse search(String collectionName, float[][] features, int topK) {
|
||||
SearchParam searchParam = SearchParam.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withMetricType(MetricType.L2)
|
||||
.withParams("{\"nprobe\": 128}")
|
||||
.withOutFields(Collections.singletonList(Constant.ColumnPrimaryKey))
|
||||
.withTopK(topK)
|
||||
.withVectors(VectorUtils.convertVector(features))
|
||||
.withVectorFieldName(Constant.ColumnNameFaceIndex)
|
||||
.build();
|
||||
R<SearchResults> response = this.client.search(searchParam);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
SearchStatus status = SearchStatus.build(0, "success");
|
||||
List<SearchResult> result = new ArrayList<>();
|
||||
if(response.getData().getResults().hasIds()){
|
||||
SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
|
||||
for (int i = 0; i < features.length; ++i) {
|
||||
List<SearchDocument> documents = new ArrayList<>();
|
||||
List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
|
||||
for(SearchResultsWrapper.IDScore scoreId : scores){
|
||||
long primaryKey = scoreId.getLongID();
|
||||
float score = scoreId.getScore();
|
||||
documents.add(SearchDocument.build(primaryKey, score, null));
|
||||
}
|
||||
result.add(SearchResult.build(documents));
|
||||
}
|
||||
}
|
||||
return SearchResponse.build(status, result);
|
||||
}else{
|
||||
throw new RuntimeException(response.getMessage(), response.getException());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -1,174 +0,0 @@
|
||||
package com.visual.face.search.server.engine.impl;
|
||||
|
||||
import com.alibaba.proxima.be.client.*;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.server.engine.model.*;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class ProximaSearchEngine implements SearchEngine {
|
||||
|
||||
private ProximaSearchClient client;
|
||||
|
||||
/**
|
||||
* 构造获取连接对象
|
||||
* @param client
|
||||
*/
|
||||
public ProximaSearchEngine(ProximaSearchClient client){
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getEngine(){
|
||||
return this.client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean exist(String collectionName) {
|
||||
DescribeCollectionResponse response = client.describeCollection(collectionName);
|
||||
return response.ok();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean dropCollection(String collectionName) {
|
||||
if(exist(collectionName)){
|
||||
Status status = client.dropCollection(collectionName);
|
||||
return status.ok();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean createCollection(String collectionName, MapParam param) {
|
||||
Long maxDocsPerSegment = param.getMaxDocsPerSegment();
|
||||
CollectionConfig config = CollectionConfig.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withMaxDocsPerSegment(maxDocsPerSegment)
|
||||
.withForwardColumnNames(Arrays.asList(Constant.ColumnNameFaceId))
|
||||
.addIndexColumnParam(Constant.ColumnNameFaceIndex, DataType.VECTOR_FP32, 512).build();
|
||||
Status status = client.createCollection(config);
|
||||
if(status.ok()){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(status.getReason());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean insertVector(String collectionName, Long keyId, String faceID, float[] vectors) {
|
||||
WriteRequest.Row insertRow = WriteRequest.Row.newBuilder()
|
||||
.withPrimaryKey(keyId)
|
||||
.addIndexValue(vectors)
|
||||
.addForwardValue(faceID)
|
||||
.withOperationType(WriteRequest.OperationType.INSERT)
|
||||
.build();
|
||||
WriteRequest writeRequest = WriteRequest.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withForwardColumnList(Collections.singletonList(Constant.ColumnNameFaceId))
|
||||
.addIndexColumnMeta(Constant.ColumnNameFaceIndex, DataType.VECTOR_FP32, 512)
|
||||
.addRow(insertRow)
|
||||
.build();
|
||||
Status status = client.write(writeRequest);
|
||||
if(status.ok()){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(status.getReason());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean deleteVectorByKey(String collectionName, Long keyId) {
|
||||
WriteRequest.Row deleteRow = WriteRequest.Row.newBuilder()
|
||||
.withPrimaryKey(keyId)
|
||||
.withOperationType(WriteRequest.OperationType.DELETE)
|
||||
.build();
|
||||
WriteRequest writeRequest = WriteRequest.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withForwardColumnList(Collections.singletonList(Constant.ColumnNameFaceId))
|
||||
.addIndexColumnMeta(Constant.ColumnNameFaceIndex, DataType.VECTOR_FP32, 512)
|
||||
.addRow(deleteRow)
|
||||
.build();
|
||||
Status status = client.write(writeRequest);
|
||||
if(status.ok()){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(status.getReason());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean deleteVectorByKey(String collectionName, List<Long> keyIds) {
|
||||
if(null == keyIds || keyIds.isEmpty()){
|
||||
return false;
|
||||
}
|
||||
List<Long> deleteIds = new ArrayList<>();
|
||||
for(Long keyId : keyIds){
|
||||
GetDocumentRequest r = GetDocumentRequest.newBuilder().withCollectionName(collectionName).withPrimaryKey(keyId).build();
|
||||
GetDocumentResponse p = client.getDocumentByKey(r);
|
||||
if(p.ok() && p.getDocument().getPrimaryKey() == keyId){
|
||||
deleteIds.add(keyId);
|
||||
}
|
||||
}
|
||||
if(deleteIds.isEmpty()){
|
||||
return true;
|
||||
}
|
||||
WriteRequest.Builder builder = WriteRequest.newBuilder().withCollectionName(collectionName);
|
||||
for(Long keyId : deleteIds){
|
||||
WriteRequest.Row deleteRow = WriteRequest.Row.newBuilder().withPrimaryKey(keyId).withOperationType(WriteRequest.OperationType.DELETE).build();
|
||||
builder.addRow(deleteRow);
|
||||
}
|
||||
Status status = client.write(builder.build());
|
||||
if(status.ok()){
|
||||
return true;
|
||||
}else{
|
||||
throw new RuntimeException(status.getReason());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchResponse search(String collectionName, float[][] features, int topK) {
|
||||
QueryRequest queryRequest = QueryRequest.newBuilder()
|
||||
.withCollectionName(collectionName)
|
||||
.withKnnQueryParam(
|
||||
QueryRequest.KnnQueryParam.newBuilder()
|
||||
.withColumnName(Constant.ColumnNameFaceIndex)
|
||||
.withTopk(topK)
|
||||
.withFeatures(features)
|
||||
.build())
|
||||
.build();
|
||||
//搜索向量
|
||||
QueryResponse queryResponse;
|
||||
try {
|
||||
queryResponse = client.query(queryRequest);
|
||||
} catch (Exception e) {
|
||||
return SearchResponse.build(SearchStatus.build(1, e.getMessage()), null);
|
||||
}
|
||||
//搜索失败
|
||||
if(!queryResponse.ok()){
|
||||
return SearchResponse.build(SearchStatus.build(1, "response status is error"), null);
|
||||
}
|
||||
//转换对象
|
||||
List<SearchResult> result = new ArrayList<>();
|
||||
SearchStatus status = SearchStatus.build(0, "success");
|
||||
for (int i = 0; i < queryResponse.getQueryResultCount(); ++i) {
|
||||
List<SearchDocument> documents = new ArrayList<>();
|
||||
QueryResult queryResult = queryResponse.getQueryResult(i);
|
||||
for (int d = 0; d < queryResult.getDocumentCount(); ++d) {
|
||||
Document document = queryResult.getDocument(d);
|
||||
long primaryKey = document.getPrimaryKey();
|
||||
float score = document.getScore();
|
||||
Set<String> forwardKeys = document.getForwardKeySet();
|
||||
String faceId = null;
|
||||
if(forwardKeys.contains(Constant.ColumnNameFaceId)){
|
||||
faceId = document.getForwardValue(Constant.ColumnNameFaceId).getStringValue();
|
||||
}
|
||||
documents.add(SearchDocument.build(primaryKey, score, faceId));
|
||||
}
|
||||
result.add(SearchResult.build(documents));
|
||||
}
|
||||
//返回
|
||||
return SearchResponse.build(status, result);
|
||||
}
|
||||
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
package com.visual.face.search.server.engine.model;
|
||||
|
||||
public class SearchDocument {
|
||||
private long primaryKey;
|
||||
private float score;
|
||||
private String faceId;
|
||||
|
||||
public SearchDocument(){}
|
||||
|
||||
public SearchDocument(long primaryKey, float score, String faceId) {
|
||||
this.primaryKey = primaryKey;
|
||||
this.score = score;
|
||||
this.faceId = faceId;
|
||||
}
|
||||
|
||||
public static SearchDocument build(long primaryKey, float score, String faceId){
|
||||
return new SearchDocument(primaryKey, score, faceId);
|
||||
}
|
||||
|
||||
public long getPrimaryKey() {
|
||||
return primaryKey;
|
||||
}
|
||||
|
||||
public void setPrimaryKey(long primaryKey) {
|
||||
this.primaryKey = primaryKey;
|
||||
}
|
||||
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public void setScore(float score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public String getFaceId() {
|
||||
return faceId;
|
||||
}
|
||||
|
||||
public void setFaceId(String faceId) {
|
||||
this.faceId = faceId;
|
||||
}
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package com.visual.face.search.server.mapper;
|
||||
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface AdminMapper {
|
||||
|
||||
|
||||
|
||||
}
|
@ -44,7 +44,7 @@ public interface CollectMapper {
|
||||
"and namespace = #{namespace,jdbcType=VARCHAR} ",
|
||||
"and collection = #{collection,jdbcType=VARCHAR}"
|
||||
})
|
||||
int deleteByName(@Param("namespace")String namespace, @Param("collection")String collection);
|
||||
int deleteByName(@Param("namespace") String namespace, @Param("collection") String collection);
|
||||
|
||||
@Select({
|
||||
"select",
|
||||
@ -72,7 +72,7 @@ public interface CollectMapper {
|
||||
@Result(column="update_time", property="updateTime", jdbcType=JdbcType.TIMESTAMP),
|
||||
@Result(column="deleted", property="deleted", jdbcType=JdbcType.CHAR)
|
||||
})
|
||||
Collection selectByName(@Param("namespace")String namespace, @Param("collection")String collection);
|
||||
Collection selectByName(@Param("namespace") String namespace, @Param("collection") String collection);
|
||||
|
||||
@Select({
|
||||
"select",
|
||||
@ -98,5 +98,5 @@ public interface CollectMapper {
|
||||
@Result(column="update_time", property="updateTime", jdbcType=JdbcType.TIMESTAMP),
|
||||
@Result(column="deleted", property="deleted", jdbcType=JdbcType.CHAR)
|
||||
})
|
||||
List<Collection> selectByNamespace(@Param("namespace")String namespace);
|
||||
List<Collection> selectByNamespace(@Param("namespace") String namespace);
|
||||
}
|
||||
|
@ -61,25 +61,25 @@ public interface FaceDataMapper {
|
||||
int deleteById(@Param("table") String table, @Param("id") Long id);
|
||||
|
||||
@Delete({"delete from ${table} where sample_id = #{sampleId,jdbcType=VARCHAR}"})
|
||||
int deleteBySampleId(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
int deleteBySampleId(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Delete({"delete from ${table} where sample_id = #{sampleId,jdbcType=VARCHAR} and face_id=#{faceId,jdbcType=VARCHAR}"})
|
||||
int deleteByFaceId(@Param("table") String table, @Param("sampleId")String sampleId, @Param("faceId")String faceId);
|
||||
int deleteByFaceId(@Param("table") String table, @Param("sampleId") String sampleId, @Param("faceId") String faceId);
|
||||
|
||||
@Select({"select count(1) from ${table} where sample_id = '${sampleId}'"})
|
||||
long countBySampleId(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
long countBySampleId(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Select({"select count(1) from ${table} where sample_id = '${sampleId}' and face_id=#{faceId,jdbcType=VARCHAR}"})
|
||||
long count(@Param("table") String table, @Param("sampleId")String sampleId, @Param("faceId")String faceId);
|
||||
long count(@Param("table") String table, @Param("sampleId") String sampleId, @Param("faceId") String faceId);
|
||||
|
||||
@Select({"select id from ${table} where sample_id = '${sampleId}' and face_id=#{faceId,jdbcType=VARCHAR}"})
|
||||
Long getIdByFaceId(@Param("table") String table, @Param("sampleId")String sampleId, @Param("faceId")String faceId);
|
||||
Long getIdByFaceId(@Param("table") String table, @Param("sampleId") String sampleId, @Param("faceId") String faceId);
|
||||
|
||||
@Select({"select id from ${table} where sample_id=#{sampleId,jdbcType=VARCHAR}"})
|
||||
List<Long> getIdBySampleId(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
@Select({"select face_id from ${table} where sample_id=#{sampleId,jdbcType=VARCHAR}"})
|
||||
List<String> getFaceIdBySampleId(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Select({"select * from ${table} where sample_id = #{sampleId,jdbcType=VARCHAR}"})
|
||||
List<Map<String, Object>> getBySampleId(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
List<Map<String, Object>> getBySampleId(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Select({
|
||||
"<script>",
|
||||
@ -89,20 +89,24 @@ public interface FaceDataMapper {
|
||||
"</foreach>",
|
||||
"</script>"
|
||||
})
|
||||
List<Map<String, Object>> getBySampleIds(@Param("table") String table, @Param("sampleIds")List<String> sampleIds);
|
||||
List<Map<String, Object>> getBySampleIds(@Param("table") String table, @Param("sampleIds") List<String> sampleIds);
|
||||
|
||||
@Select({"select * from ${table} where sample_id = #{sampleId,jdbcType=VARCHAR} and face_id=#{faceId,jdbcType=VARCHAR}"})
|
||||
Map<String, Object> getByFaceId(@Param("table") String table, @Param("sampleId")String sampleId, @Param("faceId")String faceId);
|
||||
Map<String, Object> getByFaceId(@Param("table") String table, @Param("sampleId") String sampleId, @Param("faceId") String faceId);
|
||||
|
||||
@Select({
|
||||
"<script>",
|
||||
"select * from ${table} where face_id in ",
|
||||
"select",
|
||||
"<foreach item=\"item\" index=\"index\" collection=\"columns\" open=\"\" separator=\",\" close=\"\">",
|
||||
"${item}",
|
||||
"</foreach>",
|
||||
"from ${table} where face_id in ",
|
||||
"<foreach collection=\"faceIds\" item=\"item\" index=\"index\" open=\"(\" separator=\",\" close=\")\">",
|
||||
"#{item,jdbcType=VARCHAR}",
|
||||
"</foreach>",
|
||||
"</script>"
|
||||
})
|
||||
List<Map<String, Object>> getByFaceIds(@Param("table") String table, @Param("faceIds")List<String> faceIds);
|
||||
List<Map<String, Object>> getByFaceIds(@Param("table") String table, @Param("columns") List<String> columns, @Param("faceIds") List<String> faceIds);
|
||||
|
||||
@Select({
|
||||
"<script>",
|
||||
@ -112,6 +116,6 @@ public interface FaceDataMapper {
|
||||
"</foreach>",
|
||||
"</script>"
|
||||
})
|
||||
List<Map<String, Object>> getByPrimaryIds(@Param("table") String table, @Param("keyIds")List<Long> keyIds);
|
||||
List<Map<String, Object>> getByPrimaryIds(@Param("table") String table, @Param("keyIds") List<Long> keyIds);
|
||||
|
||||
}
|
||||
|
@ -18,6 +18,6 @@ public interface ImageDataMapper {
|
||||
"#{image.createTime,jdbcType=TIMESTAMP}, #{image.modifyTime,jdbcType=TIMESTAMP})"
|
||||
})
|
||||
@SelectKey(statement="SELECT LAST_INSERT_ID()", keyProperty="id", before=false, resultType=Long.class)
|
||||
int insert(@Param("table")String table, @Param("image")ImageData record);
|
||||
int insert(@Param("table") String table, @Param("image") ImageData record);
|
||||
|
||||
}
|
||||
|
@ -12,10 +12,10 @@ import com.visual.face.search.server.model.TableColumn;
|
||||
public interface OperateTableMapper {
|
||||
|
||||
@Select({"SHOW TABLES LIKE '${table}'"})
|
||||
String showTable(@Param("table")String table);
|
||||
String showTable(@Param("table") String table);
|
||||
|
||||
@Update({ "DROP TABLE IF EXISTS ${table}"})
|
||||
int dropTable(@Param("table")String table);
|
||||
int dropTable(@Param("table") String table);
|
||||
|
||||
@Update({
|
||||
"<script>",
|
||||
|
@ -50,7 +50,7 @@ public interface SampleDataMapper {
|
||||
int create(@Param("table") String table, @Param("sample") SampleData sample, @Param("columnValues") List<ColumnValue> columnValues);
|
||||
|
||||
@Select({"select count(1) from ${table} where sample_id = '${sampleId}'"})
|
||||
long count(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
long count(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Update({
|
||||
"<script>",
|
||||
@ -79,16 +79,16 @@ public interface SampleDataMapper {
|
||||
"where sample_id = #{sampleId,jdbcType=BIGINT}",
|
||||
"</script>"
|
||||
})
|
||||
int update(@Param("table") String table, @Param("sampleId")String sampleId, @Param("columnValues") List<ColumnValue> columnValues);
|
||||
int update(@Param("table") String table, @Param("sampleId") String sampleId, @Param("columnValues") List<ColumnValue> columnValues);
|
||||
|
||||
@Delete({"delete from ${table} where sample_id = #{sampleId,jdbcType=VARCHAR}",})
|
||||
int delete(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
int delete(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Select({"select * from ${table} where sample_id = #{sampleId,jdbcType=VARCHAR}"})
|
||||
Map<String, Object> getBySampleId(@Param("table") String table, @Param("sampleId")String sampleId);
|
||||
Map<String, Object> getBySampleId(@Param("table") String table, @Param("sampleId") String sampleId);
|
||||
|
||||
@Select({"select * from ${table} order by id ${order} limit ${offset}, ${limit}"})
|
||||
List<Map<String, Object>> getBySampleList(@Param("table") String table, @Param("offset")Integer offset, @Param("limit")Integer limit, @Param("order")String order);
|
||||
List<Map<String, Object>> getBySampleList(@Param("table") String table, @Param("offset") Integer offset, @Param("limit") Integer limit, @Param("order") String order);
|
||||
|
||||
@Select({
|
||||
"<script>",
|
||||
@ -98,5 +98,5 @@ public interface SampleDataMapper {
|
||||
"</foreach>",
|
||||
"</script>"
|
||||
})
|
||||
List<Map<String, Object>> getBySampleIds(@Param("table") String table, @Param("sampleIds")List<String> sampleIds);
|
||||
List<Map<String, Object>> getBySampleIds(@Param("table") String table, @Param("sampleIds") List<String> sampleIds);
|
||||
}
|
||||
|
@ -1,79 +1,79 @@
|
||||
package com.visual.face.search.server.scheduler;
|
||||
|
||||
import io.milvus.param.R;
|
||||
import io.milvus.grpc.FlushResponse;
|
||||
import io.milvus.client.MilvusServiceClient;
|
||||
import io.milvus.param.collection.FlushParam;
|
||||
import io.milvus.param.collection.HasCollectionParam;
|
||||
import com.visual.face.search.server.utils.VTableCache;
|
||||
import com.alibaba.proxima.be.client.ProximaSearchClient;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import javax.annotation.Resource;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
|
||||
/**
|
||||
* 用于刷新数据,Milvus不刷新数据可能导致数据丢失
|
||||
*/
|
||||
@Component
|
||||
public class FlushScheduler {
|
||||
|
||||
private static final Integer SUCCESS_STATUE = 0;
|
||||
public Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@Resource
|
||||
private SearchEngine searchEngine;
|
||||
@Value("${visual.scheduler.flush.enable:true}")
|
||||
private boolean enable;
|
||||
|
||||
@Scheduled(fixedDelayString = "${visual.scheduler.flush.interval:300000}")
|
||||
public void flush(){
|
||||
if(enable){
|
||||
Object client = searchEngine.getEngine();
|
||||
if(client instanceof MilvusServiceClient){
|
||||
this.flushMilvus((MilvusServiceClient) client);
|
||||
}else if(client instanceof ProximaSearchClient){
|
||||
this.flushProxima((ProximaSearchClient) client);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新数据到数据库中
|
||||
* @param client
|
||||
*/
|
||||
private void flushMilvus(MilvusServiceClient client){
|
||||
VTableCache.getVectorTables().forEach((vectorTable) -> {
|
||||
try {
|
||||
HasCollectionParam requestParam = HasCollectionParam.newBuilder().withCollectionName(vectorTable).build();
|
||||
boolean exist = client.hasCollection(requestParam).getData();
|
||||
if(exist){
|
||||
FlushParam flushParam = FlushParam.newBuilder().addCollectionName(vectorTable).build();
|
||||
R<FlushResponse> response = client.flush(flushParam);
|
||||
if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
VTableCache.remove(vectorTable);
|
||||
logger.info("flushMilvus success: table is {}", vectorTable);
|
||||
}else{
|
||||
throw new RuntimeException("FlushResponse Error");
|
||||
}
|
||||
}
|
||||
}catch (Exception e){
|
||||
logger.error("flushMilvus error:", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新数据到数据库中
|
||||
* @param client
|
||||
*/
|
||||
private void flushProxima(ProximaSearchClient client){
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
//package com.visual.face.search.server.scheduler;
|
||||
//
|
||||
//import io.milvus.param.R;
|
||||
//import io.milvus.grpc.FlushResponse;
|
||||
//import io.milvus.client.MilvusServiceClient;
|
||||
//import io.milvus.param.collection.FlushParam;
|
||||
//import io.milvus.param.collection.HasCollectionParam;
|
||||
//import com.visual.face.search.server.utils.VTableCache;
|
||||
//import com.alibaba.proxima.be.client.ProximaSearchClient;
|
||||
//import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
//
|
||||
//import org.slf4j.Logger;
|
||||
//import org.slf4j.LoggerFactory;
|
||||
//import javax.annotation.Resource;
|
||||
//
|
||||
//import org.springframework.beans.factory.annotation.Value;
|
||||
//import org.springframework.stereotype.Component;
|
||||
//import org.springframework.scheduling.annotation.Scheduled;
|
||||
//
|
||||
///**
|
||||
// * 用于刷新数据,Milvus不刷新数据可能导致数据丢失
|
||||
// */
|
||||
//@Component
|
||||
//public class FlushScheduler {
|
||||
//
|
||||
// private static final Integer SUCCESS_STATUE = 0;
|
||||
// public Logger logger = LoggerFactory.getLogger(getClass());
|
||||
//
|
||||
// @Resource
|
||||
// private SearchEngine searchEngine;
|
||||
// @Value("${visual.scheduler.flush.enable:true}")
|
||||
// private boolean enable;
|
||||
//
|
||||
// @Scheduled(fixedDelayString = "${visual.scheduler.flush.interval:300000}")
|
||||
// public void flush(){
|
||||
// if(enable){
|
||||
// Object client = searchEngine.getEngine();
|
||||
// if(client instanceof MilvusServiceClient){
|
||||
// this.flushMilvus((MilvusServiceClient) client);
|
||||
// }else if(client instanceof ProximaSearchClient){
|
||||
// this.flushProxima((ProximaSearchClient) client);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 刷新数据到数据库中
|
||||
// * @param client
|
||||
// */
|
||||
// private void flushMilvus(MilvusServiceClient client){
|
||||
// VTableCache.getVectorTables().forEach((vectorTable) -> {
|
||||
// try {
|
||||
// HasCollectionParam requestParam = HasCollectionParam.newBuilder().withCollectionName(vectorTable).build();
|
||||
// boolean exist = client.hasCollection(requestParam).getData();
|
||||
// if(exist){
|
||||
// FlushParam flushParam = FlushParam.newBuilder().addCollectionName(vectorTable).build();
|
||||
// R<FlushResponse> response = client.flush(flushParam);
|
||||
// if(SUCCESS_STATUE.equals(response.getStatus())){
|
||||
// VTableCache.remove(vectorTable);
|
||||
// logger.info("flushMilvus success: table is {}", vectorTable);
|
||||
// }else{
|
||||
// throw new RuntimeException("FlushResponse Error");
|
||||
// }
|
||||
// }
|
||||
// }catch (Exception e){
|
||||
// logger.error("flushMilvus error:", e);
|
||||
// }
|
||||
// });
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 刷新数据到数据库中
|
||||
// * @param client
|
||||
// */
|
||||
// private void flushProxima(ProximaSearchClient client){
|
||||
//
|
||||
// }
|
||||
//
|
||||
//}
|
||||
|
@ -0,0 +1,5 @@
|
||||
package com.visual.face.search.server.service.api;
|
||||
|
||||
public interface AdminService {
|
||||
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package com.visual.face.search.server.service.impl;
|
||||
|
||||
import com.visual.face.search.server.service.api.AdminService;
|
||||
import com.visual.face.search.server.service.base.BaseService;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service("adminCollectService")
|
||||
public class AdminServiceImpl extends BaseService implements AdminService {
|
||||
|
||||
|
||||
}
|
@ -5,10 +5,9 @@ import com.visual.face.search.core.utils.JsonUtil;
|
||||
import com.visual.face.search.core.utils.ThreadUtil;
|
||||
import com.visual.face.search.server.domain.request.CollectReqVo;
|
||||
import com.visual.face.search.server.domain.response.CollectRepVo;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.server.engine.model.MapParam;
|
||||
//import com.visual.face.search.server.mapper.CollectMapper;
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import com.visual.face.search.engine.model.MapParam;
|
||||
import com.visual.face.search.server.mapper.CollectMapper;
|
||||
import com.visual.face.search.server.model.Collection;
|
||||
import com.visual.face.search.server.service.api.CollectService;
|
||||
@ -21,12 +20,10 @@ import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Propagation;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.transaction.interceptor.TransactionAspectSupport;
|
||||
import org.springframework.util.DigestUtils;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
|
||||
@Service("visualCollectService")
|
||||
@ -91,8 +88,8 @@ public class CollectServiceImpl extends BaseService implements CollectService {
|
||||
}
|
||||
//创建人脸向量库
|
||||
MapParam param = MapParam.build()
|
||||
.put(Constant.ParamKeyShardsNum, collect.getShardsNum())
|
||||
.put(Constant.ParamKeyMaxDocsPerSegment, collect.getMaxDocsPerSegment());
|
||||
.put(Constant.IndexShardsNum, collect.getShardsNum())
|
||||
.put(Constant.IndexReplicasNum, collect.getReplicasNum());
|
||||
boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param);
|
||||
if(!createVectorFlag){
|
||||
throw new RuntimeException("create vector table error");
|
||||
|
@ -1,12 +1,13 @@
|
||||
package com.visual.face.search.server.service.impl;
|
||||
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.core.domain.ExtParam;
|
||||
import com.visual.face.search.core.domain.FaceImage;
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.extract.FaceFeatureExtractor;
|
||||
import com.visual.face.search.core.utils.JsonUtil;
|
||||
import com.visual.face.search.core.utils.Similarity;
|
||||
import com.visual.face.search.server.domain.request.SearchAlgorithm;
|
||||
import com.visual.face.search.server.domain.storage.StorageDataInfo;
|
||||
import com.visual.face.search.server.domain.storage.StorageImageInfo;
|
||||
import com.visual.face.search.server.domain.storage.StorageInfo;
|
||||
@ -14,11 +15,6 @@ import com.visual.face.search.server.domain.extend.FieldKeyValue;
|
||||
import com.visual.face.search.server.domain.extend.FieldKeyValues;
|
||||
import com.visual.face.search.server.domain.request.FaceDataReqVo;
|
||||
import com.visual.face.search.server.domain.response.FaceDataRepVo;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.server.engine.model.SearchDocument;
|
||||
import com.visual.face.search.server.engine.model.SearchResponse;
|
||||
import com.visual.face.search.server.engine.model.SearchResult;
|
||||
import com.visual.face.search.server.mapper.CollectMapper;
|
||||
import com.visual.face.search.server.mapper.FaceDataMapper;
|
||||
import com.visual.face.search.server.mapper.SampleDataMapper;
|
||||
@ -33,8 +29,6 @@ import com.visual.face.search.server.service.base.BaseService;
|
||||
import com.visual.face.search.server.utils.CollectionUtil;
|
||||
import com.visual.face.search.server.utils.TableUtils;
|
||||
import com.visual.face.search.server.utils.VTableCache;
|
||||
import com.visual.face.search.server.utils.ValueUtil;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Propagation;
|
||||
@ -101,45 +95,18 @@ public class FaceDataServiceImpl extends BaseService implements FaceDataService
|
||||
float[] embeds = faceInfo.embedding.embeds;
|
||||
//当前样本的人脸相似度的最小阈值
|
||||
if(null != face.getMinConfidenceThresholdWithThisSample() && face.getMinConfidenceThresholdWithThisSample() > 0){
|
||||
List<Map<String, Object>> faces = faceDataMapper.getBySampleId(collection.getFaceTable(), face.getSampleId());
|
||||
for(Map<String, Object> item : faces){
|
||||
String faceVectorStr = MapUtils.getString(item, Constant.ColumnNameFaceVector);
|
||||
float[] faceVector = ValueUtil.convertVector(faceVectorStr);
|
||||
float simVal = Similarity.cosineSimilarityNorm(embeds, faceVector);
|
||||
float confidence = (float) Math.floor(simVal * 10000)/100;
|
||||
if(confidence < face.getMinConfidenceThresholdWithThisSample()){
|
||||
throw new RuntimeException("this face confidence is less than minConfidenceThresholdWithThisSample,confidence="+confidence+",threshold="+face.getMinConfidenceThresholdWithThisSample());
|
||||
}
|
||||
float minScore = this.searchEngine.searchMinScoreBySampleId(collection.getVectorTable(), face.getSampleId(), embeds, SearchAlgorithm.COSINESIMIL.algorithm());
|
||||
float confidence = (float) Math.floor(minScore * 10000)/100;
|
||||
if(confidence < face.getMinConfidenceThresholdWithThisSample()){
|
||||
throw new RuntimeException("this face confidence is less than minConfidenceThresholdWithThisSample,confidence="+confidence+",threshold="+face.getMinConfidenceThresholdWithThisSample());
|
||||
}
|
||||
}
|
||||
//当前样本与其他样本的人脸相似度的最大阈值
|
||||
if(null != face.getMaxConfidenceThresholdWithOtherSample() && face.getMaxConfidenceThresholdWithOtherSample() > 0){
|
||||
//查询
|
||||
List<Long> otherFaceIds = new ArrayList<>();
|
||||
List<Long> faceIds = faceDataMapper.getIdBySampleId(collection.getFaceTable(), face.getSampleId());
|
||||
int topK = faceIds.size() + 2;
|
||||
float [][] vectors = new float[1][]; vectors[0] = embeds;
|
||||
SearchResponse response = searchEngine.search(collection.getVectorTable(), vectors, topK);
|
||||
if(response.getStatus().ok()){
|
||||
for(SearchResult result : response.getResult()){
|
||||
for(SearchDocument document : result.getDocuments()){
|
||||
if(!faceIds.contains(document.getPrimaryKey())){
|
||||
otherFaceIds.add(document.getPrimaryKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!otherFaceIds.isEmpty()){
|
||||
List<Map<String, Object>> faces = faceDataMapper.getByPrimaryIds(collection.getFaceTable(), otherFaceIds);
|
||||
for(Map<String, Object> item : faces){
|
||||
String faceVectorStr = MapUtils.getString(item, Constant.ColumnNameFaceVector);
|
||||
float[] faceVector = ValueUtil.convertVector(faceVectorStr);
|
||||
float simVal = Similarity.cosineSimilarityNorm(embeds, faceVector);
|
||||
float confidence = (float) Math.floor(simVal * 10000)/100;
|
||||
if(confidence > face.getMaxConfidenceThresholdWithOtherSample()){
|
||||
throw new RuntimeException("this face confidence is gather than maxConfidenceThresholdWithOtherSample,confidence="+confidence+",threshold="+face.getMaxConfidenceThresholdWithOtherSample());
|
||||
}
|
||||
}
|
||||
float minScore = this.searchEngine.searchMaxScoreBySampleId(collection.getVectorTable(), face.getSampleId(), embeds, SearchAlgorithm.COSINESIMIL.algorithm());
|
||||
float confidence = (float) Math.floor(minScore * 10000)/100;
|
||||
if(confidence > face.getMaxConfidenceThresholdWithOtherSample()){
|
||||
throw new RuntimeException("this face confidence is gather than maxConfidenceThresholdWithOtherSample,confidence="+confidence+",threshold="+face.getMaxConfidenceThresholdWithOtherSample());
|
||||
}
|
||||
}
|
||||
//保存图片信息并获取图片存储
|
||||
@ -188,7 +155,7 @@ public class FaceDataServiceImpl extends BaseService implements FaceDataService
|
||||
throw new RuntimeException("create face error");
|
||||
}
|
||||
//写入数据到人脸向量库
|
||||
boolean flag1 = searchEngine.insertVector(collection.getVectorTable(), facePo.getId(), faceId, embeds);
|
||||
boolean flag1 = searchEngine.insertVector(collection.getVectorTable(), face.getSampleId(), faceId, embeds);
|
||||
if(!flag1){
|
||||
throw new RuntimeException("create face vector error");
|
||||
}
|
||||
@ -218,7 +185,7 @@ public class FaceDataServiceImpl extends BaseService implements FaceDataService
|
||||
throw new RuntimeException("face id is not exist");
|
||||
}
|
||||
//删除向量
|
||||
boolean delete = searchEngine.deleteVectorByKey(collection.getVectorTable(), keyId);
|
||||
boolean delete = searchEngine.deleteVectorByKey(collection.getVectorTable(), faceId);
|
||||
if(!delete){
|
||||
throw new RuntimeException("delete face vector error");
|
||||
}
|
||||
|
@ -12,12 +12,13 @@ import com.visual.face.search.core.utils.Similarity;
|
||||
import com.visual.face.search.server.domain.extend.FaceLocation;
|
||||
import com.visual.face.search.server.domain.extend.SampleFaceVo;
|
||||
import com.visual.face.search.server.domain.request.FaceSearchReqVo;
|
||||
import com.visual.face.search.server.domain.request.SearchAlgorithm;
|
||||
import com.visual.face.search.server.domain.response.FaceSearchRepVo;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.server.engine.model.SearchDocument;
|
||||
import com.visual.face.search.server.engine.model.SearchResponse;
|
||||
import com.visual.face.search.server.engine.model.SearchResult;
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import com.visual.face.search.engine.model.SearchDocument;
|
||||
import com.visual.face.search.engine.model.SearchResponse;
|
||||
import com.visual.face.search.engine.model.SearchResult;
|
||||
import com.visual.face.search.server.mapper.CollectMapper;
|
||||
import com.visual.face.search.server.mapper.FaceDataMapper;
|
||||
import com.visual.face.search.server.mapper.SampleDataMapper;
|
||||
@ -82,7 +83,7 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
|
||||
}
|
||||
//特征搜索
|
||||
int topK = (null == search.getLimit() || search.getLimit() <= 0) ? 5 : search.getLimit();
|
||||
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, topK);
|
||||
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK);
|
||||
if(!searchResponse.getStatus().ok()){
|
||||
throw new RuntimeException(searchResponse.getStatus().getReason());
|
||||
}
|
||||
@ -106,32 +107,22 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
|
||||
return vos;
|
||||
}
|
||||
//获取关联数据ID
|
||||
boolean needFixFaceId = false;
|
||||
Set<Long> faceIds = new HashSet<>();
|
||||
Set<String> faceIds = new HashSet<>();
|
||||
for(SearchResult searchResult : result){
|
||||
List<SearchDocument> documents = searchResult.getDocuments();
|
||||
for(SearchDocument document : documents){
|
||||
faceIds.add(document.getPrimaryKey());
|
||||
if(null == document.getFaceId() || document.getFaceId().isEmpty()){
|
||||
needFixFaceId = true;
|
||||
}
|
||||
faceIds.add(document.getFaceId());
|
||||
}
|
||||
}
|
||||
//查询数据
|
||||
List<Map<String, Object>> faceList = faceDataMapper.getByPrimaryIds(collection.getFaceTable(), new ArrayList<>(faceIds));
|
||||
Set<String> sampleIds = faceList.stream().map(item -> MapUtils.getString(item, Constant.ColumnNameSampleId)).collect(Collectors.toSet());
|
||||
List<Map<String, Object>> sampleList = sampleDataMapper.getBySampleIds(collection.getSampleTable(), new ArrayList<>(sampleIds));
|
||||
Map<String, Map<String, Object>> faceMapping = ValueUtil.mapping(faceList, Constant.ColumnNameFaceId);
|
||||
Map<String, Map<String, Object>> sampleMapping = ValueUtil.mapping(sampleList, Constant.ColumnNameSampleId);
|
||||
//补全结果数据中的FaceId。由于milvus不支持字符串结构,只会返回人脸数据的主键ID
|
||||
if(needFixFaceId){
|
||||
Map<Long, String> mapping = ValueUtil.mapping(faceList, Constant.ColumnPrimaryKey, Constant.ColumnNameFaceId);
|
||||
for(SearchResult searchResult : result){
|
||||
List<SearchDocument> documents = searchResult.getDocuments();
|
||||
for(SearchDocument document : documents){
|
||||
document.setFaceId(mapping.get(document.getPrimaryKey()));
|
||||
}
|
||||
}
|
||||
Map<String, Map<String, Object>> faceMapping = new HashMap<>();
|
||||
Map<String, Map<String, Object>> sampleMapping = new HashMap<>();
|
||||
if(faceIds.size() > 0){
|
||||
List<Map<String, Object>> faceList = faceDataMapper.getByFaceIds(collection.getFaceTable(), ValueUtil.getAllFaceColumnNames(collection), new ArrayList<>(faceIds));
|
||||
Set<String> sampleIds = faceList.stream().map(item -> MapUtils.getString(item, Constant.ColumnNameSampleId)).collect(Collectors.toSet());
|
||||
List<Map<String, Object>> sampleList = sampleDataMapper.getBySampleIds(collection.getSampleTable(), new ArrayList<>(sampleIds));
|
||||
faceMapping = ValueUtil.mapping(faceList, Constant.ColumnNameFaceId);
|
||||
sampleMapping = ValueUtil.mapping(sampleList, Constant.ColumnNameSampleId);
|
||||
}
|
||||
//构造返回结果
|
||||
List<FaceSearchRepVo> vos = new ArrayList<>();
|
||||
@ -148,10 +139,12 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
|
||||
if(null != face){
|
||||
float faceScore = MapUtils.getFloatValue(face, Constant.ColumnNameFaceScore);
|
||||
String sampleId = MapUtils.getString(face, Constant.ColumnNameSampleId);
|
||||
String faceVectorStr = MapUtils.getString(face, Constant.ColumnNameFaceVector);
|
||||
float[] faceVector = ValueUtil.convertVector(faceVectorStr);
|
||||
float simVal = Similarity.cosineSimilarityNorm(faceInfos.get(i).embedding.embeds, faceVector);
|
||||
float confidence = (float) Math.floor(simVal * 1000000)/10000;
|
||||
float score = document.getScore();
|
||||
float confidence = score;
|
||||
if(SearchAlgorithm.COSINESIMIL == search.getAlgorithm()){
|
||||
score = Similarity.cosEnhance(score);
|
||||
confidence = (float) Math.floor(score * 1000000)/10000;
|
||||
}
|
||||
if(null != sampleId && sampleMapping.containsKey(sampleId) && confidence >= search.getConfidenceThreshold()){
|
||||
Map<String, Object> sample = sampleMapping.get(sampleId);
|
||||
SampleFaceVo faceVo = SampleFaceVo.build();
|
||||
@ -159,7 +152,6 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
|
||||
faceVo.setFaceId(document.getFaceId());
|
||||
faceVo.setFaceScore(faceScore);
|
||||
faceVo.setConfidence(confidence);
|
||||
faceVo.setDistance((float) Math.floor(document.getScore() * 10000) / 10000);
|
||||
faceVo.setFaceData(ValueUtil.getFieldKeyValues(face, ValueUtil.getFaceColumns(collection)));
|
||||
faceVo.setSampleData(ValueUtil.getFieldKeyValues(sample, ValueUtil.getSampleColumns(collection)));
|
||||
match.add(faceVo);
|
||||
|
@ -1,12 +1,12 @@
|
||||
package com.visual.face.search.server.service.impl;
|
||||
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.domain.extend.FieldKeyValue;
|
||||
import com.visual.face.search.server.domain.extend.FieldKeyValues;
|
||||
import com.visual.face.search.server.domain.extend.SimpleFaceVo;
|
||||
import com.visual.face.search.server.domain.request.SampleDataReqVo;
|
||||
import com.visual.face.search.server.domain.response.SampleDataRepVo;
|
||||
import com.visual.face.search.server.engine.api.SearchEngine;
|
||||
import com.visual.face.search.server.engine.conf.Constant;
|
||||
import com.visual.face.search.server.mapper.CollectMapper;
|
||||
import com.visual.face.search.server.mapper.FaceDataMapper;
|
||||
import com.visual.face.search.server.mapper.SampleDataMapper;
|
||||
@ -21,7 +21,6 @@ import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Propagation;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
@ -128,7 +127,7 @@ public class SampleDataServiceImpl extends BaseService implements SampleDataServ
|
||||
throw new RuntimeException("sample_id is not exist");
|
||||
}
|
||||
//删除向量数据
|
||||
List<Long> faceIds = faceDataMapper.getIdBySampleId(collection.getFaceTable(), sampleId);
|
||||
List<String> faceIds = faceDataMapper.getFaceIdBySampleId(collection.getFaceTable(), sampleId);
|
||||
searchEngine.deleteVectorByKey(collection.getVectorTable(), faceIds);
|
||||
//删除人脸数据
|
||||
faceDataMapper.deleteBySampleId(collection.getFaceTable(), sampleId);
|
||||
|
@ -39,7 +39,7 @@ public final class SpringUtils implements BeanFactoryPostProcessor, ApplicationC
|
||||
*
|
||||
* @param name
|
||||
* @return Object 一个以所给名字注册的bean的实例
|
||||
* @throws org.springframework.beans.BeansException
|
||||
* @throws BeansException
|
||||
*
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
@ -53,7 +53,7 @@ public final class SpringUtils implements BeanFactoryPostProcessor, ApplicationC
|
||||
*
|
||||
* @param clz
|
||||
* @return
|
||||
* @throws org.springframework.beans.BeansException
|
||||
* @throws BeansException
|
||||
*
|
||||
*/
|
||||
public static <T> T getBean(Class<T> clz) throws BeansException
|
||||
@ -78,7 +78,7 @@ public final class SpringUtils implements BeanFactoryPostProcessor, ApplicationC
|
||||
*
|
||||
* @param name
|
||||
* @return boolean
|
||||
* @throws org.springframework.beans.factory.NoSuchBeanDefinitionException
|
||||
* @throws NoSuchBeanDefinitionException
|
||||
*
|
||||
*/
|
||||
public static boolean isSingleton(String name) throws NoSuchBeanDefinitionException
|
||||
@ -89,7 +89,7 @@ public final class SpringUtils implements BeanFactoryPostProcessor, ApplicationC
|
||||
/**
|
||||
* @param name
|
||||
* @return Class 注册对象的类型
|
||||
* @throws org.springframework.beans.factory.NoSuchBeanDefinitionException
|
||||
* @throws NoSuchBeanDefinitionException
|
||||
*
|
||||
*/
|
||||
public static Class<?> getType(String name) throws NoSuchBeanDefinitionException
|
||||
@ -102,7 +102,7 @@ public final class SpringUtils implements BeanFactoryPostProcessor, ApplicationC
|
||||
*
|
||||
* @param name
|
||||
* @return
|
||||
* @throws org.springframework.beans.factory.NoSuchBeanDefinitionException
|
||||
* @throws NoSuchBeanDefinitionException
|
||||
*
|
||||
*/
|
||||
public static String[] getAliases(String name) throws NoSuchBeanDefinitionException
|
||||
|
@ -1,9 +1,11 @@
|
||||
package com.visual.face.search.server.utils;
|
||||
|
||||
import com.visual.face.search.core.utils.JsonUtil;
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import com.visual.face.search.server.domain.extend.FieldKeyValue;
|
||||
import com.visual.face.search.server.domain.extend.FieldKeyValues;
|
||||
import com.visual.face.search.server.domain.extend.FiledColumn;
|
||||
import com.visual.face.search.server.domain.extend.FiledDataType;
|
||||
import com.visual.face.search.server.domain.response.CollectRepVo;
|
||||
import com.visual.face.search.server.model.Collection;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
@ -12,6 +14,7 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ValueUtil {
|
||||
|
||||
@ -25,6 +28,7 @@ public class ValueUtil {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
|
||||
public static List<FiledColumn> getSampleColumns(Collection collection){
|
||||
if(null != collection.getSchemaInfo() && !collection.getSchemaInfo().isEmpty()){
|
||||
CollectRepVo collectVo = JsonUtil.toEntity(collection.getSchemaInfo(), CollectRepVo.class);
|
||||
@ -35,6 +39,19 @@ public class ValueUtil {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
public static List<String> getAllFaceColumnNames(Collection collection){
|
||||
List<FiledColumn> columns = getAllFaceColumns(collection);
|
||||
return columns.stream().map(FiledColumn::getName).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static List<FiledColumn> getAllFaceColumns(Collection collection){
|
||||
List<FiledColumn> columns = getFaceColumns(collection);
|
||||
columns.add(FiledColumn.build().setName(Constant.ColumnNameSampleId).setComment("样本ID").setDataType(FiledDataType.STRING));
|
||||
columns.add(FiledColumn.build().setName(Constant.ColumnNameFaceId).setComment("人脸ID").setDataType(FiledDataType.STRING));
|
||||
columns.add(FiledColumn.build().setName(Constant.ColumnNameFaceScore).setComment("人脸分数").setDataType(FiledDataType.FLOAT));
|
||||
return columns;
|
||||
}
|
||||
|
||||
public static FieldKeyValues getFieldKeyValues(Map<String, Object> map , List<FiledColumn> columns){
|
||||
columns = null != columns ? columns : new ArrayList<>();
|
||||
Map<String, String> keyMap = new HashMap<>();
|
||||
|
@ -22,6 +22,7 @@ logging:
|
||||
# 模型配置
|
||||
visual:
|
||||
model:
|
||||
baseModelPath: 'face-search-core/src/main/resources/'
|
||||
faceDetection:
|
||||
name: InsightScrfdFaceDetection
|
||||
modelPath:
|
||||
@ -42,14 +43,17 @@ visual:
|
||||
name: InsightArcFaceRecognition
|
||||
modelPath:
|
||||
thread: 1
|
||||
faceAttribute:
|
||||
name: InsightAttributeDetection
|
||||
modelPath:
|
||||
thread: 1
|
||||
engine:
|
||||
selected: milvus
|
||||
proxima:
|
||||
host: visual-face-search-proxima
|
||||
port: 16000
|
||||
milvus:
|
||||
host: visual-face-search-milvus
|
||||
port: 19530
|
||||
open-search:
|
||||
host: visual-face-search-opensearch
|
||||
port: 9200
|
||||
scheme: https
|
||||
username: admin
|
||||
password: admin
|
||||
scheduler:
|
||||
flush:
|
||||
enable: true
|
||||
@ -80,8 +84,8 @@ spring:
|
||||
# 主库数据源
|
||||
master:
|
||||
url: jdbc:mysql://visual-face-search-mysql:3306/visual_face_search?useUnicode=true&characterEncoding=utf8&zeroDateTimeBehavior=convertToNull&useSSL=true&serverTimezone=GMT%2B8
|
||||
username: visual
|
||||
password: visual
|
||||
username: root
|
||||
password: root
|
||||
slave:
|
||||
# 从数据源开关/默认关闭
|
||||
enabled: false
|
||||
|
@ -22,34 +22,38 @@ logging:
|
||||
# 模型配置
|
||||
visual:
|
||||
model:
|
||||
baseModelPath: ${VISUAL_MODEL_BASE_MODEL_PATH:/app/face-search/}
|
||||
faceDetection:
|
||||
name: InsightScrfdFaceDetection
|
||||
modelPath:
|
||||
thread: 4
|
||||
name: ${VISUAL_MODEL_FACEDETECTION_NAME:InsightScrfdFaceDetection}
|
||||
modelPath: ${VISUAL_MODEL_FACEDETECTION_PATH:}
|
||||
thread: ${VISUAL_MODEL_FACEDETECTION_THREAD:4}
|
||||
backup:
|
||||
name: PcnNetworkFaceDetection
|
||||
modelPath:
|
||||
thread: 4
|
||||
name: ${VISUAL_MODEL_FACEDETECTION_BACKUP_NAME:PcnNetworkFaceDetection}
|
||||
modelPath: ${VISUAL_MODEL_FACEDETECTION_BACKUP_PATH:}
|
||||
thread: ${VISUAL_MODEL_FACEDETECTION_BACKUP_THREAD:4}
|
||||
faceKeyPoint:
|
||||
name: InsightCoordFaceKeyPoint
|
||||
modelPath:
|
||||
thread: 4
|
||||
name: ${VISUAL_MODEL_FACEKEYPOINT_NAME:InsightCoordFaceKeyPoint}
|
||||
modelPath: ${VISUAL_MODEL_FACEKEYPOINT_PATH:}
|
||||
thread: ${VISUAL_MODEL_FACEKEYPOINT_THREAD:4}
|
||||
faceAlignment:
|
||||
name: Simple005pFaceAlignment
|
||||
modelPath:
|
||||
thread: 4
|
||||
name: ${VISUAL_MODEL_FACEALIGNMENT_NAME:Simple005pFaceAlignment}
|
||||
modelPath: ${VISUAL_MODEL_FACEALIGNMENT_PATH:}
|
||||
thread: ${VISUAL_MODEL_FACEALIGNMENT_THREAD:4}
|
||||
faceRecognition:
|
||||
name: InsightArcFaceRecognition
|
||||
modelPath:
|
||||
thread: 4
|
||||
name: ${VISUAL_MODEL_FACERECOGNITION_NAME:InsightArcFaceRecognition}
|
||||
modelPath: ${VISUAL_MODEL_FACERECOGNITION_PATH:}
|
||||
thread: ${VISUAL_MODEL_FACERECOGNITION_THREAD:4}
|
||||
faceAttribute:
|
||||
name: ${VISUAL_MODEL_FACEATTRIBUTE_NAME:InsightAttributeDetection}
|
||||
modelPath: ${VISUAL_MODEL_FACEATTRIBUTE_PATH:}
|
||||
thread: ${VISUAL_MODEL_FACEATTRIBUTE_THREAD:4}
|
||||
engine:
|
||||
selected: proxima
|
||||
proxima:
|
||||
host:
|
||||
port: 16000
|
||||
milvus:
|
||||
host:
|
||||
port: 19530
|
||||
open-search:
|
||||
host: ${VISUAL_ENGINE_OPENSEARCH_HOST}
|
||||
port: ${VISUAL_ENGINE_OPENSEARCH_PORT:9200}
|
||||
scheme: ${VISUAL_ENGINE_OPENSEARCH_SCHEME:https}
|
||||
username: ${VISUAL_ENGINE_OPENSEARCH_USERNAME:admin}
|
||||
password: ${VISUAL_ENGINE_OPENSEARCH_PASSWORD:admin}
|
||||
scheduler:
|
||||
flush:
|
||||
enable: true
|
||||
@ -58,7 +62,7 @@ visual:
|
||||
face-search: false
|
||||
face-compare: false
|
||||
swagger:
|
||||
enable: true
|
||||
enable: ${VISUAL_SWAGGER_ENABLE:true}
|
||||
|
||||
# Spring配置
|
||||
spring:
|
||||
@ -79,9 +83,9 @@ spring:
|
||||
druid:
|
||||
# 主库数据源
|
||||
master:
|
||||
url:
|
||||
username:
|
||||
password:
|
||||
url: ${SPRING_DATASOURCE_URL}
|
||||
username: ${SPRING_DATASOURCE_USERNAME:root}
|
||||
password: ${SPRING_DATASOURCE_PASSWORD:root}
|
||||
slave:
|
||||
# 从数据源开关/默认关闭
|
||||
enabled: false
|
||||
|
133
face-search-server/src/main/resources/application-local.yml
Executable file
@ -0,0 +1,133 @@
|
||||
# 开发环境配置
|
||||
server:
|
||||
# 服务器的HTTP端口,默认为80
|
||||
port: 8080
|
||||
servlet:
|
||||
# 应用的访问路径
|
||||
context-path: /
|
||||
tomcat:
|
||||
# tomcat的URI编码
|
||||
uri-encoding: UTF-8
|
||||
# tomcat最大线程数,默认为200
|
||||
max-threads: 10
|
||||
# Tomcat启动初始化的线程数,默认值25
|
||||
min-spare-threads: 5
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
level:
|
||||
com.visual.face.search: info
|
||||
org.springframework: warn
|
||||
|
||||
# 模型配置
|
||||
visual:
|
||||
model:
|
||||
baseModelPath: 'face-search-core/src/main/resources/'
|
||||
faceDetection:
|
||||
name: InsightScrfdFaceDetection
|
||||
modelPath:
|
||||
thread: 1
|
||||
backup:
|
||||
name: PcnNetworkFaceDetection
|
||||
modelPath:
|
||||
thread: 1
|
||||
faceKeyPoint:
|
||||
name: InsightCoordFaceKeyPoint
|
||||
modelPath:
|
||||
thread: 1
|
||||
faceAlignment:
|
||||
name: Simple005pFaceAlignment
|
||||
modelPath:
|
||||
thread: 1
|
||||
faceRecognition:
|
||||
name: InsightArcFaceRecognition
|
||||
modelPath:
|
||||
thread: 1
|
||||
faceAttribute:
|
||||
name: InsightAttributeDetection
|
||||
modelPath:
|
||||
thread: 1
|
||||
engine:
|
||||
open-search:
|
||||
host: visual-face-search-opensearch
|
||||
port: 9200
|
||||
scheme: https
|
||||
username: admin
|
||||
password: admin
|
||||
scheduler:
|
||||
flush:
|
||||
enable: true
|
||||
interval: 60000
|
||||
face-mask:
|
||||
face-search: false
|
||||
face-compare: false
|
||||
swagger:
|
||||
enable: true
|
||||
|
||||
# Spring配置
|
||||
spring:
|
||||
jackson:
|
||||
time-zone: GMT+8
|
||||
date-format: yyyy-MM-dd HH:mm:ss
|
||||
# 文件上传
|
||||
servlet:
|
||||
multipart:
|
||||
# 单个文件大小
|
||||
max-file-size: 10MB
|
||||
# 设置总上传的文件大小
|
||||
max-request-size: 20MB
|
||||
#数据源
|
||||
datasource:
|
||||
type: com.alibaba.druid.pool.DruidDataSource
|
||||
driverClassName: com.mysql.cj.jdbc.Driver
|
||||
druid:
|
||||
# 主库数据源
|
||||
master:
|
||||
url: jdbc:mysql://visual-face-search-mysql:3306/visual_face_search?useUnicode=true&characterEncoding=utf8&zeroDateTimeBehavior=convertToNull&useSSL=false&serverTimezone=GMT%2B8
|
||||
username: root
|
||||
password: root
|
||||
slave:
|
||||
# 从数据源开关/默认关闭
|
||||
enabled: false
|
||||
url:
|
||||
username:
|
||||
password:
|
||||
# 初始连接数
|
||||
initialSize: 5
|
||||
# 最小连接池数量
|
||||
minIdle: 10
|
||||
# 最大连接池数量
|
||||
maxActive: 20
|
||||
# 配置获取连接等待超时的时间
|
||||
maxWait: 60000
|
||||
# 配置间隔多久才进行一次检测,检测需要关闭的空闲连接,单位是毫秒
|
||||
timeBetweenEvictionRunsMillis: 60000
|
||||
# 配置一个连接在池中最小生存的时间,单位是毫秒
|
||||
minEvictableIdleTimeMillis: 300000
|
||||
# 配置一个连接在池中最大生存的时间,单位是毫秒
|
||||
maxEvictableIdleTimeMillis: 900000
|
||||
# 配置检测连接是否有效
|
||||
validationQuery: SELECT 1 FROM DUAL
|
||||
testWhileIdle: true
|
||||
testOnBorrow: false
|
||||
testOnReturn: false
|
||||
webStatFilter:
|
||||
enabled: true
|
||||
statViewServlet:
|
||||
enabled: true
|
||||
# 设置白名单,不填则允许所有访问
|
||||
allow:
|
||||
url-pattern: /druid/*
|
||||
# 控制台管理用户名和密码
|
||||
login-username:
|
||||
login-password:
|
||||
filter:
|
||||
stat:
|
||||
enabled: true
|
||||
# 慢SQL记录
|
||||
log-slow-sql: true
|
||||
slow-sql-millis: 1000
|
||||
merge-sql: true
|
||||
wall:
|
||||
config:
|
||||
multi-statement-allow: true
|