This commit is contained in:
divenswu 2021-12-27 10:14:50 +08:00
commit 3f04f3362f
369 changed files with 30291 additions and 0 deletions

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 divenswu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

116
README.md Normal file
View File

@ -0,0 +1,116 @@
## 人脸搜索M:N
* 本项目是阿里云视觉智能开放平台的人脸1N的开源替代项目中使用的模型均为开源模型项目支持milvus和proxima向量存储库并具有较高的自定义能力。
* 项目使用纯Java开发免去使用Python带来的服务不稳定性。
* 1N是通过采集某人的人像后从海量的人像数据底库中找到与当前使用者人脸数据相符合的图像通过数据库的比对找出"你是谁",常见的办公楼宇的人脸考勤门禁、社区门禁、工地考勤、会签等等场景。
* MN 是通过计算机对场景内所有人进行面部识别并与人像数据库进行比对的过程。M:N作为一种动态人脸比对其使用率非常高能充分应用于多种场景例如公共安防迎宾机器人应用等。
### 项目简介
* 整体架构图
![输入图片说明](scripts/images/%E4%BA%BA%E8%84%B8%E6%90%9C%E7%B4%A2%E6%B5%81%E7%A8%8B%E5%9B%BE.jpg)
* 项目使用组件
    1、spring boot
    2、[onnx](https://github.com/onnx/onnx)
    3、[milvus](https://github.com/milvus-io/milvus/)
    4、[proxima](https://github.com/alibaba/proximabilin)
* 深度学习模型
    1、[insightface](https://github.com/deepinsight/insightface)
    2、[PCN](https://github.com/Rock-100/FaceKit/tree/master/PCN)
### 项目文档
* [文档-1.0.0](https://gitee.com/open-visual/face-search/blob/master/scripts/docs/doc-1.0.0.md)
* 启动项目且开启swagger访问host:port/doc.html, 如 http://127.0.0.1:8080/doc.html
### 搜索客户端
* Java依赖,未发布到中央仓库,需要自行编译发布到私有仓库
```
<dependency>
<groupId>com.visual.face.search</groupId>
<artifactId>face-search-client</artifactId>
<version>1.0.0</version>
</dependency>
```
* 其他语言依赖
&ensp; &ensp;使用restful接口[文档-1.0.0](https://gitee.com/open-visual/face-search/blob/master/scripts/docs/doc-1.0.0.md)
### 项目部署
* docker部署脚本目录face-search/scripts
```
1、使用milvus作为向量搜索引擎
docker-compose -f docker-compose-milvus.yml --compatibility up -d
2、使用proxima作为向量搜索引擎
docker-compose -f docker-compose-proxima.yml --compatibility up -d
```
* 项目编译
```
1、克隆项目
git clone https://gitee.com/open-visual/face-search.git
2、项目打包
cd face-search && sh scripts/docker_build.sh
```
* 部署参数
| 参数 | 描述 | 默认值 | 可选值|
| -------- | -----: | :----: |--------|
| VISUAL_SWAGGER_ENABLE | 是否开启swagger | true | |
| SPRING_DATASOURCE_URL | 数据库地址 | | |
| SPRING_DATASOURCE_USERNAME | 数据库用户名 | | |
| SPRING_DATASOURCE_PASSWORD | 数据库密码 | | |
| VISUAL_ENGINE_SELECTED | 向量存储引擎 | proxima |proxima,milvus |
| VISUAL_ENGINE_PROXIMA_HOST | PROXIMA地址 | |VISUAL_ENGINE_SELECTED=proxima时生效 |
| VISUAL_ENGINE_PROXIMA_PORT | PROXIMA端口 | 16000 |VISUAL_ENGINE_SELECTED=proxima时生效 |
| VISUAL_ENGINE_MILVUS_HOST | MILVUS地址 | |VISUAL_ENGINE_SELECTED=milvus时生效 |
| VISUAL_ENGINE_MILVUS_PORT | MILVUS端口 | 19530 |VISUAL_ENGINE_SELECTED=milvus时生效 |
| VISUAL_MODEL_FACEDETECTION_NAME | 人脸检测模型名称 | PcnNetworkFaceDetection |PcnNetworkFaceDetectionInsightScrfdFaceDetection |
| VISUAL_MODEL_FACEDETECTION_BACKUP_NAME | 备用人脸检测模型名称 | InsightScrfdFaceDetection |PcnNetworkFaceDetectionInsightScrfdFaceDetection |
| VISUAL_MODEL_FACEKEYPOINT_NAME | 人脸关键点模型名称 | InsightCoordFaceKeyPoint |InsightCoordFaceKeyPoint |
| VISUAL_MODEL_FACEALIGNMENT_NAME | 人脸对齐模型名称 | Simple106pFaceAlignment |Simple106pFaceAlignmentSimple005pFaceAlignment |
| VISUAL_MODEL_FACERECOGNITION_NAME | 人脸特征提取模型名称 | InsightArcFaceRecognition |InsightArcFaceRecognition |
### 性能优化
* 项目中为了提高人脸的检出率使用了主要和次要的人脸检测模型目前实现了两种人脸检测模型insightface和PCN在docker的服务中默认主服务为PCN备用服务为insightface。insightface的效率高但针对于旋转了大角度的人脸检出率不高而pcn则可以识别大角度旋转的图片但效率低一些。若图像均为正脸的图像建议使用insightface为主模型pcn为备用模型如何切换请查看部署参数。
* 在测试过程中针对milvus和proxima发现proxima的速度比milvus稍快但稳定性没有milvus好线上服务使用时还是建议使用milvus作为向量检索引擎。
### 项目演示
* 测试用例face-search-test[测试用例-FaceSearchExample](https://gitee.com/open-visual/face-search/blob/master/face-search-test/src/main/java/com/visual/face/search/valid/exps/FaceSearchExample.java)
* ![输入图片说明](scripts/images/validate.jpg)
### 交流群
* 钉钉交流群
关注微信公众号回复:钉钉群
* 微信交流群
关注微信公众号回复:微信群
* 微信公众号
![微信公众号](scripts/images/%E5%85%AC%E4%BC%97%E5%8F%B7-%E5%BE%AE%E4%BF%A1.jpg)

View File

@ -0,0 +1,31 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<artifactId>face-search-client</artifactId>
<groupId>com.visual.face.search</groupId>
<version>1.0.0</version>
<properties>
<java.version>1.8</java.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.58</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.5</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,85 @@
package com.visual.face.search;
import com.visual.face.search.handle.CollectHandler;
import com.visual.face.search.handle.FaceHandler;
import com.visual.face.search.handle.SampleHandler;
import com.visual.face.search.handle.SearchHandler;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class FaceSearch {
/**服务地址**/
private String serverHost;
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**实例对象**/
private final static Map<String, FaceSearch> ins = new ConcurrentHashMap<>();
/**
* 构建集合对象
* @param serverHost 服务地址
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
private FaceSearch(String serverHost, String namespace, String collectionName){
this.serverHost = serverHost;
this.namespace = namespace;
this.collectionName = collectionName;
}
/**
* 构建集合对象
* @param serverHost 服务地址
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static FaceSearch build (String serverHost,String namespace, String collectionName){
String key = serverHost+"|_|"+namespace + "|_|" + collectionName;
if(!ins.containsKey(key)){
synchronized (FaceSearch.class){
if(!ins.containsKey(key)){
ins.put(key, new FaceSearch(serverHost, namespace, collectionName));
}
}
}
return ins.get(key);
}
/**
* 集合操作对象
* @return CollectHandler
*/
public CollectHandler collect(){
return CollectHandler.build(serverHost, namespace, collectionName);
}
/**
* 样本操作对象
* @return SampleHandler
*/
public SampleHandler sample(){
return SampleHandler.build(serverHost, namespace, collectionName);
}
/**
* 人脸操作对象
* @return FaceHandler
*/
public FaceHandler face(){
return FaceHandler.build(serverHost, namespace, collectionName);
}
/**
* 人脸搜索
* @return FaceHandler
*/
public SearchHandler search(){
return SearchHandler.build(serverHost, namespace, collectionName);
}
}

View File

@ -0,0 +1,26 @@
package com.visual.face.search.common;
public class Api {
public static final String collect_get = "/visual/collect/get";
public static final String collect_list = "/visual/collect/list";
public static final String collect_create = "/visual/collect/create";
public static final String collect_delete = "/visual/collect/delete";
public static final String sample_get = "/visual/sample/get";
public static final String sample_list = "/visual/sample/list";
public static final String sample_create = "/visual/sample/create";
public static final String sample_update = "/visual/sample/update";
public static final String sample_delete = "/visual/sample/delete";
public static final String face_delete = "/visual/face/delete";
public static final String face_create = "/visual/face/create";
public static final String visual_search = "/visual/search/do";
public static String getUrl(String host, String uri){
host = host.replaceAll ("/+$", "");
return host + uri;
}
}

View File

@ -0,0 +1,39 @@
package com.visual.face.search.handle;
public class BaseHandler<ExtendsVo extends BaseHandler<ExtendsVo>> {
/**服务地址**/
protected String serverHost;
/**命名空间**/
protected String namespace;
/**集合名称**/
protected String collectionName;
public String getServerHost() {
return serverHost;
}
public ExtendsVo setServerHost(String serverHost) {
this.serverHost = serverHost;
return (ExtendsVo) this;
}
public String getNamespace() {
return namespace;
}
public ExtendsVo setNamespace(String namespace) {
this.namespace = namespace;
return (ExtendsVo) this;
}
public String getCollectionName() {
return collectionName;
}
public ExtendsVo setCollectionName(String collectionName) {
this.collectionName = collectionName;
return (ExtendsVo) this;
}
}

View File

@ -0,0 +1,85 @@
package com.visual.face.search.handle;
import java.util.Map;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import com.visual.face.search.common.Api;
import com.visual.face.search.http.HttpClient;
import com.visual.face.search.model.*;
public class CollectHandler extends BaseHandler<CollectHandler>{
/**实例对象**/
private final static Map<String, CollectHandler> ins = new ConcurrentHashMap<>();
/**
* 构建集合对象
* @param serverHost 服务地址
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static CollectHandler build(String serverHost, String namespace, String collectionName){
String key = serverHost+"|_|"+namespace + "|_|" + collectionName;
if(!ins.containsKey(key)){
synchronized (CollectHandler.class){
if(!ins.containsKey(key)){
ins.put(key, new CollectHandler().setServerHost(serverHost)
.setNamespace(namespace).setCollectionName(collectionName));
}
}
}
return ins.get(key);
}
/**
* 创建一个集合
* @param collect 集合的定义信息
* @return 是否创建成功
*/
public Response<Boolean> createCollect(Collect collect){
CollectReq collectReq = CollectReq
.build(this.namespace, this.collectionName)
.setCollectionComment(collect.getCollectionComment())
.setMaxDocsPerSegment(collect.getMaxDocsPerSegment())
.setSampleColumns(collect.getSampleColumns())
.setFaceColumns(collect.getFaceColumns())
.setSyncBinLog(collect.isSyncBinLog())
.setShardsNum(collect.getShardsNum());
return HttpClient.post(Api.getUrl(this.serverHost, Api.collect_create), collectReq);
}
/**
*根据命名空间集合名称删除集合
* @return 是否删除成功
*/
public Response<Boolean> deleteCollect(){
MapParam param = MapParam.build()
.put("namespace", namespace)
.put("collectionName", collectionName);
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_delete), param);
}
/**
*根据命名空间集合名称查看集合信息
* @return 集合信息
*/
public Response<CollectRep> getCollect(){
MapParam param = MapParam.build()
.put("namespace", namespace)
.put("collectionName", collectionName);
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_get), param);
}
/**
*根据命名空间查看集合列表
* @return 集合列表
*/
public Response<List<CollectRep>> collectList(){
MapParam param = MapParam.build().put("namespace", namespace);
return HttpClient.get(Api.getUrl(this.serverHost, Api.collect_list), param);
}
}

View File

@ -0,0 +1,68 @@
package com.visual.face.search.handle;
import com.visual.face.search.common.Api;
import com.visual.face.search.http.HttpClient;
import com.visual.face.search.model.*;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class FaceHandler extends BaseHandler<FaceHandler>{
/**实例对象**/
private final static Map<String, FaceHandler> ins = new ConcurrentHashMap<>();
/**
* 构建样本对象
* @param serverHost 服务地址
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static FaceHandler build(String serverHost, String namespace, String collectionName){
String key = serverHost+"|_|"+namespace + "|_|" + collectionName;
if(!ins.containsKey(key)){
synchronized (FaceHandler.class){
if(!ins.containsKey(key)){
ins.put(key, new FaceHandler().setServerHost(serverHost)
.setNamespace(namespace).setCollectionName(collectionName));
}
}
}
return ins.get(key);
}
/**
* 创建一个人脸数据
* @param face 人脸的定义信息
* @return 是否创建成功
*/
public Response<FaceRep> createFace(Face face){
FaceReq faceReq = FaceReq
.build(this.namespace, this.collectionName)
.setSampleId(face.getSampleId())
.setImageBase64(face.getImageBase64())
.setFaceData(face.getFaceData())
.setFaceScoreThreshold(face.getFaceScoreThreshold())
.setMinConfidenceThresholdWithThisSample(face.getMinConfidenceThresholdWithThisSample())
.setMaxConfidenceThresholdWithOtherSample(face.getMaxConfidenceThresholdWithOtherSample());
return HttpClient.post(Api.getUrl(this.serverHost, Api.face_create), faceReq);
}
/**
*根据条件删除人脸数据
* @param sampleId 样本ID
* @param faceId 人脸ID
* @return 是否删除成功
*/
public Response<Boolean> deleteFace(String sampleId, String faceId){
MapParam param = MapParam.build()
.put("namespace", namespace)
.put("collectionName", collectionName)
.put("sampleId", sampleId)
.put("faceId", faceId);
return HttpClient.get(Api.getUrl(this.serverHost, Api.face_delete), param);
}
}

View File

@ -0,0 +1,103 @@
package com.visual.face.search.handle;
import com.alibaba.fastjson.TypeReference;
import com.visual.face.search.common.Api;
import com.visual.face.search.http.HttpClient;
import com.visual.face.search.model.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class SampleHandler extends BaseHandler<SampleHandler> {
/**实例对象**/
private final static Map<String, SampleHandler> ins = new ConcurrentHashMap<>();
/**
* 构建样本对象
* @param serverHost 服务地址
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static SampleHandler build(String serverHost, String namespace, String collectionName){
String key = serverHost+"|_|"+namespace + "|_|" + collectionName;
if(!ins.containsKey(key)){
synchronized (SampleHandler.class){
if(!ins.containsKey(key)){
ins.put(key, new SampleHandler().setServerHost(serverHost)
.setNamespace(namespace).setCollectionName(collectionName));
}
}
}
return ins.get(key);
}
/**
* 创建一个样本
* @param sample 样本的定义信息
* @return 是否创建成功
*/
public Response<Boolean> createSample(Sample sample){
SampleReq sampleReq = SampleReq
.build(this.namespace, this.collectionName)
.setSampleId(sample.getSampleId())
.setSampleData(sample.getSampleData());
return HttpClient.post(Api.getUrl(this.serverHost, Api.sample_create), sampleReq);
}
/**
* 更新一个样本
* @param sample 样本的定义信息
* @return 是否创建成功
*/
public Response<Boolean> updateSample(Sample sample){
SampleReq sampleReq = SampleReq
.build(this.namespace, this.collectionName)
.setSampleId(sample.getSampleId())
.setSampleData(sample.getSampleData());
return HttpClient.post(Api.getUrl(this.serverHost, Api.sample_update), sampleReq);
}
/**
*根据条件删除样本
* @return 是否删除成功
*/
public Response<Boolean> deleteSample(String sampleId){
MapParam param = MapParam.build()
.put("namespace", namespace)
.put("collectionName", collectionName)
.put("sampleId", sampleId);
return HttpClient.get(Api.getUrl(this.serverHost, Api.sample_delete), param);
}
/**
*根据条件查看样本
* @return 样本
*/
public Response<SampleRep> getSample(String sampleId){
MapParam param = MapParam.build()
.put("namespace", namespace)
.put("collectionName", collectionName)
.put("sampleId", sampleId);
return HttpClient.get(Api.getUrl(this.serverHost, Api.sample_get), param, new TypeReference<Response<SampleRep>>(){});
}
/**
* 根据查询信息查看样本列表
* @param offset 起始记录
* @param limit 样本数目
* @param order 排列方式包括asc升序和desc降序
* @return
*/
public Response<List<SampleRep>> sampleList(Integer offset, Integer limit, Order order){
MapParam param = MapParam.build()
.put("namespace", namespace)
.put("collectionName", collectionName)
.put("offset", offset)
.put("limit", limit)
.put("order", order.name());
return HttpClient.get(Api.getUrl(this.serverHost, Api.sample_list), param, new TypeReference<Response<List<SampleRep>>>(){});
}
}

View File

@ -0,0 +1,51 @@
package com.visual.face.search.handle;
import com.alibaba.fastjson.TypeReference;
import com.visual.face.search.common.Api;
import com.visual.face.search.http.HttpClient;
import com.visual.face.search.model.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class SearchHandler extends BaseHandler<SearchHandler>{
/**实例对象**/
private final static Map<String, SearchHandler> ins = new ConcurrentHashMap<>();
/**
* 构建样本对象
* @param serverHost 服务地址
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static SearchHandler build(String serverHost, String namespace, String collectionName){
String key = serverHost+"|_|"+namespace + "|_|" + collectionName;
if(!ins.containsKey(key)){
synchronized (SearchHandler.class){
if(!ins.containsKey(key)){
ins.put(key, new SearchHandler().setServerHost(serverHost)
.setNamespace(namespace).setCollectionName(collectionName));
}
}
}
return ins.get(key);
}
/**
* 人脸搜索
* @param search 搜索条件
* @return 获取当前匹配的列表
*/
public Response<List<SearchRep>> search(Search search){
SearchReq searchReq = SearchReq.build(namespace, collectionName)
.setImageBase64(search.getImageBase64())
.setMaxFaceNum(search.getMaxFaceNum())
.setLimit(search.getLimit())
.setConfidenceThreshold(search.getConfidenceThreshold())
.setFaceScoreThreshold(search.getFaceScoreThreshold());
return HttpClient.post(Api.getUrl(this.serverHost, Api.visual_search), searchReq, new TypeReference<Response<List<SearchRep>>>(){});
}
}

View File

@ -0,0 +1,223 @@
package com.visual.face.search.http;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.TypeReference;
import com.visual.face.search.model.Collect;
import com.visual.face.search.model.CollectRep;
import com.visual.face.search.model.MapParam;
import com.visual.face.search.model.Response;
import com.visual.face.search.utils.JsonUtil;
import org.apache.http.NameValuePair;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCookieStore;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.net.URISyntaxException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class HttpClient {
/**编码方式*/
private static final String ENCODING = "UTF-8";
/**连接超时时间10秒*/
public static final int DEFAULT_CONNECT_TIMEOUT = 10 * 1000;
/**socket连接超时时间10秒*/
public static final int DEFAULT_READ_TIMEOUT = 10 * 000;
/**请求超时时间60秒*/
public static final int DEFAULT_CONNECT_REQUEST_TIMEOUT = 30 * 000;
/**最大连接数,默认为2*/
private static final int MAX_TOTAL = 8;
/**设置指向特定路由的并发连接总数默认为2*/
private static final int MAX_PER_ROUTE = 4;
private static RequestConfig requestConfig;
private static PoolingHttpClientConnectionManager connectionManager;
private static BasicCookieStore cookieStore;
private static HttpClientBuilder httpBuilder;
private static CloseableHttpClient httpClient;
private static CloseableHttpClient httpsClient;
private static SSLContext sslContext;
/**
* 创建SSLContext对象用来绕过https证书认证实现访问
*/
static {
try {
sslContext = SSLContext.getInstance("TLS");
// 实现一个X509TrustManager接口用于绕过验证不用修改里面的方法
X509TrustManager tm = new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
};
sslContext.init(null, new TrustManager[] {tm}, null);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 初始化httpclient对象以及在创建httpclient对象之前的一些自定义配置
*/
static {
// 自定义配置信息
requestConfig = RequestConfig.custom()
.setSocketTimeout(DEFAULT_READ_TIMEOUT)
.setConnectTimeout(DEFAULT_CONNECT_TIMEOUT)
.setConnectionRequestTimeout(DEFAULT_CONNECT_REQUEST_TIMEOUT)
.build();
//设置协议http和https对应的处理socket链接工厂的对象
Registry<ConnectionSocketFactory> socketFactoryRegistry = RegistryBuilder.<ConnectionSocketFactory> create()
.register("http", new PlainConnectionSocketFactory())
.register("https", new SSLConnectionSocketFactory(sslContext))
.build();
connectionManager = new PoolingHttpClientConnectionManager(socketFactoryRegistry);
// 设置cookie存储对像在需要获取cookie信息时可以使用这个对象
cookieStore = new BasicCookieStore();
// 设置最大连接数
connectionManager.setMaxTotal(MAX_TOTAL);
// 设置路由并发数
connectionManager.setDefaultMaxPerRoute(MAX_PER_ROUTE);
httpBuilder = HttpClientBuilder.create();
httpBuilder.setDefaultRequestConfig(requestConfig);
httpBuilder.setConnectionManager(connectionManager);
httpBuilder.setDefaultCookieStore(cookieStore);
// 实例化http https的对象
httpClient = httpBuilder.build();
httpsClient = httpBuilder.build();
}
/**
* post请求
* @param url
* @param data
* @param <T>
* @return
*/
public static <T> Response<T> post(String url, Object data) {
return post(url, data, new TypeReference<Response<T>>() {});
}
/**
* post请求
* @param url
* @param data
* @param <T>
* @return
*/
public static <T> Response<T> post(String url, Object data, TypeReference<Response<T>> type) {
// 创建HTTP对象
HttpPost httpPost = new HttpPost(url);
httpPost.setConfig(requestConfig);
httpPost.addHeader("Content-Type", "application/json;charset=UTF-8");
// 设置请求头
if(null != data){
String json = data instanceof String ? String.valueOf(data) : JsonUtil.toString(data);
StringEntity stringEntity = new StringEntity(json, ENCODING);
stringEntity.setContentEncoding(ENCODING);
httpPost.setEntity(stringEntity);
}
// 创建httpResponse对象
CloseableHttpClient client = url.toLowerCase().startsWith("https") ? httpsClient : httpClient;
try {
CloseableHttpResponse httpResponse = client.execute(httpPost);
String content = EntityUtils.toString(httpResponse.getEntity(), ENCODING);
return JsonUtil.toEntity(content, type);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* get
* @param url
* @param param
* @param <T>
* @return
*/
public static <T> Response<T> get(String url, MapParam param) {
return get(url, param, new TypeReference<Response<T>>() {});
}
/**
* get
* @param url
* @param param
* @param <T>
* @return
*/
public static <T> Response<T> get(String url, MapParam param, TypeReference<Response<T>> type) {
try {
//参数构建
URIBuilder uriBuilder = new URIBuilder(url);
if(null != param && !param.isEmpty()){
List<NameValuePair> list = new LinkedList<>();
for(String key : param.keySet()){
Object value = param.get(key);
if(null == value){
list.add(new BasicNameValuePair(key, null));
}else{
list.add(new BasicNameValuePair(key, String.valueOf(value)));
}
}
uriBuilder.setParameters(list);
}
//构建请求
HttpGet httpGet = new HttpGet(uriBuilder.build());
httpGet.setConfig(requestConfig);
httpGet.addHeader("Content-Type", "application/json;charset=UTF-8");
// 创建httpResponse对象
CloseableHttpClient client = url.toLowerCase().startsWith("https") ? httpsClient : httpClient;
CloseableHttpResponse httpResponse = client.execute(httpGet);
String content = EntityUtils.toString(httpResponse.getEntity(), ENCODING);
return JsonUtil.toEntity(content, type);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
String url = "http://127.0.0.1:8080/visual/collect/create";
CollectRep collectRep = CollectRep.build("n1", "c1003").setCollectionComment("xxxxxx");
Response<Boolean> res = HttpClient.post(url, collectRep);
System.out.println(res.getCode());
System.out.println(res.getData());
System.out.println(res.getMessage());
}
}

View File

@ -0,0 +1,4 @@
package com.visual.face.search.http;
public class HttpClientResult {
}

View File

@ -0,0 +1,466 @@
package com.visual.face.search.http;
import org.apache.http.HttpStatus;
import org.apache.http.NameValuePair;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpDelete;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.cookie.Cookie;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicCookieStore;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @Description: httpclient常用方法封装,
* @Author: ggf
* @Date: 2020/06/06
*/
public class HttpClientUtils {
/**
* 编码方式
*/
private static final String ENCODING = "UTF-8";
/**
* 连接超时时间60秒
*/
public static final int DEFAULT_CONNECT_TIMEOUT = 6000;
/**
* socket连接超时时间60秒
*/
public static final int DEFAULT_READ_TIMEOUT = 6000;
/**
* 请求超时时间60秒
*/
public static final int DEFAULT_CONNECT_REQUEST_TIMEOUT = 6000;
/**
* 最大连接数,默认为2
*/
private static final int MAX_TOTAL = 64;
/**
* 设置指向特定路由的并发连接总数默认为2
*/
private static final int MAX_PER_ROUTE = 32;
private static RequestConfig requestConfig;
private static PoolingHttpClientConnectionManager connectionManager;
private static BasicCookieStore cookieStore;
private static HttpClientBuilder httpBuilder;
private static CloseableHttpClient httpClient;
private static CloseableHttpClient httpsClient;
private static SSLContext sslContext;
/**
* 创建SSLContext对象用来绕过https证书认证实现访问
*/
static {
try {
sslContext = SSLContext.getInstance("TLS");
// 实现一个X509TrustManager接口用于绕过验证不用修改里面的方法
X509TrustManager tm = new X509TrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType)
throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
};
sslContext.init(null, new TrustManager[] {tm}, null);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 初始化httpclient对象以及在创建httpclient对象之前的一些自定义配置
*/
static {
// 自定义配置信息
requestConfig = RequestConfig.custom()
.setSocketTimeout(DEFAULT_READ_TIMEOUT)
.setConnectTimeout(DEFAULT_CONNECT_TIMEOUT)
.setConnectionRequestTimeout(DEFAULT_CONNECT_REQUEST_TIMEOUT)
.build();
//设置协议http和https对应的处理socket链接工厂的对象
Registry<ConnectionSocketFactory> socketFactoryRegistry = RegistryBuilder.<ConnectionSocketFactory> create()
.register("http", new PlainConnectionSocketFactory())
.register("https", new SSLConnectionSocketFactory(sslContext))
.build();
connectionManager = new PoolingHttpClientConnectionManager(socketFactoryRegistry);
// 设置cookie存储对像在需要获取cookie信息时可以使用这个对象
cookieStore = new BasicCookieStore();
// 设置最大连接数
connectionManager.setMaxTotal(MAX_TOTAL);
// 设置路由并发数
connectionManager.setDefaultMaxPerRoute(MAX_PER_ROUTE);
httpBuilder = HttpClientBuilder.create();
httpBuilder.setDefaultRequestConfig(requestConfig);
httpBuilder.setConnectionManager(connectionManager);
httpBuilder.setDefaultCookieStore(cookieStore);
// 实例化http https的对象
httpClient = httpBuilder.build();
httpsClient = httpBuilder.build();
}
/**
* 封装无参数的get请求http
* @param url 请求url
* @return 返回对象HttpClientResult
*/
public static HttpClientResult doGet(String url) {
return doGet(url, false);
}
/**
* 封装无参get请求支持https协议
* @param url 请求url
* @param https 请求的是否是https协议true 否false
* @return
*/
public static HttpClientResult doGet(String url, boolean https){
return doGet(url, null, null, https);
}
/**
* 封装带参数的get请求支持https协议
* @param url 请求url
* @param params 请求参数
* @param https 是否是https协议
*/
public static HttpClientResult doGet(String url, Map<String, String> params, boolean https){
return doGet(url, null, params, https);
}
/**
* 封装带参数和带请求头信息的GET方法支持https协议请求
* @param url 请求url
* @param headers 请求头信息
* @param params 请求参数
* @param https 是否使用https协议
*/
public static HttpClientResult doGet(String url, Map<String, String> headers, Map<String, String> params, boolean https){
// 创建HttpGet
HttpGet httpGet = null;
// 创建httpResponse对象
CloseableHttpResponse httpResponse = null;
try {
// 创建访问的地址
URIBuilder uriBuilder = new URIBuilder(url);
if (params != null) {
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
uriBuilder.setParameter(entry.getKey(), entry.getValue());
}
}
// 创建HTTP对象
httpGet = new HttpGet(uriBuilder.build());
httpGet.setConfig(requestConfig);
// 设置请求头
setHeader(headers, httpGet);
// 使用不同的协议进行请求返回自定义的响应对象
if (https) {
return getHttpClientResult(httpResponse, httpsClient, httpGet);
} else {
return getHttpClientResult(httpResponse, httpClient, httpGet);
}
} catch (Exception e) {
e.printStackTrace();
} finally {
// 释放资源
if (httpGet != null) {
httpGet.releaseConnection();
}
release(httpResponse);
}
return null;
}
/**
* POST不带参数只支持http协议
* @param url 请求url
*/
public static HttpClientResult doPost(String url) {
return doPost(url, Boolean.FALSE);
}
/**
* 封装不带参数的post请求支持https协议
* @param url 请求url
* @param https 是否是https协议
*/
public static HttpClientResult doPost(String url, boolean https) {
return doPost(url, null, (Map<String, String>)null, https);
}
/**
* 带参数的post请求支持https协议
* @param url 请求url
* @param params 请求参数
* @param https 是否是https协议
*/
public static HttpClientResult doPost(String url, Map<String, String> params, boolean https) {
return doPost(url, null, params, https);
}
/**
* 带参数和请求头的POST请求支持https
*
* @param url 请求url
* @param headers 请求头
* @param params 请求参数参数为K=V格式
* @param https 是否https协议
*/
public static HttpClientResult doPost(String url, Map<String, String> headers, Map<String, String> params, boolean https) {
// 创建HTTP对象
HttpPost httpPost = new HttpPost(url);
httpPost.setConfig(requestConfig);
// 设置请求头
setHeader(headers, httpPost);
// 封装请求参数
setParam(params, httpPost);
// 创建httpResponse对象
CloseableHttpResponse httpResponse = null;
try {
if (https) {
return getHttpClientResult(httpResponse, httpsClient, httpPost);
} else {
return getHttpClientResult(httpResponse, httpClient, httpPost);
}
} finally {
httpPost.releaseConnection();
release(httpResponse);
}
}
/**
* 带参数带请求头的POST请求支持https协议
*
* @param url 请求url
* @param headers 请求头
* @param json 请求参数为json格式
* @param https 是否使用https协议
* @throws Exception
*/
public static HttpClientResult doPost(String url, Map<String, String> headers, String json, boolean https) {
// 创建HTTP对象
HttpPost httpPost = new HttpPost(url);
httpPost.setConfig(requestConfig);
// 设置请求头
setHeader(headers, httpPost);
StringEntity stringEntity = new StringEntity(json, ENCODING);
stringEntity.setContentEncoding(ENCODING);
httpPost.setEntity(stringEntity);
// 创建httpResponse对象
CloseableHttpResponse httpResponse = null;
try {
if (https) {
return getHttpClientResult(httpResponse, httpsClient, httpPost);
} else {
return getHttpClientResult(httpResponse, httpClient, httpPost);
}
} finally {
httpPost.releaseConnection();
release(httpResponse);
}
}
/**
* 发送put请求不带请求参数
*
* @param url 请求地址
* @return
* @throws Exception
*/
public static HttpClientResult doPut(String url) {
return doPut(url);
}
/**
* 发送put请求带请求参数
*
* @param url 请求地址
* @param params 参数集合
* @return
* @throws Exception
*/
public static HttpClientResult doPut(String url, Map<String, String> params) {
HttpPut httpPut = new HttpPut(url);
httpPut.setConfig(requestConfig);
setParam(params, httpPut);
CloseableHttpResponse httpResponse = null;
try {
return getHttpClientResult(httpResponse, httpClient, httpPut);
} finally {
httpPut.releaseConnection();
release(httpResponse);
}
}
/**
* 发送delete请求不带请求参数
*
* @param url 请求url
* @return
* @throws Exception
*/
public static HttpClientResult doDelete(String url) {
HttpDelete httpDelete = new HttpDelete(url);
httpDelete.setConfig(requestConfig);
CloseableHttpResponse httpResponse = null;
try {
return getHttpClientResult(httpResponse, httpClient, httpDelete);
} finally {
httpDelete.releaseConnection();
release(httpResponse);
}
}
/**
* 发送delete请求带请求参数 支持https协议
*
* @param url 请求url
* @param params 请求参数
* @param https 是否https
* @return
* @throws Exception
*/
public static HttpClientResult doDelete(String url, Map<String, String> params, boolean https) {
if (params == null) {
params = new HashMap<String, String>();
}
params.put("_method", "delete");
return doPost(url, params, https);
}
/**
* 获取cookie信息
* @return 返回所有cookie集合
*/
public static List<Cookie> getCookies() {
return cookieStore.getCookies();
}
/**
* 设置封装请求头
*
* @param params 头信息
* @param httpMethod 请求对象
*/
public static void setHeader(Map<String, String> params, HttpRequestBase httpMethod) {
// 封装请求头
if (null != params && !params.isEmpty()) {
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
// 设置到请求头到HttpRequestBase对象中
httpMethod.setHeader(entry.getKey(), entry.getValue());
}
}
}
/**
* 封装请求参数
*
* @param params 请求参数
* @param httpMethod 请求方法
*/
public static void setParam(Map<String, String> params, HttpEntityEnclosingRequestBase httpMethod) {
// 封装请求参数
if (null != params && !params.isEmpty()) {
List<NameValuePair> nvps = new ArrayList<NameValuePair>();
Set<Map.Entry<String, String>> entrySet = params.entrySet();
for (Map.Entry<String, String> entry : entrySet) {
nvps.add(new BasicNameValuePair(entry.getKey(), entry.getValue()));
}
UrlEncodedFormEntity entity = null;
try {
entity = new UrlEncodedFormEntity(nvps, ENCODING);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
// 设置到请求的http对象中
httpMethod.setEntity(entity);
}
}
/**
* 获得响应结果
*
* @param httpResponse 响应对象
* @param httpClient httpclient对象
* @param httpMethod 请求方法
* @return
* @throws Exception
*/
public static HttpClientResult getHttpClientResult(CloseableHttpResponse httpResponse, CloseableHttpClient httpClient, HttpRequestBase httpMethod) {
try {
// 执行请求
httpResponse = httpClient.execute(httpMethod);
// 获取返回结果
if (httpResponse != null && httpResponse.getStatusLine() != null) {
String content = "";
if (httpResponse.getEntity() != null) {
content = EntityUtils.toString(httpResponse.getEntity(), ENCODING);
}
// return new HttpClientResult(httpResponse.getStatusLine().getStatusCode(), content);
}
} catch (IOException e) {
e.printStackTrace();
}
// return new HttpClientResult(HttpStatus.SC_INTERNAL_SERVER_ERROR);
return null;
}
/**
* 释放资源
*
* @param httpResponse 响应对象
*/
public static void release(CloseableHttpResponse httpResponse) {
// 释放资源
if (httpResponse != null) {
try {
httpResponse.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}

View File

@ -0,0 +1,95 @@
package com.visual.face.search.model;
import java.util.ArrayList;
import java.util.List;
import java.io.Serializable;
/***
* 集合信息对象
*/
public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializable {
/**集合描述**/
private String collectionComment;
/**数据分片中最大的文件个数,仅对Proxima引擎生效**/
private Long maxDocsPerSegment = 0L;
/**要创建的集合的分片数,仅对Milvus引擎生效**/
private Integer shardsNum = 0;
/**自定义的样本字段**/
private List<FiledColumn> sampleColumns = new ArrayList<>();
/**自定义的人脸字段**/
private List<FiledColumn> faceColumns = new ArrayList<>();
/**启用binlog同步**/
private Boolean syncBinLog = false;
/**
* 构建集合对象
* @return
*/
public static Collect build(){
return new Collect();
}
public String getCollectionComment() {
return collectionComment;
}
public ExtendsVo setCollectionComment(String collectionComment) {
this.collectionComment = collectionComment;
return (ExtendsVo) this;
}
public Long getMaxDocsPerSegment() {
return maxDocsPerSegment;
}
public ExtendsVo setMaxDocsPerSegment(Long maxDocsPerSegment) {
if(null != maxDocsPerSegment && maxDocsPerSegment >= 0){
this.maxDocsPerSegment = maxDocsPerSegment;
}
return (ExtendsVo) this;
}
public Integer getShardsNum() {
return shardsNum;
}
public ExtendsVo setShardsNum(Integer shardsNum) {
if(null != shardsNum && shardsNum >= 0){
this.shardsNum = shardsNum;
}
return (ExtendsVo) this;
}
public List<FiledColumn> getSampleColumns() {
return sampleColumns;
}
public ExtendsVo setSampleColumns(List<FiledColumn> sampleColumns) {
if(null != sampleColumns){
this.sampleColumns = sampleColumns;
}
return (ExtendsVo) this;
}
public List<FiledColumn> getFaceColumns() {
return faceColumns;
}
public ExtendsVo setFaceColumns(List<FiledColumn> faceColumns) {
if(null != faceColumns){
this.faceColumns = faceColumns;
}
return (ExtendsVo) this;
}
public boolean isSyncBinLog() {
return null == syncBinLog ? false : syncBinLog;
}
public ExtendsVo setSyncBinLog(Boolean syncBinLog) {
if(null != syncBinLog){
this.syncBinLog = syncBinLog;
}
return (ExtendsVo) this;
}
}

View File

@ -0,0 +1,38 @@
package com.visual.face.search.model;
public class CollectRep extends Collect<CollectRep> {
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static CollectRep build(String namespace, String collectionName){
return new CollectRep().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public CollectRep setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public CollectRep setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
}

View File

@ -0,0 +1,38 @@
package com.visual.face.search.model;
public class CollectReq extends Collect<CollectReq> {
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static CollectReq build(String namespace, String collectionName){
return new CollectReq().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public CollectReq setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public CollectReq setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
}

View File

@ -0,0 +1,91 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class Face<ExtendsVo extends Face<ExtendsVo>> implements Serializable {
/**样本ID**/
private String sampleId;
/**图像Base64编码值**/
private String imageBase64;
/**人脸质量分数阈值**/
private Float faceScoreThreshold = 0f;
/**当前样本里,人脸自信度的最小阈值**/
private Float minConfidenceThresholdWithThisSample = 0f;
/**当前样本与其他样本里,人脸自信度的最大阈值**/
private Float maxConfidenceThresholdWithOtherSample = 0f;
/**人脸扩展的额外数据**/
private KeyValues faceData = KeyValues.build();
/**
* 构建样本数据
* @param sampleId 样本ID
* @return
*/
public static Face build(String sampleId){
return new Face().setSampleId(sampleId);
}
public String getSampleId() {
return sampleId;
}
public ExtendsVo setSampleId(String sampleId) {
this.sampleId = sampleId;
return (ExtendsVo) this;
}
public KeyValues getFaceData() {
return faceData;
}
public ExtendsVo setFaceData(KeyValues faceData) {
if(null != faceData){
this.faceData = faceData;
}
return (ExtendsVo) this;
}
public String getImageBase64() {
return imageBase64;
}
public ExtendsVo setImageBase64(String imageBase64) {
this.imageBase64 = imageBase64;
return (ExtendsVo) this;
}
public Float getFaceScoreThreshold() {
return faceScoreThreshold;
}
public ExtendsVo setFaceScoreThreshold(Float faceScoreThreshold) {
if(null != faceScoreThreshold && faceScoreThreshold >= 0){
this.faceScoreThreshold = faceScoreThreshold;
}
return (ExtendsVo) this;
}
public Float getMinConfidenceThresholdWithThisSample() {
return minConfidenceThresholdWithThisSample;
}
public ExtendsVo setMinConfidenceThresholdWithThisSample(Float minConfidenceThresholdWithThisSample) {
if(null != minConfidenceThresholdWithThisSample && minConfidenceThresholdWithThisSample >= 0){
this.minConfidenceThresholdWithThisSample = minConfidenceThresholdWithThisSample;
}
return (ExtendsVo) this;
}
public Float getMaxConfidenceThresholdWithOtherSample() {
return maxConfidenceThresholdWithOtherSample;
}
public ExtendsVo setMaxConfidenceThresholdWithOtherSample(Float maxConfidenceThresholdWithOtherSample) {
if(null != maxConfidenceThresholdWithOtherSample && maxConfidenceThresholdWithOtherSample >= 0){
this.maxConfidenceThresholdWithOtherSample = maxConfidenceThresholdWithOtherSample;
}
return (ExtendsVo) this;
}
}

View File

@ -0,0 +1,66 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class FaceLocation implements Serializable {
/**左上角x坐标**/
private int x;
/**左上角y坐标**/
private int y;
/**宽度**/
private int w;
/**高度**/
private int h;
/**
* 构建坐标
* @param x
* @param y
* @param w
* @param h
* @return
*/
public static FaceLocation build(int x, int y, int w, int h){
return new FaceLocation().setX(x).setY(y).setW(w).setH(h);
}
public static FaceLocation build(float x, float y, float w, float h){
return new FaceLocation().setX((int) x).setY((int) y).setW((int) w).setH((int) h);
}
public int getX() {
return x;
}
public FaceLocation setX(int x) {
this.x = x;
return this;
}
public int getY() {
return y;
}
public FaceLocation setY(int y) {
this.y = y;
return this;
}
public int getW() {
return w;
}
public FaceLocation setW(int w) {
this.w = w;
return this;
}
public int getH() {
return h;
}
public FaceLocation setH(int h) {
this.h = h;
return this;
}
}

View File

@ -0,0 +1,84 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class FaceRep implements Serializable {
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**样本ID**/
private String sampleId;
/**人脸ID**/
private String faceId;
/**人脸分数**/
private Float faceScore;
/**人脸扩展的额外数据**/
private KeyValues faceData = KeyValues.build();
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static FaceRep build(String namespace, String collectionName){
return new FaceRep().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public FaceRep setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public FaceRep setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
public String getSampleId() {
return sampleId;
}
public FaceRep setSampleId(String sampleId) {
this.sampleId = sampleId;
return this;
}
public String getFaceId() {
return faceId;
}
public FaceRep setFaceId(String faceId) {
this.faceId = faceId;
return this;
}
public Float getFaceScore() {
return faceScore;
}
public FaceRep setFaceScore(Float faceScore) {
this.faceScore = faceScore;
return this;
}
public KeyValues getFaceData() {
return faceData;
}
public FaceRep setFaceData(KeyValues faceData) {
this.faceData = faceData;
return this;
}
}

View File

@ -0,0 +1,38 @@
package com.visual.face.search.model;
public class FaceReq extends Face<FaceReq>{
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static FaceReq build(String namespace, String collectionName){
return new FaceReq().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public FaceReq setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public FaceReq setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
}

View File

@ -0,0 +1,47 @@
package com.visual.face.search.model;
import java.io.Serializable;
/**
* 字段定义
*/
public class FiledColumn implements Serializable {
/**字段名称*/
private String name;
/**字段描述*/
private String comment;
/**字段类型*/
private FiledDataType dataType;
/**构建工具**/
public static FiledColumn build(){
return new FiledColumn();
}
public String getName() {
return name;
}
public FiledColumn setName(String name) {
this.name = name;
return this;
}
public String getComment() {
return comment;
}
public FiledColumn setComment(String comment) {
this.comment = comment;
return this;
}
public FiledDataType getDataType() {
return dataType;
}
public FiledColumn setDataType(FiledDataType dataType) {
this.dataType = dataType;
return this;
}
}

View File

@ -0,0 +1,59 @@
package com.visual.face.search.model;
/**
* 定义字段类型
*/
public enum FiledDataType {
/**
* Undefined data type
*/
UNDEFINED(0),
/**
* String data type
*/
STRING(1),
/**
* Bool data type
*/
BOOL(2),
/**
* Int32 data type
*/
INT(3),
/**
* Float data type
*/
FLOAT(4),
/**
* Double data type
*/
DOUBLE(5);
private int value;
FiledDataType(int value) {
this.value = value;
}
public int getValue() {
return this.value;
}
public static FiledDataType valueOf(int value) {
switch (value) {
case 1:
return STRING;
case 2:
return BOOL;
case 3:
return INT;
case 4:
return FLOAT;
case 5:
return DOUBLE;
default:
return UNDEFINED;
}
}
}

View File

@ -0,0 +1,110 @@
package com.visual.face.search.model;
public class KeyValue {
private String key;
private Object value;
public KeyValue(){}
public KeyValue(String key, Object value) {
this.key = key;
this.value = value;
}
public static KeyValue build(String key, String value){
return new KeyValue(key, value);
}
public static KeyValue build(String key, Boolean value){
return new KeyValue(key, value);
}
public static KeyValue build(String key, Integer value){
return new KeyValue(key, value);
}
public static KeyValue build(String key, Float value){
return new KeyValue(key, value);
}
public static KeyValue build(String key, Double value){
return new KeyValue(key, value);
}
public static KeyValue build(String key, Object value){
return new KeyValue(key, value);
}
public String getKey() {
return key;
}
public void setKey(String key) {
this.key = key;
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String toStringValue(){
if(null == value){
return null;
}
if(value instanceof String){
return value.toString();
}else{
return String.valueOf(value);
}
}
public Boolean toBooleanValue(){
if(null == value){
return null;
}
if(value instanceof Boolean){
return (Boolean)value;
}else{
return Boolean.parseBoolean(String.valueOf(value));
}
}
public Integer toIntegerValue(){
if(null == value){
return null;
}
if(value instanceof Integer){
return (Integer)value;
}else{
return Integer.parseInt(String.valueOf(value));
}
}
public Float toFloatValue(){
if(null == value){
return null;
}
if(value instanceof Float){
return (Float)value;
}else{
return Float.parseFloat(String.valueOf(value));
}
}
public Double toDoubleValue(){
if(null == value){
return null;
}
if(value instanceof Double){
return (Double)value;
}else{
return Double.parseDouble(String.valueOf(value));
}
}
}

View File

@ -0,0 +1,71 @@
package com.visual.face.search.model;
import java.util.ArrayList;
import java.util.Arrays;
public class KeyValues extends ArrayList<KeyValue>{
public static KeyValues build(){
return new KeyValues();
}
public KeyValues add(KeyValue...keyValue){
this.addAll(Arrays.asList(keyValue));
return this;
}
public String getString(String key){
for(KeyValue keyValue : this){
if(key.equalsIgnoreCase(keyValue.getKey())){
return keyValue.toStringValue();
}
}
return null;
}
public Boolean getBoolean(String key){
for(KeyValue keyValue : this){
if(key.equalsIgnoreCase(keyValue.getKey())){
return keyValue.toBooleanValue();
}
}
return null;
}
public Integer getInteger(String key){
for(KeyValue keyValue : this){
if(key.equalsIgnoreCase(keyValue.getKey())){
return keyValue.toIntegerValue();
}
}
return null;
}
public Float getFloat(String key){
for(KeyValue keyValue : this){
if(key.equalsIgnoreCase(keyValue.getKey())){
return keyValue.toFloatValue();
}
}
return null;
}
public Double getDouble(String key){
for(KeyValue keyValue : this){
if(key.equalsIgnoreCase(keyValue.getKey())){
return keyValue.toDoubleValue();
}
}
return null;
}
public Object getObject(String key){
for(KeyValue keyValue : this){
if(key.equalsIgnoreCase(keyValue.getKey())){
return keyValue.getValue();
}
}
return null;
}
}

View File

@ -0,0 +1,17 @@
package com.visual.face.search.model;
import java.util.HashMap;
public class MapParam extends HashMap<String, Object> {
public static MapParam build(){
return new MapParam();
}
public MapParam put(String key, Object value){
super.put(key, value);
return this;
}
}

View File

@ -0,0 +1,5 @@
package com.visual.face.search.model;
public enum Order {
asc, desc;
}

View File

@ -0,0 +1,62 @@
package com.visual.face.search.model;
import java.io.Serializable;
/**
* des:接口返回对象
* @author diven
* @date 上午9:34 2018/7/12
*/
public class Response<T> implements Serializable{
private static final long serialVersionUID = -6919611972884058300L;
private Integer code = -1;
private String message;
private T data;
public Response(){}
public Response(Integer code, String message, T data) {
if(null != code) {
this.code = code;
}
this.message = message;
this.data = data;
}
public Integer getCode() {
return code;
}
public void setCode(Integer code) {
if(null != code){
this.code = code;
}
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public T getData() {
return data;
}
public void setData(T data) {
this.data = data;
}
public boolean ok(){
return new Integer(0).equals(code);
}
@Override
public String toString() {
return "Response{" + "code=" + code + ", message='" + message + '\'' + ", data=" + data + '}';
}
}

View File

@ -0,0 +1,40 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class Sample<ExtendsVo extends Sample<ExtendsVo>> implements Serializable {
/**样本ID**/
private String sampleId;
/**样本扩展的额外数据**/
private KeyValues sampleData;
public Sample(){}
/**
* 构建样本数据
* @param sampleId 样本ID
* @return
*/
public static Sample build(String sampleId){
return new Sample().setSampleId(sampleId);
}
public String getSampleId() {
return sampleId;
}
public ExtendsVo setSampleId(String sampleId) {
this.sampleId = sampleId;
return (ExtendsVo) this;
}
public KeyValues getSampleData() {
return sampleData;
}
public ExtendsVo setSampleData(KeyValues sampleData) {
this.sampleData = sampleData;
return (ExtendsVo) this;
}
}

View File

@ -0,0 +1,90 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class SampleFace implements Comparable<SampleFace>, Serializable {
/**样本ID**/
private String sampleId;
/**人脸ID**/
private String faceId;
/**人脸人数质量**/
private Float faceScore;
/**转换后的置信度**/
private Float distance;
/**转换后的置信度**/
private Float confidence;
/**样本扩展的额外数据**/
private KeyValues sampleData;
/**人脸扩展的额外数据**/
private KeyValues faceData;
/**
* 构造数据
* @return
*/
public static SampleFace build(){
return new SampleFace();
}
public String getSampleId() {
return sampleId;
}
public void setSampleId(String sampleId) {
this.sampleId = sampleId;
}
public KeyValues getSampleData() {
return sampleData;
}
public void setSampleData(KeyValues sampleData) {
this.sampleData = sampleData;
}
public String getFaceId() {
return faceId;
}
public void setFaceId(String faceId) {
this.faceId = faceId;
}
public KeyValues getFaceData() {
return faceData;
}
public void setFaceData(KeyValues faceData) {
this.faceData = faceData;
}
public Float getFaceScore() {
return faceScore;
}
public void setFaceScore(Float faceScore) {
this.faceScore = faceScore;
}
public Float getDistance() {
return distance;
}
public void setDistance(Float distance) {
this.distance = distance;
}
public Float getConfidence() {
return confidence;
}
public void setConfidence(Float confidence) {
this.confidence = confidence;
}
@Override
public int compareTo(SampleFace that) {
return Float.compare(that.confidence, this.confidence);
}
}

View File

@ -0,0 +1,51 @@
package com.visual.face.search.model;
import java.util.ArrayList;
import java.util.List;
public class SampleRep extends Sample<SampleRep>{
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**人脸数据**/
private List<SimpleFace> faces = new ArrayList<>();
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static SampleRep build(String namespace, String collectionName){
return new SampleRep().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public SampleRep setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public SampleRep setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
public List<SimpleFace> getFaces() {
return faces;
}
public SampleRep setFaces(List<SimpleFace> faces) {
this.faces = faces;
return this;
}
}

View File

@ -0,0 +1,37 @@
package com.visual.face.search.model;
public class SampleReq extends Sample<SampleReq>{
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static SampleReq build(String namespace, String collectionName){
return new SampleReq().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public SampleReq setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public SampleReq setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
}

View File

@ -0,0 +1,71 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable {
/**图像Base64编码值**/
private String imageBase64;
/**人脸质量分数阈值默认0,范围:[0,100]。当设置为0时会默认使用当前模型的默认值该方法为推荐使用方式**/
private Float faceScoreThreshold=0f;
/**人脸匹配分数阈值默认0,范围:[-100,100]**/
private Float confidenceThreshold=0f;
/**搜索条数默认5**/
private Integer limit=5;
/**对输入图像中多少个人脸进行检索比对默认5**/
private Integer maxFaceNum=5;
/**
* 构建检索对象
* @param imageBase64 待检索的图片
* @return
*/
public static Search build(String imageBase64){
return new Search().setImageBase64(imageBase64);
}
public String getImageBase64() {
return imageBase64;
}
public ExtendsVo setImageBase64(String imageBase64) {
this.imageBase64 = imageBase64;
return (ExtendsVo) this;
}
public Float getFaceScoreThreshold() {
return faceScoreThreshold;
}
public ExtendsVo setFaceScoreThreshold(Float faceScoreThreshold) {
this.faceScoreThreshold = faceScoreThreshold;
return (ExtendsVo) this;
}
public Float getConfidenceThreshold() {
return confidenceThreshold;
}
public ExtendsVo setConfidenceThreshold(Float confidenceThreshold) {
this.confidenceThreshold = confidenceThreshold;
return (ExtendsVo) this;
}
public Integer getLimit() {
return limit;
}
public ExtendsVo setLimit(Integer limit) {
this.limit = limit;
return (ExtendsVo) this;
}
public Integer getMaxFaceNum() {
return maxFaceNum;
}
public ExtendsVo setMaxFaceNum(Integer maxFaceNum) {
this.maxFaceNum = maxFaceNum;
return (ExtendsVo) this;
}
}

View File

@ -0,0 +1,47 @@
package com.visual.face.search.model;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
public class SearchRep implements Serializable {
/**人脸位置信息**/
private FaceLocation location;
/**人脸质量分数**/
private Float faceScore;
/**匹配的人脸列表**/
private List<SampleFace> match = new ArrayList<>();
/**
* 构建对象
* @return
*/
public static SearchRep build(){
return new SearchRep();
}
public FaceLocation getLocation() {
return location;
}
public void setLocation(FaceLocation location) {
this.location = location;
}
public Float getFaceScore() {
return faceScore;
}
public void setFaceScore(Float faceScore) {
this.faceScore = faceScore;
}
public List<SampleFace> getMatch() {
return match;
}
public void setMatch(List<SampleFace> match) {
this.match = match;
}
}

View File

@ -0,0 +1,37 @@
package com.visual.face.search.model;
public class SearchReq extends Search<SearchReq> {
/**命名空间**/
private String namespace;
/**集合名称**/
private String collectionName;
/**
* 构建集合对象
* @param namespace 命名空间
* @param collectionName 集合名称
* @return
*/
public static SearchReq build(String namespace, String collectionName){
return new SearchReq().setNamespace(namespace).setCollectionName(collectionName);
}
public String getNamespace() {
return namespace;
}
public SearchReq setNamespace(String namespace) {
this.namespace = namespace;
return this;
}
public String getCollectionName() {
return collectionName;
}
public SearchReq setCollectionName(String collectionName) {
this.collectionName = collectionName;
return this;
}
}

View File

@ -0,0 +1,44 @@
package com.visual.face.search.model;
import java.io.Serializable;
public class SimpleFace implements Serializable {
/**人脸ID**/
private String faceId;
/**人脸扩展的额外数据**/
private KeyValues faceData;
/**人脸人数质量**/
private Float faceScore;
public static SimpleFace build(String faceId){
return new SimpleFace().setFaceId(faceId);
}
public String getFaceId() {
return faceId;
}
public SimpleFace setFaceId(String faceId) {
this.faceId = faceId;
return this;
}
public KeyValues getFaceData() {
return faceData;
}
public SimpleFace setFaceData(KeyValues faceData) {
this.faceData = faceData;
return this;
}
public Float getFaceScore() {
return faceScore;
}
public SimpleFace setFaceScore(Float faceScore) {
this.faceScore = faceScore;
return this;
}
}

View File

@ -0,0 +1,44 @@
package com.visual.face.search.utils;
import java.io.*;
import org.apache.commons.codec.binary.Base64;
public class Base64Util {
public static String encode(byte[] binaryData) {
byte[] bytes = Base64.encodeBase64(binaryData);
return new String(bytes);
}
public static String encode(InputStream in) {
// 读取图片字节数组
try {
ByteArrayOutputStream swapStream = new ByteArrayOutputStream();
byte[] buff = new byte[100];
int rc;
while ((rc = in.read(buff, 0, 100)) > 0) {
swapStream.write(buff, 0, rc);
}
return encode(swapStream.toByteArray());
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (in != null) {
try {
in.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
public static String encode(String filePath){
try {
return encode(new FileInputStream(filePath));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,128 @@
package com.visual.face.search.utils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.TypeReference;
import com.alibaba.fastjson.serializer.SerializerFeature;
import java.util.List;
import java.util.Map;
public class JsonUtil {
/**
* 将Bean转化为json字符串
*
* @param obj bean对象
* @return json
*/
public static String toString(Object obj) {
return toString(obj, false, false);
}
public static String toSimpleString(Object obj) {
return toString(obj, false, true);
}
/**
* 将Bean转化为json字符串
*
* @param obj bean对象
* @param prettyFormat 是否格式化
* @return json
*/
public static String toString(Object obj, boolean prettyFormat, boolean noNull) {
if (prettyFormat) {
if (noNull) {
return JSON.toJSONString(obj, SerializerFeature.DisableCircularReferenceDetect, SerializerFeature.PrettyFormat);
} else {
return JSON.toJSONString(obj, SerializerFeature.WriteMapNullValue, SerializerFeature.WriteNullListAsEmpty, SerializerFeature.DisableCircularReferenceDetect, SerializerFeature.PrettyFormat);
}
} else {
if (noNull) {
return JSON.toJSONString(obj, SerializerFeature.DisableCircularReferenceDetect);
} else {
return JSON.toJSONString(obj, SerializerFeature.WriteMapNullValue, SerializerFeature.WriteNullListAsEmpty, SerializerFeature.DisableCircularReferenceDetect);
}
}
}
/**
* 将字符串转换为Entity
*
* @param json 数据字符串
* @param clazz Entity class
* @return
*/
public static <T> T toEntity(String json, Class<T> clazz) {
return JSON.parseObject(json, clazz);
}
/**
* 将字符串转换为Entity
*
* @param json 数据字符串
* @param typeReference Entity class
* @return
*/
public static <T> T toEntity(String json, TypeReference<T> typeReference) {
return JSON.parseObject(json, typeReference);
}
/**
* 将字符串转换为Map
*
* @param json 数据字符串
* @return Map
*/
public static Map<String, Object> toMap(String json) {
return JSON.parseObject(json, new TypeReference<Map<String, Object>>() {
});
}
/**
* 将字符串转换为List<T>
*
* @param json 数据字符串
* @param collectionClass 泛型
* @return list<T>
*/
public static <T> List<T> toList(String json, Class<T> collectionClass) {
return JSON.parseArray(json, collectionClass);
}
/**
* 将字符串转换为List<Map<String, Object>>
*
* @param json 数据字符串
* @return list<map>
*/
public static List<Map<String, Object>> toListMap(String json) {
return JSON.parseObject(json, new TypeReference<List<Map<String, Object>>>() {
});
}
/**
* 将字符串转换为Object
*
* @param json 数据字符串
* @return list<map>
*/
public static JSONObject toJsonObject(String json) {
return JSON.parseObject(json);
}
/**
* 将字符串转换为Array
*
* @param json 数据字符串
* @return list<map>
*/
public static JSONArray toJsonArray(String json) {
return JSON.parseArray(json);
}
}

View File

@ -0,0 +1,22 @@
package com.visual.face.search.base;
import java.io.File;
import java.util.Map;
import java.util.TreeMap;
public abstract class BaseTest {
public static Map<String, String> getImagePathMap(String imagePath){
Map<String, String> map = new TreeMap<>();
File file = new File(imagePath);
if(file.isFile()){
map.put(file.getName(), file.getAbsolutePath());
}else if(file.isDirectory()){
for(File tmpFile : file.listFiles()){
map.putAll(getImagePathMap(tmpFile.getPath()));
}
}
return map;
}
}

View File

@ -0,0 +1,96 @@
package com.visual.face.search.unit;
import com.visual.face.search.FaceSearch;
import com.visual.face.search.base.BaseTest;
import com.visual.face.search.model.*;
import com.visual.face.search.utils.Base64Util;
import java.io.File;
import java.util.*;
public class FaceSearchTest extends BaseTest {
public static String serverHost = "http://127.0.0.1:8080";
public static String namespace = "n1";
public static String collectionName = "c000002";
public static FaceSearch faceSearch = FaceSearch.build(serverHost, namespace, collectionName);
public static void collect(){
List<FiledColumn> sampleColumns = new ArrayList<>();
sampleColumns.add(FiledColumn.build().setName("name").setDataType(FiledDataType.STRING).setComment("姓名"));
sampleColumns.add(FiledColumn.build().setName("age").setDataType(FiledDataType.INT).setComment("年龄"));
List<FiledColumn> faceColumns = new ArrayList<>();
faceColumns.add(FiledColumn.build().setName("label_1").setDataType(FiledDataType.STRING).setComment("标签1"));
faceColumns.add(FiledColumn.build().setName("label_2").setDataType(FiledDataType.STRING).setComment("标签1"));
Collect collect = Collect.build().setCollectionComment("人脸库").setSampleColumns(sampleColumns).setFaceColumns(faceColumns);
Response<Boolean> deleteCollect = faceSearch.collect().deleteCollect();
System.out.println(deleteCollect);
Response<Boolean> createCollect = faceSearch.collect().createCollect(collect);
System.out.println(createCollect);
Response<CollectRep> getCollect = faceSearch.collect().getCollect();
System.out.println(getCollect);
Response<List<CollectRep>> collectList = faceSearch.collect().collectList();
System.out.println(collectList);
}
public static List<String> sample(){
Set<String> allIds = new HashSet<>();
List<String> updateIds = new ArrayList<>();
List<String> deleteIds = new ArrayList<>();
for(int i=0; i< 20; i++){
Sample sample = Sample.build(UUID.randomUUID().toString().toLowerCase().replace("-",""));
KeyValues sampleData = KeyValues.build().add(KeyValue.build("name", "姓名"+i), KeyValue.build("age", new Random().nextInt(80)));
sample.setSampleData(sampleData);
Response<Boolean> createSample = faceSearch.sample().createSample(sample);
System.out.println("createSample:"+i+":"+createSample);
if(new Random().nextInt(10) <= 1){
deleteIds.add(sample.getSampleId());
}
if(new Random().nextInt(10) <= 1){
updateIds.add(sample.getSampleId());
}
allIds.add(sample.getSampleId());
}
for(String id : updateIds){
Sample sample = Sample.build(id);
KeyValues sampleData = KeyValues.build().add(KeyValue.build("name", "姓名"+ new Random().nextInt(10000)), KeyValue.build("age", new Random().nextInt(80)));
sample.setSampleData(sampleData);
Response<Boolean> updateSample = faceSearch.sample().updateSample(sample);
System.out.println("updateSample:"+id+":"+updateSample);
}
for(String id : deleteIds){
Response<Boolean> deleteSample = faceSearch.sample().deleteSample(id);
System.out.println("deleteSample:"+id+":"+deleteSample);
}
for(String id : updateIds){
Response<SampleRep> getSample = faceSearch.sample().getSample(id);
System.out.println("getSample:"+id+":"+getSample);
}
Response<List<SampleRep>> getSample = faceSearch.sample().sampleList(0, 10, Order.asc);
System.out.println("getSample:"+getSample);
for(SampleRep item : getSample.getData()){
System.out.println("getSample:item:"+item);
}
//删除不存在的样本
allIds.removeAll(deleteIds);
return new ArrayList<>(allIds);
}
public static void main(String[] args) {
collect();
sample();
}
}

35
face-search-core/pom.xml Normal file
View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>face-search</artifactId>
<groupId>com.visual.face.search</groupId>
<version>1.0.0</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>face-search-core</artifactId>
<dependencies>
<dependency>
<groupId>org.openpnp</groupId>
<artifactId>opencv</artifactId>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,110 @@
package com.visual.face.search.core.base;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtLoggingLevel;
import ai.onnxruntime.OrtSession;
public abstract class BaseOnnxInfer extends OpenCVLoader{
private OrtEnvironment env;
private String[] inputNames;
private OrtSession[] sessions;
/**
* 构造函数
* @param modelPath
* @param threads
*/
public BaseOnnxInfer(String modelPath, int threads){
try {
this.env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
opts.setInterOpNumThreads(threads);
opts.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR);
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.BASIC_OPT);
this.sessions = new OrtSession[]{env.createSession(modelPath, opts)};
this.inputNames = new String[]{this.sessions[0].getInputNames().iterator().next()};
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* 构造函数
* @param modelPaths
* @param threads
*/
public BaseOnnxInfer(String[] modelPaths, int threads){
try {
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
opts.setInterOpNumThreads(threads);
opts.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR);
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.BASIC_OPT);
this.inputNames = new String[modelPaths.length];
this.sessions = new OrtSession[modelPaths.length];
for(int i=0; i< modelPaths.length; i++){
OrtSession session = env.createSession(modelPaths[i], opts);
this.sessions[i] = session;
this.inputNames[i] = session.getInputNames().iterator().next();
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
/**
* 获取环境信息
* @return
*/
public OrtEnvironment getEnv() {
return env;
}
/**
* 获取输入端的名称
* @return
*/
public String getInputName() {
return inputNames[0];
}
/**
* 获取session
* @return
*/
public OrtSession getSession() {
return sessions[0];
}
/**
* 获取输入端的名称
* @return
*/
public String[] getInputNames() {
return inputNames;
}
/**
* 获取session
* @return
*/
public OrtSession[] getSessions() {
return sessions;
}
/**
* 关闭服务
*/
public void close(){
try {
if(sessions != null){
for(OrtSession session : sessions){
session.close();
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}

View File

@ -0,0 +1,21 @@
package com.visual.face.search.core.base;
import java.util.Map;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.domain.FaceInfo.Points;
/**
* 对图像进行对齐
*/
public interface FaceAlignment {
/**
* 对图像进行对齐
* @param imageMat 图像信息
* @imagePoint
* @param params 参数信息
* @return
*/
ImageMat inference(ImageMat imageMat, Points imagePoint, Map<String, Object> params);
}

View File

@ -0,0 +1,25 @@
package com.visual.face.search.core.base;
import java.util.List;
import java.util.Map;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
/**
* 人脸检测接口
*/
public interface FaceDetection {
/**
*获取人脸信息
* @param image 图像信息
* @param scoreTh 人脸人数阈值
* @param iouTh 人脸iou阈值
* @param params 参数信息
* @return
*/
List<FaceInfo> inference(ImageMat image, float scoreTh, float iouTh, Map<String, Object> params);
}

View File

@ -0,0 +1,20 @@
package com.visual.face.search.core.base;
import java.util.Map;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.domain.FaceInfo.Points;
/**
* 人脸关键点检测
*/
public interface FaceKeyPoint {
/**
* 人脸关键点检测
* @param imageMat 图像数据
* @param params 参数信息
* @return
*/
Points inference(ImageMat imageMat, Map<String, Object> params);
}

View File

@ -0,0 +1,21 @@
package com.visual.face.search.core.base;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.domain.FaceInfo.Embedding;
import java.util.Map;
/**
* 人脸识别模型
*/
public interface FaceRecognition {
/**
* 人脸识别人脸特征向量
* @param image 图像信息
* @param params 参数信息
* @return
*/
Embedding inference(ImageMat image, Map<String, Object> params);
}

View File

@ -0,0 +1,8 @@
package com.visual.face.search.core.base;
public abstract class OpenCVLoader {
//静态加载动态链接库
static{ nu.pattern.OpenCV.loadShared(); }
}

View File

@ -0,0 +1,23 @@
package com.visual.face.search.core.common.annotation;
import java.lang.annotation.*;
import com.visual.face.search.core.common.enums.DataSourceType;
/**
* 自定义多数据源切换注解
*
* 优先级先方法后类如果方法覆盖了类上的数据源类型以方法的为准否则以类上的为准
*
* @author diven
*/
@Target({ ElementType.METHOD, ElementType.TYPE })
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface DataSource
{
/**
* 切换数据源名称
*/
public DataSourceType value() default DataSourceType.MASTER;
}

View File

@ -0,0 +1,28 @@
package com.visual.face.search.core.common.enums;
public enum CollectionStatue {
UNDEFINED(-1),
NORMAL(0);
private int value;
CollectionStatue(int value) {
this.value = value;
}
public int getValue() {
return this.value;
}
public static CollectionStatue valueOf(int value) {
switch (value) {
case 0:
return NORMAL;
default:
return UNDEFINED;
}
}
}

View File

@ -0,0 +1,18 @@
package com.visual.face.search.core.common.enums;
/**
* 数据源
*
* @author diven
*/
public enum DataSourceType {
/**
* 主库
*/
MASTER,
/**
* 从库
*/
SLAVE
}

View File

@ -0,0 +1,67 @@
package com.visual.face.search.core.domain;
import java.io.Serializable;
public class ExtParam implements Serializable {
private float scoreTh;
private float iouTh;
private float scaling;
private boolean mask;
private int topK = 5;
private ExtParam(){}
public static ExtParam build(){
return new ExtParam();
}
public float getScoreTh() {
return scoreTh;
}
public ExtParam setScoreTh(float scoreTh) {
this.scoreTh = scoreTh;
return this;
}
public float getIouTh() {
return iouTh;
}
public ExtParam setIouTh(float iouTh) {
this.iouTh = iouTh;
return this;
}
public float getScaling() {
return scaling;
}
public ExtParam setScaling(float scaling) {
this.scaling = scaling;
return this;
}
public boolean isMask() {
return mask;
}
public ExtParam setMask(boolean mask) {
this.mask = mask;
return this;
}
public int getTopK() {
return topK;
}
public ExtParam setTopK(int topK) {
this.topK = topK;
return this;
}
}

View File

@ -0,0 +1,64 @@
package com.visual.face.search.core.domain;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* 图片人脸信息
*/
public class FaceImage implements Serializable {
/**图像数据**/
public String imageBase64;
/**人脸解析数据**/
public List<FaceInfo> faceInfos;
/**
* 构建函数
* @param imageBase64 图像数据
* @param faceInfos 人脸解析数据
* @return
*/
private FaceImage(String imageBase64, List<FaceInfo> faceInfos) {
this.imageBase64 = imageBase64;
this.faceInfos = faceInfos;
}
/**
* 构建对象
* @param imageBase64 图像数据
* @param faceInfos 人脸解析数据
* @return
*/
public static FaceImage build(String imageBase64, List<FaceInfo> faceInfos){
if(faceInfos == null){
faceInfos = new ArrayList<>();
}
return new FaceImage(imageBase64, faceInfos);
}
/**
* 图像数据
* @return
*/
public String imageBase64(){
return this.imageBase64;
}
/**
* 获取图像数据
* @return
*/
public ImageMat imageMat(){
return ImageMat.fromBase64(this.imageBase64);
}
/**
* 获取人脸解析数据
* @return
*/
public List<FaceInfo> faceInfos(){
return this.faceInfos;
}
}

View File

@ -0,0 +1,449 @@
package com.visual.face.search.core.domain;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
public class FaceInfo implements Comparable<FaceInfo>, Serializable {
/**人脸分数**/
public float score;
/**人脸旋转角度**/
public float angle;
/**人脸框**/
public FaceBox box;
/**人脸关键点**/
public Points points;
/**人脸特征向量**/
public Embedding embedding;
/**
* 构造函数
* @param score 人脸分数
* @param box 人脸框
* @param points 人脸关键点
* @param angle 人脸旋转角度
* @param embedding 人脸特征向量
*/
private FaceInfo(float score, FaceBox box, Points points, float angle, Embedding embedding) {
this.score = score;
this.angle = angle;
this.box = box;
this.points = points;
this.embedding = embedding;
}
/**
* 构造一个人脸信息
* @param score 人脸分数
* @param box 人脸框
* @param points 人脸关键点
* @param angle 人脸旋转角度
*/
public static FaceInfo build(float score, float angle, FaceBox box, Points points){
return new FaceInfo(score, box, points, angle, null);
}
/**
* 构造一个人脸信息
* @param score 人脸分数
* @param box 人脸框
* @param points 人脸关键点
* @param angle 人脸旋转角度
* @param embedding 人脸特征向量
*/
public static FaceInfo build(float score, float angle, FaceBox box, Points points, Embedding embedding){
return new FaceInfo(score, box, points, angle, embedding);
}
/**
* 判断两个框的重叠率
* @param that
* @return
*/
public double iou(FaceInfo that){
float areaA = this.box.area();
float areaB = that.box.area();
float wTotal = Math.max(this.box.x2(), that.box.x2()) - Math.min(this.box.x1(), that.box.x1());
float hTotal = Math.max(this.box.y2(), that.box.y2()) - Math.min(this.box.y1(), that.box.y1());
float wOverlap = wTotal - this.box.width() - that.box.width();
float hOverlap = hTotal - this.box.height() - that.box.height();
float areaOverlap = (wOverlap >= 0 || hOverlap >= 0) ? 0 : wOverlap * hOverlap;
return 1.0 * areaOverlap / (areaA + areaB - areaOverlap);
}
/**
* 对人脸框进行旋转对应的角度
* @return
*/
public FaceBox rotateFaceBox(){
return this.box.rotate(this.angle);
}
@Override
public int compareTo(FaceInfo that) {
return Float.compare(that.score, this.score);
}
/**
* 关键点
*/
public static class Point implements Serializable {
/**坐标X的值**/
public float x;
/**坐标Y的值**/
public float y;
/**
* 构造函数
* @param x 坐标X的值
* @param y 坐标Y的值
*/
private Point(float x, float y){
this.x = x;
this.y = y;
}
/**
* 构造一个点
* @param x 坐标X的值
* @param y 坐标Y的值
* @return
*/
public static Point build(float x, float y){
return new Point(x, y);
}
/**
* 对点进行中心旋转
* @param center 中心点
* @param angle 旋转角度
* @return 旋转后的角
*/
public Point rotation(Point center, float angle){
double k = new Float(Math.toRadians(angle));
float nx1 = new Float((this.x-center.x)*Math.cos(k) +(this.y-center.y)*Math.sin(k)+center.x);
float ny1 = new Float(-(this.x-center.x)*Math.sin(k) + (this.y-center.y)*Math.cos(k)+center.y);
return new Point(nx1, ny1);
}
/**
* 计算两点之间的距离
* @param that
* @return 距离
*/
public float distance(Point that){
return (float) Math.sqrt(Math.pow((this.x-that.x), 2)+Math.pow((this.y-that.y), 2));
}
}
/**
* 关键点集合
*/
public static class Points extends ArrayList<Point>{
/**
* 构建一个集合
* @return
*/
public static Points build(){
return new Points();
}
/**
* 添加点
* @param point
* @return
*/
public Points add(Point ...point){
super.addAll(Arrays.asList(point));
return this;
}
/**
* 转换为double数组
* @return
*/
public double[][] toDoubleArray(){
double[][] arr = new double[this.size()][2];
for(int i=0; i< this.size(); i++){
arr[i][0] = this.get(i).x;
arr[i][1] = this.get(i).y;
}
return arr;
}
/**
* 对点进行中心旋转
* @param center 中心点
* @param angle 旋转角度
* @return 旋转后的角
*/
public Points rotation(Point center, float angle){
Points points = build();
for(Point item : this){
points.add(item.rotation(center, angle));
}
return points;
}
/**
* 加法操作对所有的点都加上point的值
* @param point
* @return
*/
public Points operateAdd(Point point){
Points points = build();
for(Point item : this){
float x = item.x + point.x;
float y = item.y + point.y;
points.add(Point.build(x, y));
}
return points;
}
/**
* 减法操作对所有的点都加上point的值
* @param point
* @return
*/
public Points operateSubtract(Point point){
Points points = build();
for(Point item : this){
float x = item.x - point.x;
float y = item.y - point.y;
points.add(Point.build(x, y));
}
return points;
}
/**
* 选择关键点
* @param indexes 关键点索引号
* @return 关键点集合
*/
public Points select(int ...indexes){
Points points = build();
for(int index : indexes){
points.add(this.get(index));
}
return points;
}
/**
* 乘法操作对所有的点都乘法scale的值
* @param scale
* @return
*/
public Points operateMultiply(float scale){
Points points = build();
for(Point item : this){
float x = item.x * scale;
float y = item.y * scale;
points.add(Point.build(x, y));
}
return points;
}
}
/**
* 标准坐标系下的人脸框
*/
public static class FaceBox implements Serializable {
/**左上角坐标值**/
public Point leftTop;
/**右上角坐标**/
public Point rightTop;
/**右下角坐标**/
public Point rightBottom;
/**左下角坐标**/
public Point leftBottom;
/**
* 构造函数
* @param leftTop 左上角坐标值
* @param rightTop 右上角坐标
* @param rightBottom 右下角坐标
* @param leftBottom 左下角坐标
*/
public FaceBox(Point leftTop, Point rightTop, Point rightBottom, Point leftBottom) {
this.leftTop = leftTop;
this.rightTop = rightTop;
this.rightBottom = rightBottom;
this.leftBottom = leftBottom;
}
/**
* 构造函数
* @param x1 左上角坐标X的值
* @param y1 左上角坐标Y的值
* @param x2 右下角坐标X的值
* @param y2 右下角坐标Y的值
*/
private FaceBox(float x1, float y1, float x2, float y2){
this.leftTop = Point.build(x1, y1);
this.rightTop = Point.build(x2, y1);
this.rightBottom = Point.build(x2, y2);
this.leftBottom = Point.build(x1, y2);
}
/**
* 构造一个人脸框
* @param x1 左上角坐标X的值
* @param y1 左上角坐标Y的值
* @param x2 右下角坐标X的值
* @param y2 右下角坐标Y的值
*/
public static FaceBox build(float x1, float y1, float x2, float y2){
return new FaceBox((int)x1,(int)y1,(int)x2,(int)y2);
}
/**
* x的最小坐标
* @return
*/
public float x1(){
return Math.min(Math.min(Math.min(leftTop.x, rightTop.x), rightBottom.x), leftBottom.x);
}
/**
* y的最小坐标
* @return
*/
public float y1(){
return Math.min(Math.min(Math.min(leftTop.y, rightTop.y), rightBottom.y), leftBottom.y);
}
/**
* x的最大坐标
* @return
*/
public float x2(){
return Math.max(Math.max(Math.max(leftTop.x, rightTop.x), rightBottom.x), leftBottom.x);
}
/**
* y的最大坐标
* @return
*/
public float y2(){
return Math.max(Math.max(Math.max(leftTop.y, rightTop.y), rightBottom.y), leftBottom.y);
}
/**
* 判断当前的人脸框是否是标准的人脸框即非旋转后的人脸框
* @return 否是标准的人脸框
*/
public boolean isNormal(){
if((int)leftTop.x == (int)leftBottom.x && (int)leftTop.y == (int)rightTop.y){
if((int)rightBottom.x == (int)rightTop.x && (int)rightBottom.y == (int)leftBottom.y){
return true;
}
}
return false;
}
/**
* 获取宽度
* @return
*/
public float width(){
return (float) Math.sqrt(Math.pow((rightTop.x-leftTop.x), 2)+Math.pow((rightTop.y-leftTop.y), 2));
}
/**
* 获取高度
* @return
*/
public float height(){
return (float) Math.sqrt(Math.pow((rightTop.x-rightBottom.x), 2)+Math.pow((rightTop.y-rightBottom.y), 2));
}
/**
* 获取面积
* @return
*/
public float area(){
return this.width() * this.height();
}
/**
* 中心点坐标
* @return
*/
public Point center(){
return Point.build((rightTop.x + leftBottom.x) / 2, (rightTop.y + leftBottom.y) / 2);
}
/**
* 对人脸框进行旋转对应的角度
* @param angle 旋转角
* @return
*/
public FaceBox rotate(float angle){
Point center = this.center();
Point rPoint1 = this.leftTop.rotation(center, angle);
Point rPoint2 = this.rightTop.rotation(center, angle);
Point rPoint3 = this.rightBottom.rotation(center, angle);
Point rPoint4 = this.leftBottom.rotation(center, angle);
return new FaceBox(rPoint1, rPoint2, rPoint3, rPoint4);
}
/**
* 中心缩放
* @param scale
* @return
*/
public FaceBox scaling(float scale){
//p1-p3
float length_p1_p3 = leftTop.distance(rightBottom);
float x_diff_p1_p3 = leftTop.x-rightBottom.x;
float y_diff_p1_p3 = leftTop.y-rightBottom.y;
float change_p1_p3 = length_p1_p3 * (1-scale);
float change_x_p1_p3 = change_p1_p3 * x_diff_p1_p3 / length_p1_p3 / 2;
float change_y_p1_p3 = change_p1_p3 * y_diff_p1_p3 / length_p1_p3 / 2;
//p2-p4
float length_p2_p4 = rightTop.distance(leftBottom);
float x_diff_p2_p4 = rightTop.x-leftBottom.x;
float y_diff_p2_p4 = rightTop.y-leftBottom.y;
float change_p2_p4 = length_p2_p4 * (1-scale);
float change_x_p2_p4 = change_p2_p4 * x_diff_p2_p4 / length_p2_p4 / 2;
float change_y_p2_p4 = change_p2_p4 * y_diff_p2_p4 / length_p2_p4 / 2;
//构造人脸框
return new FaceBox(
new Point(leftTop.x - change_x_p1_p3, leftTop.y - change_y_p1_p3),
new Point(rightTop.x - change_x_p2_p4, rightTop.y - change_y_p2_p4),
new Point(rightBottom.x + change_x_p1_p3, rightBottom.y + change_y_p1_p3),
new Point(leftBottom.x + change_x_p2_p4, leftBottom.y + change_y_p2_p4)
);
}
}
/**
* 人脸特征向量
*/
public static class Embedding implements Serializable {
/**当前图片的base64编码值**/
public String image;
/**当前图片的人脸向量信息**/
public float[] embeds;
/**
* 构造函数
* @param image 前图片的base64编码值
* @param embeds 当前图片的人脸向量信息
*/
private Embedding(String image, float[] embeds){
this.image = image;
this.embeds = embeds;
}
/**
* 构建人脸特征向量
* @param image 前图片的base64编码值
* @param embeds 当前图片的人脸向量信息
*/
public static Embedding build(String image, float[] embeds){
return new Embedding(image, embeds);
}
}
}

View File

@ -0,0 +1,685 @@
package com.visual.face.search.core.domain;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import com.visual.face.search.core.utils.MatUtil;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.dnn.Dnn;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import sun.misc.BASE64Decoder;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.util.ArrayList;
/**
* 图片加载工具
*/
public class ImageMat implements Serializable {
//静态加载动态链接库
static{ nu.pattern.OpenCV.loadShared(); }
private OrtEnvironment env = OrtEnvironment.getEnvironment();
//对象成员
private Mat mat;
private ImageMat(Mat mat){
this.mat = mat;
}
/**
* 读取图片并转换为Mat
* @param imagePath 图片地址
* @return
*/
public static ImageMat fromImage(String imagePath){
return new ImageMat(Imgcodecs.imread(imagePath));
}
/**
* 直接读取Mat
* @param mat 图片mat值
* @return
*/
public static ImageMat fromCVMat(Mat mat){
try {
return new ImageMat(mat);
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 读取图片并转换为Mat
* @param base64Str 图片Base64编码值
* @return
*/
public static ImageMat fromBase64(String base64Str){
InputStream inputStream = null;
try {
BASE64Decoder decoder = new BASE64Decoder();
byte[] data = decoder.decodeBuffer(base64Str);
inputStream = new ByteArrayInputStream(data);
return fromInputStream(inputStream);
}catch (Exception e){
throw new RuntimeException(e);
}finally {
if(null != inputStream){
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
/**
* 读取图片并转换为Mat
* @param inputStream 图片数据
* @return
*/
public static ImageMat fromInputStream(InputStream inputStream){
try {
BufferedImage image = ImageIO.read(inputStream);
return fromBufferedImage(image);
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 读取图片并转换为Mat
* @param image 图片数据
* @return
*/
public static ImageMat fromBufferedImage(BufferedImage image){
try {
if(image.getType() != BufferedImage.TYPE_3BYTE_BGR){
BufferedImage temp = new BufferedImage(image.getWidth(), image.getHeight(), BufferedImage.TYPE_3BYTE_BGR);
Graphics2D g = temp.createGraphics();
try {
g.setComposite(AlphaComposite.Src);
g.drawImage(image, 0, 0, null);
} finally {
g.dispose();
}
image = temp;
}
byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
Mat mat = Mat.eye(image.getHeight(), image.getWidth(), CvType.CV_8UC3);
mat.put(0, 0, pixels);
return new ImageMat(mat);
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 显示图片用于数据调试
*/
public void imShow() {
HighGui.imshow("image", mat);
HighGui.waitKey();
}
/**
*获取数据的宽度
* @return
*/
public int getWidth(){
return (int) mat.size().width;
}
/**
* 获取数据的高度
* @return
*/
public int getHeight(){
return (int) mat.size().height;
}
/**
* 克隆ImageMat
* @return
*/
public ImageMat clone(){
return ImageMat.fromCVMat(this.mat.clone());
}
/**
* 获取图像的中心点
* @return
*/
public Point center(){
return new Point(mat.size(1)/2, mat.size(0)/2);
}
/**
* 获取当前的CV Mat
* @return
*/
public Mat toCvMat() {
return mat;
}
/**
* 数据格式转换,不释放原始图片数据
* @param code Imgproc.COLOR_*
* @param release 是否释放参数mat
*/
public ImageMat cvtColorAndNoReleaseMat(int code, boolean release) {
return this.cvtColor(code, false);
}
/**
* 数据格式转换,并释放原始图片数据
* @param code Imgproc.COLOR_*
* @param release 是否释放参数mat
*/
public ImageMat cvtColorAndDoReleaseMat(int code, boolean release) {
return this.cvtColor(code, true);
}
/**
* 数据格式转换
* @param code Imgproc.COLOR_*
* @param release 是否释放参数mat
*/
private ImageMat cvtColor(int code, boolean release) {
try {
Mat dst = new Mat();
Imgproc.cvtColor(mat, dst, code);
return new ImageMat(dst);
}finally {
if(release){
this.release();
}
}
}
/**
* 重新设置图片尺寸,不释放原始图片数据
* @param width 图片宽度
* @param height 图片高度
* @return
*/
public ImageMat resizeAndNoReleaseMat(int width, int height){
return this.resize(width, height, false);
}
/**
* 重新设置图片尺寸,并释放原始图片数据
* @param width 图片宽度
* @param height 图片高度
* @return
*/
public ImageMat resizeAndDoReleaseMat(int width, int height){
return this.resize(width, height, true);
}
/**
* 重新设置图片尺寸
* @param width 图片宽度
* @param height 图片高度
* @param release 是否释放参数mat
* @return
*/
private ImageMat resize(int width, int height, boolean release){
try {
Mat dst = new Mat();
Imgproc.resize(mat, dst, new Size(width,height), 0, 0, Imgproc.INTER_AREA);
return new ImageMat(dst);
}finally {
if(release){
this.release();
}
}
}
/**
* 对图像进行预处理,不释放原始图片数据
* @param scale 图像各通道数值的缩放比例
* @param mean 用于各通道减去的值以降低光照的影响
* @param swapRB 交换RB通道默认为False.
* @return
*/
public ImageMat blobFromImageAndNoReleaseMat(double scale, Scalar mean, boolean swapRB){
return this.blobFromImage(scale, mean, swapRB, false);
}
/**
* 对图像进行预处理,并释放原始图片数据
* @param scale 图像各通道数值的缩放比例
* @param mean 用于各通道减去的值以降低光照的影响
* @param swapRB 交换RB通道默认为False.
* @return
*/
public ImageMat blobFromImageAndDoReleaseMat(double scale, Scalar mean, boolean swapRB){
return this.blobFromImage(scale, mean, swapRB, true);
}
/**
* 对图像进行预处理
* @param scale 图像各通道数值的缩放比例
* @param mean 用于各通道减去的值以降低光照的影响
* @param swapRB 交换RB通道默认为False.
* @param release 是否释放参数mat
* @return
*/
private ImageMat blobFromImage(double scale, Scalar mean, boolean swapRB, boolean release){
try {
Mat dst = Dnn.blobFromImage(mat, scale, new Size( mat.cols(), mat.rows()), mean, swapRB);
java.util.List<Mat> mats = new ArrayList<>();
Dnn.imagesFromBlob(dst, mats);
dst.release();
return new ImageMat(mats.get(0));
}finally {
if(release){
this.release();
}
}
}
/**
* 转换为base64,不释放原始图片数据
* @return
*/
public String toBase64AndNoReleaseMat(){
return toBase64(false);
}
/**
* 转换为base64,并释放原始图片数据
* @return
*/
public String toBase64AndDoReleaseMat(){
return toBase64(true);
}
/**
* 转换为base64
* @param release 是否释放参数mat
* @return
*/
private String toBase64(boolean release){
if(null != mat){
try {
return MatUtil.matToBase64(mat);
}finally {
if(release){
this.release();
}
}
}else{
return null;
}
}
/**
* 转换为整形数组,不释放原始图片数据
* @param firstChannel
* @return
*/
public int[][][][] to4dIntArrayAndNoReleaseMat(boolean firstChannel){
return this.to4dIntArray(firstChannel, false);
}
/**
* 转换为整形数组,并释放原始图片数据
* @param firstChannel
* @return
*/
public int[][][][] to4dIntArrayAndDoReleaseMat(boolean firstChannel){
return this.to4dIntArray(firstChannel, true);
}
/**
* 转换为整形数组
* @param firstChannel
* @param release 是否释放参数mat
* @return
*/
private int[][][][] to4dIntArray(boolean firstChannel, boolean release){
try {
int width = this.mat.cols();
int height = this.mat.rows();
int channel = this.mat.channels();
int[][][][] array;
if(firstChannel){
array = new int[1][channel][height][width];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][k][i][j] = (int) Math.round(c[k]);
}
}
}
}else{
array = new int[1][height][width][channel];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][i][j][k] = (int) Math.round(c[k]);
}
}
}
}
return array;
}finally {
if(release){
this.release();
}
}
}
/**
* 转换为长整形数组,不释放原始图片数据
* @param firstChannel
* @return
*/
public long[][][][] to4dLongArrayAndNoReleaseMat(boolean firstChannel){
return this.to4dLongArray(firstChannel, false);
}
/**
* 转换为长整形数组,并释放原始图片数据
* @param firstChannel
* @return
*/
public long[][][][] to4dLongArrayAndDoReleaseMat(boolean firstChannel){
return this.to4dLongArray(firstChannel, true);
}
/**
* 转换为长整形数组
* @param firstChannel
* @param release 是否释放参数mat
* @return
*/
private long[][][][] to4dLongArray(boolean firstChannel, boolean release){
try {
int width = this.mat.cols();
int height = this.mat.rows();
int channel = this.mat.channels();
long[][][][] array;
if(firstChannel){
array = new long[1][channel][height][width];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][k][i][j] = Math.round(c[k]);
}
}
}
}else{
array = new long[1][height][width][channel];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][i][j][k] = Math.round(c[k]);
}
}
}
}
return array;
}finally {
if(release){
this.release();
}
}
}
/**
* 转换为单精度形数组,不释放原始图片数据
* @param firstChannel
* @return
*/
public float[][][][] to4dFloatArrayAndNoReleaseMat(boolean firstChannel){
return this.to4dFloatArray(firstChannel, false);
}
/**
* 转换为单精度形数组,并释放原始图片数据
* @param firstChannel
* @return
*/
public float[][][][] to4dFloatArrayAndDoReleaseMat(boolean firstChannel){
return this.to4dFloatArray(firstChannel, true);
}
/**
* 转换为单精度形数组
* @param firstChannel
* @param release 是否释放参数mat
* @return
*/
private float[][][][] to4dFloatArray(boolean firstChannel, boolean release){
try {
int width = this.mat.cols();
int height = this.mat.rows();
int channel = this.mat.channels();
float[][][][] array;
if(firstChannel){
array = new float[1][channel][height][width];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][k][i][j] = (float) c[k];
}
}
}
}else{
array = new float[1][height][width][channel];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][i][j][k] = (float) c[k];
}
}
}
}
return array;
}finally {
if(release){
this.release();
}
}
}
/**
* 转换为双精度形数组,不释放原始图片数据
* @param firstChannel
* @return
*/
public double[][][][] to4dDoubleArrayAndNoReleaseMat(boolean firstChannel){
return this.to4dDoubleArray(firstChannel, false);
}
/**
* 转换为双精度形数组,并释放原始图片数据
* @param firstChannel
* @return
*/
public double[][][][] to4dDoubleArrayAndDoReleaseMat(boolean firstChannel){
return this.to4dDoubleArray(firstChannel, true);
}
/**
* 转换为双精度形数组
* @param firstChannel
* @param release 是否释放参数mat
* @return
*/
private double[][][][] to4dDoubleArray(boolean firstChannel, boolean release){
try {
int width = this.mat.cols();
int height = this.mat.rows();
int channel = this.mat.channels();
double[][][][] array;
if(firstChannel){
array = new double[1][channel][height][width];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][k][i][j] = c[k];
}
}
}
}else{
array = new double[1][height][width][channel];
for(int i=0; i<height; i++){
for(int j=0; j<width; j++){
double[] c = mat.get(i, j);
for(int k=0; k< channel; k++){
array[0][i][j][k] = c[k];
}
}
}
}
return array;
}finally {
if(release){
this.release();
}
}
}
/**
* 转换为整形OnnxTensor,不释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dIntOnnxTensorAndNoReleaseMat(boolean firstChannel){
try {
return OnnxTensor.createTensor(env, this.to4dIntArrayAndNoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为整形OnnxTensor,并释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dIntOnnxTensorAndDoReleaseMat(boolean firstChannel){
try {
return OnnxTensor.createTensor(env, this.to4dIntArrayAndDoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为长整形OnnxTensor,不释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dLongOnnxTensorAndNoReleaseMat(boolean firstChannel) {
try {
return OnnxTensor.createTensor(env, this.to4dLongArrayAndNoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为长整形OnnxTensor,并释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dLongOnnxTensorAndDoReleaseMat(boolean firstChannel) {
try {
return OnnxTensor.createTensor(env, this.to4dLongArrayAndDoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为单精度形OnnxTensor,不释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dFloatOnnxTensorAndNoReleaseMat(boolean firstChannel) {
try {
return OnnxTensor.createTensor(env, this.to4dFloatArrayAndNoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为单精度形OnnxTensor,并释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dFloatOnnxTensorAndDoReleaseMat(boolean firstChannel) {
try {
return OnnxTensor.createTensor(env, this.to4dFloatArrayAndDoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为双精度形OnnxTensor,不释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dDoubleOnnxTensorAndNoReleaseMat(boolean firstChannel) {
try {
return OnnxTensor.createTensor(env, this.to4dDoubleArrayAndNoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 转换为双精度形OnnxTensor,并释放原始图片数据
* @param firstChannel
* @return
*/
public OnnxTensor to4dDoubleOnnxTensorAndDoReleaseMat(boolean firstChannel) {
try {
return OnnxTensor.createTensor(env, this.to4dDoubleArrayAndDoReleaseMat(firstChannel));
}catch (Exception e){
throw new RuntimeException(e);
}
}
/**
* 释放资源
*/
public void release(){
if(this.mat != null){
this.mat.release();
this.mat = null;
}
}
}

View File

@ -0,0 +1,23 @@
package com.visual.face.search.core.extract;
import com.visual.face.search.core.domain.ExtParam;
import com.visual.face.search.core.domain.FaceImage;
import com.visual.face.search.core.domain.ImageMat;
import java.util.Map;
/**
* 人脸特征提取器
*/
public interface FaceFeatureExtractor {
/**
* 人脸特征提取
* @param image
* @param extParam
* @param params
* @return
*/
public FaceImage extract(ImageMat image, ExtParam extParam, Map<String, Object> params);
}

View File

@ -0,0 +1,114 @@
package com.visual.face.search.core.extract;
import java.util.List;
import java.util.Map;
import com.visual.face.search.core.base.FaceAlignment;
import com.visual.face.search.core.base.FaceDetection;
import com.visual.face.search.core.base.FaceKeyPoint;
import com.visual.face.search.core.base.FaceRecognition;
import com.visual.face.search.core.domain.ExtParam;
import com.visual.face.search.core.domain.FaceImage;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
import com.visual.face.search.core.utils.CropUtil;
import com.visual.face.search.core.utils.MaskUtil;
import org.opencv.core.Mat;
/**
* 人脸特征提取器实现
*/
public class FaceFeatureExtractorImpl implements FaceFeatureExtractor {
public static final float defScaling = 1.5f;
private FaceKeyPoint faceKeyPoint;
private FaceDetection faceDetection;
private FaceAlignment faceAlignment;
private FaceRecognition faceRecognition;
private FaceDetection backupFaceDetection;
/**
* 构造函数
* @param faceDetection 人脸识别模型
* @param backupFaceDetection 备用人脸识别模型
* @param faceKeyPoint 人脸关键点模型
* @param faceAlignment 人脸对齐模型
* @param faceRecognition 人脸特征提取模型
*/
public FaceFeatureExtractorImpl(FaceDetection faceDetection, FaceDetection backupFaceDetection, FaceKeyPoint faceKeyPoint, FaceAlignment faceAlignment, FaceRecognition faceRecognition) {
this.faceKeyPoint = faceKeyPoint;
this.faceDetection = faceDetection;
this.faceAlignment = faceAlignment;
this.faceRecognition = faceRecognition;
this.backupFaceDetection = backupFaceDetection;
}
/**
* 人脸特征提取
* @param image
* @param extParam
* @param params
* @return
*/
@Override
public FaceImage extract(ImageMat image, ExtParam extParam, Map<String, Object> params) {
//人脸识别
List<FaceInfo> faceInfos = this.faceDetection.inference(image, extParam.getScoreTh(), extParam.getIouTh(), params);
//启用备用的人脸识别
if(faceInfos.isEmpty() && null != backupFaceDetection){
faceInfos = this.backupFaceDetection.inference(image, extParam.getScoreTh(), extParam.getIouTh(), params);
}
//取人脸topK
int topK = (extParam.getTopK() > 0) ? extParam.getTopK() : 5;
if(faceInfos.size() > topK){
faceInfos = faceInfos.subList(0, topK);
}
//处理数据
for(FaceInfo faceInfo : faceInfos) {
Mat cropFace = null;
ImageMat cropImageMat = null;
ImageMat alignmentImage = null;
try {
//缩放人脸框的比例
float scaling = extParam.getScaling() <= 0 ? defScaling : extParam.getScaling();
//通过旋转角度获取正脸坐标并进行图像裁剪
FaceInfo.FaceBox box = faceInfo.rotateFaceBox().scaling(scaling);
cropFace = CropUtil.crop(image.toCvMat(), box);
//人脸标记关键点
cropImageMat = ImageMat.fromCVMat(cropFace);
FaceInfo.Points corpPoints = this.faceKeyPoint.inference(cropImageMat, params);
//还原原始图片中的关键点
FaceInfo.Point corpImageCenter = FaceInfo.Point.build((float)cropImageMat.center().x, (float)cropImageMat.center().y);
FaceInfo.Points imagePoints = corpPoints.rotation(corpImageCenter, faceInfo.angle).operateSubtract(corpImageCenter);
faceInfo.points = imagePoints.operateAdd(box.center());
//人脸对齐
alignmentImage = this.faceAlignment.inference(cropImageMat, corpPoints, params);
//判断是否需要遮罩人脸以外的区域
if(extParam.isMask()){
if(faceKeyPoint instanceof InsightCoordFaceKeyPoint){
FaceInfo.Points alignmentPoints = this.faceKeyPoint.inference(alignmentImage, params);
alignmentImage = MaskUtil.maskFor106InsightCoordModel(alignmentImage, alignmentPoints, true);
}
}
//人脸特征提取
FaceInfo.Embedding embedding = this.faceRecognition.inference(alignmentImage, params);
faceInfo.embedding = embedding;
}finally {
if(null != alignmentImage){
alignmentImage.release();
}
if(null != cropImageMat){
cropImageMat.release();
}
if(null != cropFace){
cropFace.release();
}
}
}
return FaceImage.build(image.toBase64AndNoReleaseMat(), faceInfos);
}
}

View File

@ -0,0 +1,58 @@
package com.visual.face.search.core.models;
import java.util.Collections;
import java.util.Map;
import org.opencv.core.Scalar;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtSession;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.base.BaseOnnxInfer;
import com.visual.face.search.core.base.FaceRecognition;
import com.visual.face.search.core.domain.FaceInfo.Embedding;
import com.visual.face.search.core.domain.ImageMat;
/**
* 人脸识别-人脸特征提取-512维
* git:https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch
*/
public class InsightArcFaceRecognition extends BaseOnnxInfer implements FaceRecognition {
/**
* 构造函数
* @param modelPath 模型路径
* @param threads 线程数
*/
public InsightArcFaceRecognition(String modelPath, int threads) {
super(modelPath, threads);
}
/**
* 人脸识别人脸特征向量
* @param image 图像信息
* @return
*/
@Override
public Embedding inference(ImageMat image, Map<String, Object> params) {
OnnxTensor tensor = null;
OrtSession.Result output = null;
try {
tensor = image.resizeAndNoReleaseMat(112,112)
.blobFromImageAndDoReleaseMat(1.0/127.5, new Scalar(127.5, 127.5, 127.5), true)
.to4dFloatOnnxTensorAndDoReleaseMat(true);
output = getSession().run(Collections.singletonMap(getInputName(), tensor));
float[][] embeds = (float[][]) output.get(0).getValue();
return FaceInfo.Embedding.build(image.toBase64AndNoReleaseMat(), embeds[0]);
} catch (Exception e) {
throw new RuntimeException(e);
}finally {
if(null != tensor){
tensor.close();
}
if(null != output){
output.close();
}
}
}
}

View File

@ -0,0 +1,140 @@
package com.visual.face.search.core.models;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtSession;
import com.visual.face.search.core.base.BaseOnnxInfer;
import com.visual.face.search.core.base.FaceKeyPoint;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.utils.MathUtil;
import org.apache.commons.math3.linear.RealMatrix;
import org.opencv.core.*;
import org.opencv.imgproc.Imgproc;
import java.util.Collections;
import java.util.Map;
public class InsightCoordFaceKeyPoint extends BaseOnnxInfer implements FaceKeyPoint {
private static final int[] outputSize = new int[]{192, 192};
/**
* 构造函数
* @param modelPath 模型路径
* @param threads 线程数
*/
public InsightCoordFaceKeyPoint(String modelPath, int threads) {
super(modelPath, threads);
}
/**
* 人脸关键点检测
* @param imageMat 图像数据
* @param params 参数信息
* @return
*/
@Override
public FaceInfo.Points inference(ImageMat imageMat, Map<String, Object> params) {
Mat M =null;
Mat IM = null;
Mat img = null;
OnnxTensor tensor = null;
OrtSession.Result output = null;
try {
Mat image = imageMat.toCvMat();
int w = image.size(1);
int h = image.size(0);
float cx = 1.0f * w / 2;
float cy = 1.0f * h / 2;
float[]center = new float[]{cx, cy};
float rotate = 0;
float _scale = (float) (1.0f * outputSize[0] / (Math.max(w, h)*1.5));
Mat[] transform = transform(image, center, outputSize, _scale, rotate);
img = transform[0];
M = transform[1];
tensor = ImageMat.fromCVMat(img)
.blobFromImageAndDoReleaseMat(1.0, new Scalar(0, 0, 0), true)
.to4dFloatOnnxTensorAndDoReleaseMat(true);
output = this.getSession().run(Collections.singletonMap(this.getInputName(), tensor));
float[] value = ((float[][]) output.get(0).getValue())[0];
float points[][] = new float[106][2];
for(int i=0; i< 106; i++){
points[i][0] = (value[2*i] + 1) * 96;
points[i][1] = (value[2*i + 1] + 1) * 96;
}
IM = new Mat();
Imgproc.invertAffineTransform(M, IM);
return transPoints(points, IM);
} catch (Exception e) {
throw new RuntimeException(e);
}finally {
if(null != tensor){
tensor.close();
}
if(null != output){
output.close();
}
if(null != M){
M.release();
}
if(null != IM){
IM.release();
}
if(null != img){
img.release();
}
}
}
/**
* 获取人脸数据和仿射矩阵
* @param image
* @param center
* @param outputSize
* @param scale
* @param rotation
* @return
*/
private static Mat[] transform(Mat image, float[]center, int[]outputSize, float scale, float rotation){
double scale_ratio = scale;
double rot = rotation * Math.PI / 180.0;
double cx = center[0] * scale_ratio;
double cy = center[1] * scale_ratio;
//矩阵构造
RealMatrix t1 = MathUtil.similarityTransform((Double[][]) null, scale_ratio, null, null);
RealMatrix t2 = MathUtil.similarityTransform((Double[][]) null, null, null, new Double[]{- cx, - cy});
RealMatrix t3 = MathUtil.similarityTransform((Double[][]) null, null, rot, null);
RealMatrix t4 = MathUtil.similarityTransform((Double[][]) null, null, null, new Double[]{1.0*outputSize[0]/2, 1.0*outputSize[1]/2});
RealMatrix tx = MathUtil.dotProduct(t4, MathUtil.dotProduct(t3, MathUtil.dotProduct(t2, t1)));
RealMatrix tm = tx.getSubMatrix(0, 1, 0, 2);
//仿射矩阵
Mat matMTemp = new MatOfDouble(MathUtil.flatMatrix(tm, 1).toArray());
Mat matM = new Mat(2, 3, CvType.CV_32FC3);
matMTemp.reshape(1,2).copyTo(matM);
matMTemp.release();
//使用open cv做仿射变换
Mat dst = new Mat();
Imgproc.warpAffine(image, dst, matM, new Size(outputSize[0], outputSize[1]));
return new Mat[]{dst, matM};
}
/**
* 关键点转换
* @param points 预测的点
* @param mat 仿射矩阵
* @return 标记的关键点
*/
private static FaceInfo.Points transPoints(float points[][], Mat mat){
int length = points.length;
FaceInfo.Points pointList = FaceInfo.Points.build();
double M[][] = ImageMat.fromCVMat(mat).to4dDoubleArrayAndNoReleaseMat(true)[0][0];
for(int i=0; i<length; i++){
float pt[] = points[i];
double new_pt [] = new double[]{pt[0], pt[1], 1.0f};
double x = MathUtil.dotProduct(MathUtil.createVector(M[0]), MathUtil.createVector(new_pt));
double y = MathUtil.dotProduct(MathUtil.createVector(M[1]), MathUtil.createVector(new_pt));
pointList.add(FaceInfo.Point.build((float)x, (float)y));
}
return pointList;
}
}

View File

@ -0,0 +1,150 @@
package com.visual.face.search.core.models;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.visual.face.search.core.base.BaseOnnxInfer;
import com.visual.face.search.core.base.FaceDetection;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import org.opencv.core.Scalar;
import java.util.*;
/**
* 人脸识别-SCRFD
* git:https://github.com/deepinsight/insightface/tree/master/detection/scrfd
*/
public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDetection {
//图像的最大尺寸
private final static int maxSizeLength = 640;
//模型人脸检测的步长
private final static int[] strides = new int[]{8, 16, 32};
//人脸预测分数阈值
public final static float defScoreTh = 0.5f;
//人脸重叠iou阈值
public final static float defIouTh = 0.7f;
/**
* 构造函数
* @param modelPath 模型路径
* @param threads 线程数
*/
public InsightScrfdFaceDetection(String modelPath, int threads) {
super(modelPath, threads);
}
/**
*获取人脸信息
* @param image 图像信息
* @param scoreTh 人脸人数阈值
* @param iouTh 人脸iou阈值
* @return 人脸模型
*/
@Override
public List<FaceInfo> inference(ImageMat image, float scoreTh, float iouTh, Map<String, Object> params) {
OnnxTensor tensor = null;
OrtSession.Result output = null;
ImageMat imageMat = image.clone();
try {
float imgScale = 1.0f;
iouTh = iouTh <= 0 ? defIouTh : iouTh;
scoreTh = scoreTh <= 0 ? defScoreTh : scoreTh;
int imageWidth = imageMat.getWidth(), imageHeight = imageMat.getHeight();
int modelWidth = imageWidth, modelHeight = imageHeight;
if(imageWidth > maxSizeLength || imageHeight > maxSizeLength){
if(imageWidth > imageHeight){
modelWidth = maxSizeLength;
imgScale = 1.0f * imageWidth / maxSizeLength;
modelHeight = imageHeight * maxSizeLength / imageWidth;
}else {
modelHeight = maxSizeLength ;
imgScale = 1.0f * imageHeight / maxSizeLength;
modelWidth = modelWidth * maxSizeLength / imageHeight;
}
imageMat = imageMat.resizeAndDoReleaseMat(modelWidth, modelHeight);
}
tensor = imageMat
.blobFromImageAndDoReleaseMat(1.0/128, new Scalar(127.5, 127.5, 127.5), true)
.to4dFloatOnnxTensorAndDoReleaseMat(true);
output = getSession().run(Collections.singletonMap(getInputName(), tensor));
return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale);
} catch (Exception e) {
throw new RuntimeException(e);
}finally {
if(null != tensor){
tensor.close();
}
if(null != output){
output.close();
}
if(null != imageMat){
imageMat.release();
}
}
}
/**
* 过滤人脸框
* @param output 数据输出
* @param scoreTh 人脸分数阈值
* @param iouTh 人脸重叠阈值
* @param tensorWidth 输出层的宽度
* @param imgScale 图像的缩放比例
* @return
* @throws OrtException
*/
private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale) throws OrtException {
//分数过滤及计算正确的人脸框值
List<FaceInfo> faceInfos = new ArrayList<>();
for(int index=0; index< 3; index++) {
float[][] scores = (float[][]) output.get(index).getValue();
float[][] boxes = (float[][]) output.get(index + 3).getValue();
float[][] points = (float[][]) output.get(index + 6).getValue();
int ws = (int) Math.ceil(1.0f * tensorWidth / strides[index]);
for(int i=0; i< scores.length; i++){
if(scores[i][0] >= scoreTh){
int anchorIndex = i / 2;
int rowNum = anchorIndex / ws;
int colNum = anchorIndex % ws;
//计算人脸框
float anchorX = colNum * strides[index];
float anchorY = rowNum * strides[index];
float x1 = (anchorX - boxes[i][0] * strides[index]) * imgScale;
float y1 = (anchorY - boxes[i][1] * strides[index]) * imgScale;
float x2 = (anchorX + boxes[i][2] * strides[index]) * imgScale;
float y2 = (anchorY + boxes[i][3] * strides[index]) * imgScale;
//计算关键点
float [] point = points[i];
FaceInfo.Points keyPoints = FaceInfo.Points.build();
for(int pointIndex=0; pointIndex<(point.length/2); pointIndex++){
float pointX = (point[2*pointIndex] * strides[index] + anchorX) * imgScale;
float pointY = (point[2*pointIndex+1] * strides[index] + anchorY) * imgScale;
keyPoints.add(FaceInfo.Point.build(pointX, pointY));
}
faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2), keyPoints));
}
}
}
//对人脸框进行iou过滤
Collections.sort(faceInfos);
List<FaceInfo> faces = new ArrayList<>();
while(!faceInfos.isEmpty()){
Iterator<FaceInfo> iterator = faceInfos.iterator();
//获取第一个元素并删除元素
FaceInfo firstFace = iterator.next();
iterator.remove();
//对比后面元素与第一个元素之间的iou
while (iterator.hasNext()) {
FaceInfo nextFace = iterator.next();
if(firstFace.iou(nextFace) >= iouTh){
iterator.remove();
}
}
faces.add(firstFace);
}
//返回
return faces;
}
}

View File

@ -0,0 +1,648 @@
package com.visual.face.search.core.models;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.visual.face.search.core.base.BaseOnnxInfer;
import com.visual.face.search.core.base.FaceDetection;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import org.opencv.core.*;
import org.opencv.imgproc.Imgproc;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* 人脸识别-PCN
* git:
* https://github.com/Rock-100/FaceKit/tree/master/PCN
* https://github.com/siriusdemon/pytorch-PCN
*/
public class PcnNetworkFaceDetection extends BaseOnnxInfer implements FaceDetection {
//常量参数
private final static float stride_ = 8;
private final static float minFace_ = 28;
private final static float scale_ = 1.414f;
private final static float angleRange_ = 45f;
private final static double EPS = 1e-5;
//人脸预测分数阈值
public final static float defScoreTh = 0.8f;
//人脸重叠iou阈值
public final static float defIouTh = 0.6f;
//人脸预测分数阈值
public static float[] defScoreThs = new float[]{0.3f, 0.4f, defScoreTh};
//人脸重叠iou阈值
public static float[] defIouThs = new float[]{defIouTh, defIouTh, 0.3f};
/**
* 构造函数
* @param modelPaths 模型路径
* @param threads 线程数
*/
public PcnNetworkFaceDetection(String[] modelPaths, int threads) {
super(modelPaths, threads);
}
/**
*获取人脸信息
* @param image 图像信息
* @param scoreTh 人脸人数阈值
* @param iouTh 人脸iou阈值
* @return 人脸模型
*/
@Override
public List<FaceInfo> inference(ImageMat image, float scoreTh, float iouTh, Map<String, Object> params) {
Mat mat = null;
Mat imgPad = null;
ImageMat imageMat = image.clone();
try {
mat = imageMat.toCvMat();
imgPad = pad_img_not_release_mat(mat);
float[] iouThs = iouTh <= 0 ? defIouThs : new float[]{iouTh, iouTh, 0.3f};
float[] scoreThs = scoreTh <= 0 ? defScoreThs : new float[]{0.375f * scoreTh, 0.5f * scoreTh, scoreTh};
List<PcnNetworkFaceDetection.Window2> willis = detect(this.getSessions(), mat, imgPad, scoreThs, iouThs);
return trans_window(mat, imgPad, willis);
} catch (Exception e) {
throw new RuntimeException(e);
}finally {
if(null != mat){
mat.release();
}
if(null != imgPad){
imgPad.release();
}
if(null != imageMat){
imageMat.release();
}
}
}
/********************************分割线*************************************/
private static Mat pad_img_not_release_mat(Mat mat){
int row = Math.min((int)(mat.size().height * 0.2), 100);
int col = Math.min((int)(mat.size().width * 0.2), 100);
Mat dst = new Mat();
Core.copyMakeBorder(mat, dst, row, row, col, col, Core.BORDER_CONSTANT);
return dst;
}
private static Mat resize_img(Mat mat, float scale){
double h = mat.size().height;
double w = mat.size().width;
int h_ = (int) (h / scale);
int w_ = (int) (w / scale);
Mat matF32 = new Mat();
if(mat.type() != CvType.CV_32FC3){
mat.convertTo(matF32, CvType.CV_32FC3);
}else{
mat.copyTo(matF32);
}
Mat dst = new Mat();
Imgproc.resize(matF32, dst, new Size(w_, h_), 0,0, Imgproc.INTER_NEAREST);
mat.release();
matF32.release();
return dst;
}
private static Mat preprocess_img(Mat mat, int dim){
Mat matTmp = new Mat();
if(dim > 0){
Imgproc.resize(mat, matTmp, new Size(dim, dim), 0, 0, Imgproc.INTER_NEAREST);
}else{
mat.copyTo(matTmp);
}
//格式转化
Mat matF32 = new Mat();
if(mat.type() != CvType.CV_32FC3){
matTmp.convertTo(matF32, CvType.CV_32FC3);
}else{
matTmp.copyTo(matF32);
}
Mat dst = new Mat();
Core.subtract(matF32, new Scalar(104, 117, 123), dst);
mat.release();
matTmp.release();
matF32.release();
return dst;
}
private static OnnxTensor set_input(Mat mat){
Mat dst = null;
try {
dst = new Mat();
mat.copyTo(dst);
return ImageMat.fromCVMat(dst).to4dFloatOnnxTensorAndDoReleaseMat(true);
}finally {
if(null != dst){
dst.release();
}
}
}
private static OnnxTensor set_input(List<Mat> mats){
float[][][][] arrays = new float[mats.size()][][][];
for(int i=0; i< mats.size(); i++){
Mat dst = new Mat();
mats.get(i).copyTo(dst);
float[][][][] array = ImageMat.fromCVMat(dst).to4dFloatArrayAndDoReleaseMat(true);
arrays[i] = array[0];
dst.release();
}
try {
return OnnxTensor.createTensor(OrtEnvironment.getEnvironment(), arrays);
}catch (Exception e){
throw new RuntimeException(e);
}
}
private static boolean legal(int x, int y, Mat mat){
if(0 <= x && x < mat.size().width && 0 <= y && y< mat.size().height){
return true;
}else{
return false;
}
}
private static boolean inside(int x, int y, Window2 rect){
if(rect.x <= x && x < (rect.x + rect.w) && rect.y <= y && y< (rect.y + rect.h)){
return true;
}else{
return false;
}
}
private static float IoU(Window2 w1, Window2 w2){
float xOverlap = Math.max(0, Math.min(w1.x + w1.w - 1, w2.x + w2.w - 1) - Math.max(w1.x, w2.x) + 1);
float yOverlap = Math.max(0, Math.min(w1.y + w1.h - 1, w2.y + w2.h - 1) - Math.max(w1.y, w2.y) + 1);
float intersection = xOverlap * yOverlap;
float unio = w1.w * w1.h + w2.w * w2.h - intersection;
return intersection / unio;
}
private static List<Window2> NMS(List<Window2> winlist, boolean local, float threshold){
if(winlist==null || winlist.isEmpty()){
return new ArrayList<>();
}
//排序
Collections.sort(winlist);
int [] flag = new int[winlist.size()];
for(int i=0; i< winlist.size(); i++){
if(flag[i] > 0){
continue;
}
for(int j=i+1; j<winlist.size(); j++){
if(local && Math.abs(winlist.get(i).scale - winlist.get(j).scale) > EPS){
continue;
}
if(IoU(winlist.get(i), winlist.get(j)) > threshold){
flag[j] = 1;
}
}
}
List<Window2> list = new ArrayList<>();
for(int i=0; i< flag.length; i++){
if(flag[i] == 0){
list.add(winlist.get(i));
}
}
return list;
}
private static List<Window2> deleteFP(List<Window2> winlist){
if (winlist == null || winlist.isEmpty()) {
return new ArrayList<>();
}
//排序
Collections.sort(winlist);
int [] flag = new int[winlist.size()];
for(int i=0; i< winlist.size(); i++){
if(flag[i] > 0){
continue;
}
for(int j=i+1; j<winlist.size(); j++){
Window2 win = winlist.get(j);
if(inside(win.x, win.y, winlist.get(i)) && inside(win.x + win.w - 1, win.y + win.h - 1, winlist.get(i))){
flag[j] = 1;
}
}
}
List<Window2> list = new ArrayList<>();
for(int i=0; i< flag.length; i++){
if(flag[i] == 0){
list.add(winlist.get(i));
}
}
return list;
}
private static List<FaceInfo> trans_window(Mat img, Mat imgPad, List<Window2> winlist){
int row = (imgPad.size(0) - img.size(0)) / 2;
int col = (imgPad.size(1) - img.size(1)) / 2;
List<FaceInfo> ret = new ArrayList<>();
for(Window2 win : winlist){
if( win.w > 0 && win.h > 0){
int x1 = win.x - col;
int y1 = win.y - row;
int x2 = win.x - col + win.w;
int y2 = win.y - row + win.w;
int angle = (win.angle + 360) % 360;
//扩展人脸高度
float rw = 0f;
float rh = 0.1f;
int w = Math.abs(x2 - x1);
int h = Math.abs(y2 - y1);
x1 = Math.max(x1 - (int)(w * rw), 1);
y1 = Math.max(y1 - (int)(h * rh), 1);
x2 = Math.min(x2 + (int)(w * rw), img.size(1)-1);
y2 = Math.min(y2 + (int)(h * rh), img.size(0)-1);
//构建人脸信息
FaceInfo faceInfo = FaceInfo.build(win.conf, angle, FaceInfo.FaceBox.build(x1, y1, x2, y2), FaceInfo.Points.build());
ret.add(faceInfo);
}
}
return ret;
}
/**
* 验证通过
* @param img
* @param imgPad
* @param net
* @param thres
* @return
* @throws OrtException
* @throws IOException
*/
private static List<Window2> stage1(Mat img, Mat imgPad, OrtSession net, float thres) throws RuntimeException {
int netSize = 24;
float curScale = minFace_ / netSize;
int row = (int) ((imgPad.size().height - img.size().height) / 2);
int col = (int) ((imgPad.size().width - img.size().width) / 2);
Mat img_resized = null;
OnnxTensor net_input1 = null;
OrtSession.Result output = null;
List<Window2> winlist = new ArrayList<>();
try {
img_resized = resize_img(img.clone(), curScale);
while(Math.min(img_resized.size().height, img_resized.size().width) >= netSize){
img_resized = preprocess_img(img_resized, 0);
net_input1 = set_input(img_resized);
output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), net_input1));
float[][][][] cls_prob = (float[][][][]) output.get(0).getValue();
float[][][][] rotate = (float[][][][]) output.get(1).getValue();
float[][][][] bbox = (float[][][][]) output.get(2).getValue();
//关闭对象
if(null != net_input1){
net_input1.close();
net_input1 = null;
}
if(null != output){
output.close();
output = null;
}
//计算业务逻辑
float w = netSize * curScale;
for(int i=0; i< cls_prob[0][0].length; i++){
for(int j=0; j< cls_prob[0][0][0].length; j++){
if(cls_prob[0][1][i][j] > thres){
float sn = bbox[0][0][i][j];
float xn = bbox[0][1][i][j];
float yn = bbox[0][2][i][j];
int rx = (int)(j * curScale * stride_ - 0.5 * sn * w + sn * xn * w + 0.5 * w) + col;
int ry = (int)(i * curScale * stride_ - 0.5 * sn * w + sn * yn * w + 0.5 * w) + row;
int rw = (int)(w * sn);
if (legal(rx, ry, imgPad) && legal(rx + rw - 1, ry + rw - 1, imgPad)){
if (rotate[0][1][i][j] > 0.5){
winlist.add(new Window2(rx, ry, rw, rw, 0, curScale, cls_prob[0][1][i][j]));
}else{
winlist.add(new Window2(rx, ry, rw, rw, 180, curScale, cls_prob[0][1][i][j]));
}
}
}
}
}
img_resized = resize_img(img_resized, scale_);
curScale = (float) (img.size().height / img_resized.size().height);
}
}catch (Exception e){
throw new RuntimeException(e);
}finally {
if(null != net_input1){
net_input1.close();
}
if(null != output){
output.close();
}
if(null != img_resized){
img_resized.release();
}
}
//返回
return winlist;
}
/**
* 验证通过
* @param img
* @param img180
* @param net
* @param thres
* @param dim
* @param winlist
* @return
* @throws OrtException
*/
private static List<Window2> stage2(Mat img, Mat img180, OrtSession net, float thres, int dim, List<Window2> winlist) throws OrtException {
if(winlist==null || winlist.isEmpty()){
return new ArrayList<>();
}
int height = img.size(0);
List<Mat> datalist = new ArrayList<>();
for(Window2 win : winlist){
if(Math.abs(win.angle) < EPS){
Mat cloneMat = img.clone();
Mat corp = new Mat(cloneMat, new Rect(win.x, win.y, win.w, win.h));
Mat corp1 = preprocess_img(corp, dim);
datalist.add(corp1);
if(null != corp){
corp.release();
}
if(null != cloneMat){
cloneMat.release();
}
}else{
int y2 = win.y + win.h - 1;
int y = height - 1 - y2;
Mat cloneMat = img180.clone();
Mat corp = new Mat(cloneMat, new Rect(win.x, y, win.w, win.h));
Mat corp1 = preprocess_img(corp, dim);
datalist.add(corp1);
if(null != corp){
corp.release();
}
if(null != cloneMat){
cloneMat.release();
}
}
}
OnnxTensor net_input = set_input(datalist);
OrtSession.Result output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), net_input));
float[][] cls_prob = (float[][]) output.get(0).getValue();
float[][] rotate = (float[][]) output.get(1).getValue();
float[][] bbox = (float[][]) output.get(2).getValue();
//关闭对象
for(Mat mat : datalist){
mat.release();
}
if(null != net_input){
net_input.close();
}
if(null != output){
output.close();
}
List<Window2> ret = new ArrayList<>();
for(int i=0; i<winlist.size(); i++){
if(cls_prob[i][1] > thres){
float sn = bbox[i][0];
float xn = bbox[i][1];
float yn = bbox[i][2];
float cropX = winlist.get(i).x;
float cropY = winlist.get(i).y;
float cropW = winlist.get(i).w;
if(Math.abs(winlist.get(i).angle) > EPS){
cropY = height - 1 - (cropY + cropW - 1);
}
int w = (int)(sn * cropW);
int x = (int)(cropX - 0.5 * sn * cropW + cropW * sn * xn + 0.5 * cropW);
int y = (int)(cropY - 0.5 * sn * cropW + cropW * sn * yn + 0.5 * cropW);
float maxRotateScore = 0;
int maxRotateIndex = 0;
for(int j=0; j<3; j++){
if(rotate[i][j] > maxRotateScore){
maxRotateScore = rotate[i][j];
maxRotateIndex = j;
}
}
if(legal(x, y, img) && legal(x + w - 1, y + w - 1, img)){
int angle = 0;
if(Math.abs(winlist.get(i).angle) < EPS){
if(maxRotateIndex == 0){
angle = 90;
}else if(maxRotateIndex == 1){
angle = 0;
}else{
angle = -90;
}
ret.add(new Window2(x, y, w, w, angle, winlist.get(i).scale, cls_prob[i][1]));
}else{
if(maxRotateIndex == 0){
angle = 90;
}else if(maxRotateIndex == 1){
angle = 180;
}else{
angle = -90;
}
ret.add(new Window2(x, height - 1 - (y + w - 1), w, w, angle, winlist.get(i).scale, cls_prob[i][1]));
}
}
}
}
return ret;
}
/**
* 验证通过
* @param imgPad
* @param img180
* @param img90
* @param imgNeg90
* @param net
* @param thres
* @param dim
* @param winlist
* @return
* @throws OrtException
*/
private static List<Window2> stage3(Mat imgPad, Mat img180, Mat img90, Mat imgNeg90, OrtSession net, float thres, int dim, List<Window2> winlist) throws OrtException {
if (winlist == null || winlist.isEmpty()) {
return new ArrayList<>();
}
int height = imgPad.size(0);
int width = imgPad.size(1);
List<Mat> datalist = new ArrayList<>();
for(Window2 win : winlist){
if(Math.abs(win.angle) < EPS){
Mat corp = new Mat(imgPad, new Rect(win.x, win.y, win.w, win.h));
datalist.add(preprocess_img(corp, dim));
}else if(Math.abs(win.angle - 90) < EPS){
Mat corp = new Mat(img90, new Rect(win.y, win.x, win.h, win.w));
datalist.add(preprocess_img(corp, dim));
}else if(Math.abs(win.angle + 90) < EPS){
int x = win.y;
int y = width - 1 - (win.x + win.w - 1);
Mat corp = new Mat(imgNeg90, new Rect(x, y, win.w, win.h));
datalist.add(preprocess_img(corp, dim));
}else{
int y2 = win.y + win.h - 1;
int y = height - 1 - y2;
Mat corp = new Mat(img180, new Rect(win.x, y, win.w, win.h));
datalist.add(preprocess_img(corp, dim));
}
}
OnnxTensor net_input = set_input(datalist);
OrtSession.Result output = net.run(Collections.singletonMap(net.getInputNames().iterator().next(), net_input));
float[][] cls_prob = (float[][]) output.get(0).getValue();
float[][] rotate = (float[][]) output.get(1).getValue();
float[][] bbox = (float[][]) output.get(2).getValue();
//关闭对象
for(Mat mat : datalist){
mat.release();
}
if(null != net_input){
net_input.close();
}
if(null != output){
output.close();
}
List<Window2> ret = new ArrayList<>();
for(int i=0; i<winlist.size(); i++) {
if (cls_prob[i][1] > thres) {
float sn = bbox[i][0];
float xn = bbox[i][1];
float yn = bbox[i][2];
float cropX = winlist.get(i).x;
float cropY = winlist.get(i).y;
float cropW = winlist.get(i).w;
Mat img_tmp = imgPad;
if (Math.abs(winlist.get(i).angle - 180) < EPS) {
cropY = height - 1 - (cropY + cropW - 1);
img_tmp = img180;
}else if (Math.abs(winlist.get(i).angle - 90) < EPS) {
cropX = winlist.get(i).y;
cropY = winlist.get(i).x;
img_tmp = img90;
}else if (Math.abs(winlist.get(i).angle + 90) < EPS) {
cropX = winlist.get(i).y;
cropY = width - 1 - (winlist.get(i).x + winlist.get(i).w - 1);
img_tmp = imgNeg90;
}
int w = (int) (sn * cropW);
int x = (int) (cropX - 0.5 * sn * cropW + cropW * sn * xn + 0.5 * cropW);
int y = (int) (cropY - 0.5 * sn * cropW + cropW * sn * yn + 0.5 * cropW);
int angle = (int)(angleRange_ * rotate[i][0]);
if(legal(x, y, img_tmp) && legal(x + w - 1, y + w - 1, img_tmp)){
if(Math.abs(winlist.get(i).angle) < EPS){
ret.add(new Window2(x, y, w, w, angle, winlist.get(i).scale, cls_prob[i][1]));
}else if(Math.abs(winlist.get(i).angle - 180) < EPS){
ret.add(new Window2(x, height - 1 - (y + w - 1), w, w, 180 - angle, winlist.get(i).scale, cls_prob[i][1]));
}else if(Math.abs(winlist.get(i).angle - 90) < EPS){
ret.add(new Window2(y, x, w, w, 90 - angle, winlist.get(i).scale, cls_prob[i][1]));
}else{
ret.add(new Window2(width - y - w, x, w, w, -90 + angle, winlist.get(i).scale, cls_prob[i][1]));
}
}
}
}
return ret;
}
/**
* 验证通过
* @param sessions
* @param img
* @param imgPad
* @return
* @throws OrtException
* @throws IOException
*/
private static List<PcnNetworkFaceDetection.Window2> detect(OrtSession[] sessions, Mat img, Mat imgPad, float[] scoreThs, float iouThs[]) throws OrtException, IOException {
Mat img180 = new Mat();
Core.flip(imgPad, img180, 0);
Mat img90 = new Mat();
Core.transpose(imgPad, img90);
Mat imgNeg90 = new Mat();
Core.flip(img90, imgNeg90, 0);
List<PcnNetworkFaceDetection.Window2> winlist = stage1(img, imgPad, sessions[0], scoreThs[0]);
winlist = NMS(winlist, true, iouThs[0]);
winlist = stage2(imgPad, img180, sessions[1], scoreThs[1], 24, winlist);
winlist = NMS(winlist, true, iouThs[1]);
winlist = stage3(imgPad, img180, img90, imgNeg90, sessions[2], scoreThs[2], 48, winlist);
winlist = NMS(winlist, false, iouThs[2]);
winlist = deleteFP(winlist);
img90.release();
img180.release();
imgNeg90.release();
return winlist;
}
/**
* 临时的人脸框
*/
private static class Window2 implements Comparable<PcnNetworkFaceDetection.Window2>{
public int x;
public int y;
public int w;
public int h;
public int angle;
public float scale;
public float conf;
public Window2(int x, int y, int w, int h, int angle, float scale, float conf) {
this.x = x;
this.y = y;
this.w = w;
this.h = h;
this.angle = angle;
this.scale = scale;
this.conf = conf;
}
@Override
public int compareTo(PcnNetworkFaceDetection.Window2 o) {
if(o.conf == this.conf){
return new Integer(this.y).compareTo(o.y);
}else{
return new Float(o.conf).compareTo(this.conf);
}
}
@Override
public String toString() {
return "Window2{"
+ "x=" + x +
", y=" + y +
", w=" + w +
", h=" + h +
", angle=" + angle +
", scale=" + scale +
", conf=" + conf +
'}' +"\n";
}
}
}

View File

@ -0,0 +1,46 @@
package com.visual.face.search.core.models;
import com.visual.face.search.core.base.FaceAlignment;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.utils.AlignUtil;
import org.opencv.core.Mat;
import java.util.Map;
/**
* 五点对齐法
*/
public class Simple005pFaceAlignment implements FaceAlignment {
/**对齐矩阵**/
private final static double[][] dst_points = new double[][]{
{30.2946f + 8.0000f, 51.6963f},
{65.5318f + 8.0000f, 51.6963f},
{48.0252f + 8.0000f, 71.7366f},
{33.5493f + 8.0000f, 92.3655f},
{62.7299f + 8.0000f, 92.3655f}
};
/**
* 对图像进行对齐
* @param imageMat 图像信息
* @imagePoint 图像的关键点
* @param params 参数信息
* @return
*/
@Override
public ImageMat inference(ImageMat imageMat, FaceInfo.Points imagePoint, Map<String, Object> params) {
double [][] image_points;
if(imagePoint.size() == 5){
image_points = imagePoint.toDoubleArray();
}else if(imagePoint.size() == 106){
image_points = imagePoint.select(38, 88, 80, 52, 61).toDoubleArray();
}else{
throw new RuntimeException("need 5 point, but get "+ imagePoint.size());
}
Mat alignMat = AlignUtil.alignedImage(imageMat.toCvMat(), image_points, 112, 112, dst_points);
return ImageMat.fromCVMat(alignMat);
}
}

View File

@ -0,0 +1,138 @@
package com.visual.face.search.core.models;
import com.visual.face.search.core.base.FaceAlignment;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.utils.AlignUtil;
import org.opencv.core.Mat;
import java.util.Map;
public class Simple106pFaceAlignment implements FaceAlignment {
/**矫正的偏移**/
private final static double x_offset = 0;
private final static double y_offset = -8;
/**对齐矩阵**/
private final static double[][] dst_points = new double[][]{
{x_offset + 55.9033, y_offset + 109.5666},
{x_offset + 11.1944, y_offset + 34.8529},
{x_offset + 23.4634, y_offset + 84.9067},
{x_offset + 26.6072, y_offset + 89.7229},
{x_offset + 30.1978, y_offset + 94.2239},
{x_offset + 34.1534, y_offset + 98.4295},
{x_offset + 38.4549, y_offset + 102.5791},
{x_offset + 43.3488, y_offset + 106.149},
{x_offset + 49.0901, y_offset + 108.6559},
{x_offset + 11.1362, y_offset + 40.8701},
{x_offset + 11.5769, y_offset + 46.7289},
{x_offset + 12.4184, y_offset + 52.4678},
{x_offset + 13.5954, y_offset + 58.0757},
{x_offset + 15.0485, y_offset + 63.6245},
{x_offset + 16.6973, y_offset + 69.0877},
{x_offset + 18.5132, y_offset + 74.5257},
{x_offset + 20.7569, y_offset + 79.817},
{x_offset + 99.9696, y_offset + 33.7004},
{x_offset + 88.8002, y_offset + 84.668},
{x_offset + 85.5828, y_offset + 89.5583},
{x_offset + 81.9019, y_offset + 94.1172},
{x_offset + 77.8414, y_offset + 98.3508},
{x_offset + 73.5045, y_offset + 102.5337},
{x_offset + 68.579 , y_offset + 106.1167},
{x_offset + 62.8033, y_offset + 108.6437},
{x_offset + 100.1725, y_offset + 39.7505},
{x_offset + 99.9274, y_offset + 45.6823},
{x_offset + 99.2375, y_offset + 51.4997},
{x_offset + 98.2293, y_offset + 57.1924},
{x_offset + 96.9149, y_offset + 62.8431},
{x_offset + 95.4353, y_offset + 68.4604},
{x_offset + 93.744, y_offset + 74.0441},
{x_offset + 91.5361, y_offset + 79.4612},
{x_offset + 32.9994, y_offset + 44.2254},
{x_offset + 34.125, y_offset + 40.0104},
{x_offset + 24.6554, y_offset + 40.2778},
{x_offset + 28.1858, y_offset + 42.7984},
{x_offset + 38.3286, y_offset + 43.9908},
{x_offset + 34.1243, y_offset + 40.0099},
{x_offset + 43.1474, y_offset + 43.3184},
{x_offset + 34.0167, y_offset + 36.9718},
{x_offset + 28.7573, y_offset + 37.653},
{x_offset + 39.3212, y_offset + 38.9433},
{x_offset + 16.2452, y_offset + 28.4257},
{x_offset + 22.6223, y_offset + 27.7422},
{x_offset + 29.3128, y_offset + 28.5235},
{x_offset + 43.2792, y_offset + 33.2877},
{x_offset + 36.3115, y_offset + 30.5346},
{x_offset + 22.1746, y_offset + 23.7718},
{x_offset + 29.9789, y_offset + 23.9014},
{x_offset + 44.2766, y_offset + 29.9646},
{x_offset + 37.5711, y_offset + 26.3851},
{x_offset + 41.0353, y_offset + 85.7782},
{x_offset + 55.2768, y_offset + 93.9337},
{x_offset + 48.1265, y_offset + 86.7238},
{x_offset + 44.5563, y_offset + 89.7497},
{x_offset + 49.021, y_offset + 92.7902},
{x_offset + 62.5841, y_offset + 86.3203},
{x_offset + 66.422, y_offset + 89.1865},
{x_offset + 61.7114, y_offset + 92.4944},
{x_offset + 55.186, y_offset + 87.3286},
{x_offset + 70.197, y_offset + 84.9727},
{x_offset + 55.1775, y_offset + 85.9714},
{x_offset + 51.3597, y_offset + 81.3182},
{x_offset + 45.7722, y_offset + 83.1928},
{x_offset + 43.166, y_offset + 85.9001},
{x_offset + 48.134, y_offset + 85.7435},
{x_offset + 58.739, y_offset + 81.1461},
{x_offset + 64.7266, y_offset + 82.6822},
{x_offset + 67.9371, y_offset + 85.2036},
{x_offset + 62.5068, y_offset + 85.3484},
{x_offset + 55.0752, y_offset + 82.203},
{x_offset + 54.1078, y_offset + 41.0726},
{x_offset + 54.1622, y_offset + 50.449},
{x_offset + 54.2472, y_offset + 59.7894},
{x_offset + 47.8631, y_offset + 43.7787},
{x_offset + 45.3862, y_offset + 63.1931},
{x_offset + 43.1486, y_offset + 70.3748},
{x_offset + 46.5561, y_offset + 72.9832},
{x_offset + 50.3907, y_offset + 74.013},
{x_offset + 54.7346, y_offset + 75.5877},
{x_offset + 60.6909, y_offset + 43.5412},
{x_offset + 63.7283, y_offset + 62.8403},
{x_offset + 66.4094, y_offset + 69.925},
{x_offset + 63.0125, y_offset + 72.6655},
{x_offset + 59.1424, y_offset + 73.8402},
{x_offset + 54.3016, y_offset + 69.1335},
{x_offset + 75.6758, y_offset + 44.0704},
{x_offset + 74.6523, y_offset + 39.8236},
{x_offset + 65.424, y_offset + 43.1928},
{x_offset + 70.2963, y_offset + 43.8439},
{x_offset + 80.5739, y_offset + 42.5855},
{x_offset + 74.6507, y_offset + 39.8237},
{x_offset + 84.2044, y_offset + 40.0282},
{x_offset + 74.6292, y_offset + 36.7377},
{x_offset + 69.2614, y_offset + 38.7215},
{x_offset + 79.9957, y_offset + 37.425},
{x_offset + 63.7995, y_offset + 32.7549},
{x_offset + 71.0128, y_offset + 29.9861},
{x_offset + 78.3723, y_offset + 27.9521},
{x_offset + 85.5398, y_offset + 27.2929},
{x_offset + 92.6151, y_offset + 28.1619},
{x_offset + 62.768, y_offset + 29.3678},
{x_offset + 69.733, y_offset + 25.7456},
{x_offset + 77.7338, y_offset + 23.3563},
{x_offset + 86.0268, y_offset + 23.3427}};
@Override
public ImageMat inference(ImageMat imageMat, FaceInfo.Points imagePoint, Map<String, Object> params) {
double [][] image_points;
if(imagePoint.size() == 106){
image_points = imagePoint.toDoubleArray();
}else{
throw new RuntimeException("need 106 point, but get "+ imagePoint.size());
}
Mat alignMat = AlignUtil.alignedImage(imageMat.toCvMat(), image_points, 112, 112, dst_points);
return ImageMat.fromCVMat(alignMat);
}
}

View File

@ -0,0 +1,122 @@
package com.visual.face.search.core.utils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.opencv.core.*;
import org.opencv.imgproc.Imgproc;
/**
* 图像对齐工具
*/
public class AlignUtil {
/**
* 人脸对齐
* @param image 图像数据
* @param imagePoint 图像中的关键点
* @param stdWidth 定义的标准图像的宽度
* @param stdHeight 定义的标准图像的高度
* @param stdPoint 定义的标准关键点
*/
public static Mat alignedImage(Mat image, double[][] imagePoint, int stdWidth, int stdHeight, double[][] stdPoint){
Mat warp = null;
Mat rectMat = null;
try {
warp = warpAffine(image, imagePoint, stdPoint);
double imgWidth = warp.size().width;
double imgHeight = warp.size().height;
if(stdWidth <= imgWidth && stdHeight <= imgHeight){
Mat crop = new Mat(warp, new Rect(0, 0, stdWidth, stdHeight));
return crop;
}
//计算需要裁剪的宽和高
int h, w;
if((1.0*imgWidth/imgHeight) >= (1.0 * stdWidth/stdHeight)){
h = (int) Math.floor(1.0 * imgHeight);
w = (int) Math.floor(1.0 * stdWidth * imgHeight / stdHeight);
}else{
w = (int) Math.floor(1.0 * imgWidth);
h = (int) Math.floor(1.0 * stdHeight * imgWidth / stdWidth);
}
//需要裁剪图片
rectMat = new Mat(warp, new Rect(0, 0, w, h));
Mat crop = new Mat();
Imgproc.resize(rectMat, crop, new Size(stdWidth, stdHeight), 0, 0, Imgproc.INTER_NEAREST);
return crop;
}finally {
if(null != rectMat){
rectMat.release();
}
if(null != warp){
warp.release();
}
}
}
/**
* 图像仿射变换
* @param image 图像数据
* @param imgPoint 图像中的关键点
* @param stdPoint 定义的标准关键点
* @return 图像的仿射结果图
*/
public static Mat warpAffine(Mat image, double[][] imgPoint, double[][] stdPoint){
Mat matM = null;
Mat matMTemp = null;
try {
//转换为矩阵
RealMatrix imgPointMatrix = MathUtil.createMatrix(imgPoint);
RealMatrix stdPointMatrix = MathUtil.createMatrix(stdPoint);
//判断数据的行列是否一致
int row = imgPointMatrix.getRowDimension();
int col = imgPointMatrix.getColumnDimension();
if(row <= 0 || col <=0 || row != stdPointMatrix.getRowDimension() || col != stdPointMatrix.getColumnDimension()){
throw new RuntimeException("row or col is not equal");
}
//求列的均值
RealVector imgPointMeanVector = MathUtil.mean(imgPointMatrix, 0);
RealVector stdPointMeanVector = MathUtil.mean(stdPointMatrix, 0);
//对关键点进行减去均值
RealMatrix imgPointMatrix1 = imgPointMatrix.subtract(MathUtil.createMatrix(row, imgPointMeanVector.toArray()));
RealMatrix stdPointMatrix1 = stdPointMatrix.subtract(MathUtil.createMatrix(row, stdPointMeanVector.toArray()));
//计算关键点的标准差
double imgPointStd = MathUtil.std(imgPointMatrix1);
double stdPointStd = MathUtil.std(stdPointMatrix1);
//对关键点除以标准差
RealMatrix imgPointMatrix2 = MathUtil.scalarDivision(imgPointMatrix1, imgPointStd);
RealMatrix stdPointMatrix2 = MathUtil.scalarDivision(stdPointMatrix1, stdPointStd);
//获取矩阵的分量
RealMatrix pointsT = imgPointMatrix2.transpose().multiply(stdPointMatrix2);
SingularValueDecomposition svdH = new SingularValueDecomposition(pointsT);
RealMatrix U = svdH.getU(); RealMatrix S = svdH.getS(); RealMatrix Vt = svdH.getVT();
//计算仿射矩阵
RealMatrix R = U.multiply(Vt).transpose();
RealMatrix R1 = R.scalarMultiply(stdPointStd/imgPointStd);
RealMatrix v21 = MathUtil.createMatrix(1, stdPointMeanVector.toArray()).transpose();
RealMatrix v22 = R.multiply(MathUtil.createMatrix(1, imgPointMeanVector.toArray()).transpose());
RealMatrix v23 = v22.scalarMultiply(stdPointStd/imgPointStd);
RealMatrix R2 = v21.subtract(v23);
RealMatrix M = MathUtil.hstack(R1, R2);
//变化仿射矩阵为Mat
matMTemp = new MatOfDouble(MathUtil.flatMatrix(M, 1).toArray());
matM = new Mat(2, 3, CvType.CV_32FC3);
matMTemp.reshape(1,2).copyTo(matM);
//使用open cv做仿射变换
Mat dst = new Mat();
Imgproc.warpAffine(image, dst, matM, image.size());
return dst;
}finally {
if(null != matM){
matM.release();
}
if(null != matMTemp){
matMTemp.release();
}
}
}
}

View File

@ -0,0 +1,58 @@
package com.visual.face.search.core.utils;
import java.util.List;
import java.util.ArrayList;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
import org.opencv.utils.Converters;
import com.visual.face.search.core.domain.FaceInfo;
/**
* 图像裁剪工具
*/
public class CropUtil {
/**
* 根据4个点裁剪图像
* @param image
* @param faceBox
* @return
*/
public static Mat crop(Mat image, FaceInfo.FaceBox faceBox){
Mat endM = null;
Mat startM = null;
Mat perspectiveTransform = null;
try {
List<Point> dest = new ArrayList<>();
dest.add(new Point(faceBox.leftTop.x, faceBox.leftTop.y));
dest.add(new Point(faceBox.rightTop.x, faceBox.rightTop.y));
dest.add(new Point(faceBox.rightBottom.x, faceBox.rightBottom.y));
dest.add(new Point(faceBox.leftBottom.x, faceBox.leftBottom.y));
startM = Converters.vector_Point2f_to_Mat(dest);
List<Point> ends = new ArrayList<>();
ends.add(new Point(0, 0));
ends.add(new Point(faceBox.width(), 0));
ends.add(new Point(faceBox.width(), faceBox.height()));
ends.add(new Point(0, faceBox.height()));
endM = Converters.vector_Point2f_to_Mat(ends);
perspectiveTransform = Imgproc.getPerspectiveTransform(startM, endM);
Mat outputMat = new Mat((int)faceBox.height() , (int)faceBox.width(), CvType.CV_8UC4);
Imgproc.warpPerspective(image, outputMat, perspectiveTransform, new Size((int)faceBox.width(), (int)faceBox.height()), Imgproc.INTER_CUBIC);
return outputMat;
}finally {
if(null != endM){
endM.release();
}
if(null != startM){
startM.release();
}
if(null != perspectiveTransform){
perspectiveTransform.release();
}
}
}
}

View File

@ -0,0 +1,128 @@
package com.visual.face.search.core.utils;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.TypeReference;
import com.alibaba.fastjson.serializer.SerializerFeature;
import java.util.List;
import java.util.Map;
public class JsonUtil {
/**
* 将Bean转化为json字符串
*
* @param obj bean对象
* @return json
*/
public static String toString(Object obj) {
return toString(obj, false, false);
}
public static String toSimpleString(Object obj) {
return toString(obj, false, true);
}
/**
* 将Bean转化为json字符串
*
* @param obj bean对象
* @param prettyFormat 是否格式化
* @return json
*/
public static String toString(Object obj, boolean prettyFormat, boolean noNull) {
if (prettyFormat) {
if (noNull) {
return JSON.toJSONString(obj, SerializerFeature.DisableCircularReferenceDetect, SerializerFeature.PrettyFormat);
} else {
return JSON.toJSONString(obj, SerializerFeature.WriteMapNullValue, SerializerFeature.WriteNullListAsEmpty, SerializerFeature.DisableCircularReferenceDetect, SerializerFeature.PrettyFormat);
}
} else {
if (noNull) {
return JSON.toJSONString(obj, SerializerFeature.DisableCircularReferenceDetect);
} else {
return JSON.toJSONString(obj, SerializerFeature.WriteMapNullValue, SerializerFeature.WriteNullListAsEmpty, SerializerFeature.DisableCircularReferenceDetect);
}
}
}
/**
* 将字符串转换为Entity
*
* @param json 数据字符串
* @param clazz Entity class
* @return
*/
public static <T> T toEntity(String json, Class<T> clazz) {
return JSON.parseObject(json, clazz);
}
/**
* 将字符串转换为Entity
*
* @param json 数据字符串
* @param typeReference Entity class
* @return
*/
public static <T> T toEntity(String json, TypeReference<T> typeReference) {
return JSON.parseObject(json, typeReference);
}
/**
* 将字符串转换为Map
*
* @param json 数据字符串
* @return Map
*/
public static Map<String, Object> toMap(String json) {
return JSON.parseObject(json, new TypeReference<Map<String, Object>>() {
});
}
/**
* 将字符串转换为List<T>
*
* @param json 数据字符串
* @param collectionClass 泛型
* @return list<T>
*/
public static <T> List<T> toList(String json, Class<T> collectionClass) {
return JSON.parseArray(json, collectionClass);
}
/**
* 将字符串转换为List<Map<String, Object>>
*
* @param json 数据字符串
* @return list<map>
*/
public static List<Map<String, Object>> toListMap(String json) {
return JSON.parseObject(json, new TypeReference<List<Map<String, Object>>>() {
});
}
/**
* 将字符串转换为Object
*
* @param json 数据字符串
* @return list<map>
*/
public static JSONObject toJsonObject(String json) {
return JSON.parseObject(json);
}
/**
* 将字符串转换为Array
*
* @param json 数据字符串
* @return list<map>
*/
public static JSONArray toJsonArray(String json) {
return JSON.parseArray(json);
}
}

View File

@ -0,0 +1,111 @@
package com.visual.face.search.core.utils;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import org.opencv.core.*;
import org.opencv.imgproc.Imgproc;
import java.util.ArrayList;
import java.util.List;
public class MaskUtil {
/**添加遮罩层所需要的索引号:InsightCoordFaceKeyPoint**/
private static int [] MASK_106_IST_ROUND_INDEX = new int[]{
1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,
24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,
101,105,104,103,102,50,51,49,48,43
};
/**
* 添加遮罩层
* @param image 原始图像
* @param pts 指定不不需要填充的区域
* @param release 是否释放参数image
* @return
*/
public static Mat mask(Mat image, List<MatOfPoint> pts, boolean release){
Mat pattern = null;
try {
pattern = MatOfPoint.zeros(image.size(), CvType.CV_8U);
Imgproc.fillPoly(pattern, pts, new Scalar(1,1,1));
Mat dst = new Mat();
image.copyTo(dst, pattern);
return dst;
}finally {
if(null != pattern){
pattern.release();
}
if(release && null != pts){
for(MatOfPoint pt : pts){
pt.release();
}
}
if(release && null != image){
image.release();
}
}
}
/**
* 添加遮罩层
* @param image 原始图像
* @param fillPoints 指定不不需要填充的区域的点
* @param release 是否释放参数image
* @return
*/
public static Mat mask(Mat image, Point[] fillPoints, boolean release){
List<MatOfPoint> pts = null;
try {
pts = new ArrayList<>();
pts.add(new MatOfPoint(fillPoints));
return mask(image, pts, false);
}finally {
if(null != pts){
for(MatOfPoint pt : pts){
pt.release();
}
}
if(release && null != image){
image.release();
}
}
}
/**
* 添加遮罩层:InsightCoordFaceKeyPoint
* @param image 原始图像
* @param points 人脸标记点
* @param release 是否释放参数image
* @return
*/
public static Mat maskFor106InsightCoordModel(Mat image, FaceInfo.Points points, boolean release){
try {
Point[] fillPoints = PointUtil.convert(points.select(MASK_106_IST_ROUND_INDEX));
return mask(image, fillPoints, false);
}finally {
if(release && null != image){
image.release();
}
}
}
/**
* 添加遮罩层:InsightCoordFaceKeyPoint
* @param image 原始图像
* @param points 人脸标记点
* @param release 是否释放参数image
* @return
*/
public static ImageMat maskFor106InsightCoordModel(ImageMat image, FaceInfo.Points points, boolean release){
try {
Mat mat = maskFor106InsightCoordModel(image.toCvMat(), points, false);
return ImageMat.fromCVMat(mat);
}finally {
if(release && null != image){
image.release();
}
}
}
}

View File

@ -0,0 +1,58 @@
package com.visual.face.search.core.utils;
import org.opencv.core.Mat;
import sun.misc.BASE64Encoder;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.util.Objects;
public class MatUtil {
/**
* 将Mat转换为BufferedImage
* @param mat
* @return BufferedImage
*/
public static BufferedImage matToBufferedImage(Mat mat) {
int dataSize = mat.cols() * mat.rows() * (int) mat.elemSize();
byte[] data = new byte[dataSize];
mat.get(0, 0, data);
int type = mat.channels() == 1 ? BufferedImage.TYPE_BYTE_GRAY : BufferedImage.TYPE_3BYTE_BGR;
if (type == BufferedImage.TYPE_3BYTE_BGR) {
for (int i = 0; i < dataSize; i += 3) {
byte blue = data[i + 0];
data[i + 0] = data[i + 2];
data[i + 2] = blue;
}
}
BufferedImage image = new BufferedImage(mat.cols(), mat.rows(), type);
image.getRaster().setDataElements(0, 0, mat.cols(), mat.rows(), data);
return image;
}
/**
* 将Mat转换为 Base64
* @param mat
* @return Base64
*/
public static String matToBase64(Mat mat) {
ByteArrayOutputStream byteArrayOutputStream = null;
try {
byteArrayOutputStream = new ByteArrayOutputStream();
ImageIO.write(matToBufferedImage(mat), "jpg", byteArrayOutputStream);
byte[] bytes = byteArrayOutputStream.toByteArray();
BASE64Encoder encoder = new BASE64Encoder();
return encoder.encodeBuffer(Objects.requireNonNull(bytes));
}catch (Exception e){
throw new RuntimeException(e);
}finally {
if(null != byteArrayOutputStream){
try {
byteArrayOutputStream.close();
} catch (Exception e) {}
}
}
}
}

View File

@ -0,0 +1,289 @@
package com.visual.face.search.core.utils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
public class MathUtil {
/**
* 创建向量
* @param array 数组
* @return 向量
*/
public static RealVector createVector(double[] array){
return new ArrayRealVector(array);
}
/**
* 创建向量
* @param array 数组
* @return 向量
*/
public static RealVector createVector(Double[] array){
return new ArrayRealVector(array);
}
/**
* 创建矩阵
* @param array 矩阵数组
* @return 矩阵
*/
public static RealMatrix createMatrix(double[][] array){
return new Array2DRowRealMatrix(array);
}
/**
* 创建矩阵
* @param array 矩阵数组
* @return 矩阵
*/
public static RealMatrix createMatrix(Double[][] array){
double[][] data = new double[array.length][];
for(int i=0; i< array.length; i++){
double [] item = new double[array[i].length];
for(int j=0; j<array[i].length; j++){
item[j] = array[i][j];
}
data[i] = item;
}
return new Array2DRowRealMatrix(data);
}
/**
* 创建矩阵
* @param rows 重复的行数
* @param array 矩阵数组
* @return 矩阵
*/
public static RealMatrix createMatrix(int rows, double[] array){
double[][] data = new double[rows][array.length];
for(int i=0; i<rows;i++){
data[i] = array;
}
return new Array2DRowRealMatrix(data);
}
/**
* 将矩阵的每个值都加上value值
* @param matrix 矩阵
* @param value 加值
* @return 矩阵
*/
public static RealMatrix scalarAdd(RealMatrix matrix, double value){
return matrix.scalarAdd(value);
}
/**
* 将矩阵的每个值都减去value值
* @param matrix 矩阵
* @param value 减值
* @return 矩阵
*/
public static RealMatrix scalarSub(RealMatrix matrix, double value){
return matrix.scalarAdd(-value);
}
/**
* 将矩阵的每个值都乘以value值
* @param matrix 矩阵
* @param value 乘值
* @return 矩阵
*/
public static RealMatrix scalarMultiply(RealMatrix matrix, double value){
return matrix.scalarMultiply(value);
}
/**
* 将矩阵的每个值都除以value值
* @param matrix 矩阵
* @param value 除值
* @return 矩阵
*/
public static RealMatrix scalarDivision(RealMatrix matrix, double value){
return matrix.scalarMultiply(1.0/value);
}
/**
* 求矩阵的均值分坐标轴0Y轴 1X轴
* @param matrix 数据矩阵
* @param axis 0Y轴 1X轴
* @return 均值
*/
public static RealVector mean(RealMatrix matrix, int axis){
if(axis == 0){
double[] means = new double[matrix.getColumnDimension()];
for(int i=0;i<matrix.getColumnDimension(); i++){
means[i] = new Mean().evaluate(matrix.getColumn(i));
}
return new ArrayRealVector(means);
}else {
double[] means = new double[matrix.getRowDimension()];
for(int i=0;i<matrix.getRowDimension(); i++){
means[i] = new Mean().evaluate(matrix.getRow(i));
}
return new ArrayRealVector(means);
}
}
/**
* 计算矩阵的整体标准差
* @param matrix 数据矩阵
* @return 整体标准差
*/
public static double std(RealMatrix matrix){
double[] data = new double[matrix.getColumnDimension() * matrix.getRowDimension()];
for(int i=0;i<matrix.getRowDimension(); i++){
for(int j=0;j<matrix.getColumnDimension(); j++){
data[i*matrix.getColumnDimension()+j] = matrix.getEntry(i, j);
}
}
return new StandardDeviation(false).evaluate(data);
}
/**
* 矩阵列拼接
* @param matrix1 数据矩阵1
* @param matrix2 数据矩阵2
* @return 数据矩阵
*/
public static RealMatrix hstack(RealMatrix matrix1, RealMatrix matrix2){
int row = matrix1.getRowDimension();
int col = matrix1.getColumnDimension()+matrix2.getColumnDimension();
double[][] data = new double[row][col];
for(int i=0;i<matrix1.getRowDimension(); i++){
for(int j=0;j<matrix1.getColumnDimension(); j++){
data[i][j] = matrix1.getEntry(i, j);
}
for(int j=0;j<matrix2.getColumnDimension(); j++){
data[i][matrix1.getColumnDimension()+j] = matrix2.getEntry(i, j);
}
}
return new Array2DRowRealMatrix(data);
}
/**
* 矩阵行拼接
* @param matrix1 数据矩阵1
* @param matrix2 数据矩阵2
* @return 数据矩阵
*/
public static RealMatrix vstack(RealMatrix matrix1, RealMatrix matrix2){
int row = matrix1.getRowDimension()+matrix2.getRowDimension();
int col = matrix1.getColumnDimension();
double[][] data = new double[row][col];
for(int i=0;i<matrix1.getRowDimension(); i++){
for(int j=0;j<matrix1.getColumnDimension(); j++){
data[i][j] = matrix1.getEntry(i, j);
}
}
for(int i=0;i<matrix2.getRowDimension(); i++){
for(int j=0;j<matrix2.getColumnDimension(); j++){
data[i+matrix1.getRowDimension()][j] = matrix2.getEntry(i, j);
}
}
return new Array2DRowRealMatrix(data);
}
/**
* 将矩阵拉平
* @param matrix 矩阵
* @param axis 0Y轴 1X轴
* @return
*/
public static RealVector flatMatrix(RealMatrix matrix, int axis){
RealVector vector = new ArrayRealVector();
if(0 == axis){
for(int i=0; i< matrix.getColumnDimension(); i++){
vector = vector.append(matrix.getColumnVector(i));
}
}else{
for(int i=0; i< matrix.getRowDimension(); i++){
vector = vector.append(matrix.getRowVector(i));
}
}
return vector;
}
/**
* 向量点积
* @param vector1 向量1
* @param vector2 向量2
* @return 点积
*/
public static double dotProduct(RealVector vector1, RealVector vector2){
return vector1.dotProduct(vector2);
}
/**
* 矩阵点积
* @param matrix1 矩阵1
* @param matrix2 矩阵2
* @return 点积矩阵
*/
public static RealMatrix dotProduct(RealMatrix matrix1, RealMatrix matrix2){
double[][] data = new double[matrix1.getRowDimension()][matrix1.getColumnDimension()];
for(int row = 0; row < matrix1.getRowDimension(); row ++){
for(int col=0; col < matrix1.getColumnDimension(); col ++){
data[row][col] = matrix1.getRowVector(row).dotProduct(matrix2.getColumnVector(col));
}
}
return createMatrix(data);
}
/**
* 矩阵相似变换
* @param matrix
* @param scale
* @param rotation
* @param translation
* @return
*/
public static RealMatrix similarityTransform(Double[][] matrix, Double scale, Double rotation, Double[] translation){
if(matrix == null && translation == null){
return similarityTransform((RealMatrix)null, scale, rotation, null);
}else if(matrix == null){
return similarityTransform(null, scale, rotation, createVector(translation));
}else if(translation == null){
return similarityTransform(createMatrix(matrix), scale, rotation, null);
}else{
return similarityTransform(createMatrix(matrix), scale, rotation, createVector(translation));
}
}
/**
* 矩阵相似变换
* @param matrix
* @param scale
* @param rotation
* @param translation
* @return
*/
public static RealMatrix similarityTransform(RealMatrix matrix, Double scale, Double rotation, RealVector translation){
boolean hasParams = (scale != null || rotation!= null || translation!= null);
if(hasParams && matrix != null){
throw new RuntimeException("You cannot specify the transformation matrix and the implicit parameters at the same time.");
}else if(matrix != null){
if(matrix.getColumnDimension() != 3 && matrix.getRowDimension() != 3){
throw new RuntimeException("Invalid shape of transformation matrix.");
}else {
return matrix;
}
}else if(hasParams){
scale = scale == null ? 1 : scale;
rotation = rotation == null ? 0 : rotation;
translation = translation == null ? createVector(new double[]{0, 0}) : translation;
return createMatrix(new double[][]{
{Math.cos(rotation) * scale, -Math.sin(rotation) * scale, translation.getEntry(0)},
{Math.sin(rotation) * scale, Math.cos(rotation) * scale, translation.getEntry(1)},
{0, 0, 1}
});
}else {
return createMatrix(new double[][]{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
}
}
}

View File

@ -0,0 +1,39 @@
package com.visual.face.search.core.utils;
import com.visual.face.search.core.domain.FaceInfo;
import org.opencv.core.Point;
public class PointUtil {
/**
* 转换点对象
* @param point
* @return
*/
public static FaceInfo.Point convert(Point point){
return FaceInfo.Point.build((float)point.x, (float)point.y);
}
/**
* 转换点对象
* @param point
* @return
*/
public static Point convert(FaceInfo.Point point){
return new Point(point.x, point.y);
}
/**
* 转换点对象
* @param points
* @return
*/
public static Point[] convert(FaceInfo.Points points){
Point[] result = new Point[points.size()];
for(int i=0; i< points.size(); i++){
result[i] = convert(points.get(i));
}
return result;
}
}

View File

@ -0,0 +1,34 @@
package com.visual.face.search.core.utils;
public class Similarity {
/**
* 向量余弦相似度
* @param leftVector
* @param rightVector
* @return
*/
public static float cosineSimilarity(float[] leftVector, float[] rightVector) {
double dotProduct = 0;
for (int i=0; i< leftVector.length; i++) {
dotProduct += leftVector[i] * rightVector[i];
}
double d1 = 0.0d;
for (float value : leftVector) {
d1 += Math.pow(value, 2);
}
double d2 = 0.0d;
for (float value : rightVector) {
d2 += Math.pow(value, 2);
}
double cosineSimilarity;
if (d1 <= 0.0 || d2 <= 0.0) {
cosineSimilarity = 0.0;
} else {
cosineSimilarity = (dotProduct / (Math.sqrt(d1) * Math.sqrt(d2)));
}
return (float) cosineSimilarity;
}
}

View File

@ -0,0 +1,9 @@
package com.visual.face.search.core.utils;
public class ThreadUtil {
public static void run(Runnable runnable){
new Thread(runnable).start();
}
}

View File

@ -0,0 +1,28 @@
package com.visual.face.search.core.test.base;
import ai.onnxruntime.OrtEnvironment;
import java.io.File;
import java.util.Map;
import java.util.TreeMap;
public abstract class BaseTest {
//静态加载动态链接库
// static{ nu.pattern.OpenCV.loadShared(); }
// private OrtEnvironment env = OrtEnvironment.getEnvironment();
public static Map<String, String> getImagePathMap(String imagePath){
Map<String, String> map = new TreeMap<>();
File file = new File(imagePath);
if(file.isFile()){
map.put(file.getName(), file.getAbsolutePath());
}else if(file.isDirectory()){
for(File tmpFile : file.listFiles()){
map.putAll(getImagePathMap(tmpFile.getPath()));
}
}
return map;
}
}

View File

@ -0,0 +1,116 @@
package com.visual.face.search.core.test.extract;
import com.visual.face.search.core.base.FaceAlignment;
import com.visual.face.search.core.base.FaceDetection;
import com.visual.face.search.core.base.FaceKeyPoint;
import com.visual.face.search.core.base.FaceRecognition;
import com.visual.face.search.core.domain.ExtParam;
import com.visual.face.search.core.domain.FaceImage;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.extract.FaceFeatureExtractor;
import com.visual.face.search.core.extract.FaceFeatureExtractorImpl;
import com.visual.face.search.core.models.*;
import com.visual.face.search.core.test.base.BaseTest;
import java.util.List;
import java.util.Map;
public class FaceFeatureExtractOOMTest extends BaseTest {
private static String modelPcn1Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn1_sd.onnx";
private static String modelPcn2Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn2_sd.onnx";
private static String modelPcn3Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn3_sd.onnx";
private static String modelScrfdPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
// private static String imagePath = "face-search-core/src/test/resources/images/faces";
private static String imagePath = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
private static Map<String, String> map = getImagePathMap(imagePath);
public static void main(String[] args) {
testAll();
// testPcnNetworkFaceDetection();
// testInsightScrfdFaceDetection(); //无异常内存稳定
// testInsightCoordFaceKeyPoint(); //无异常内存稳定
// testInsightArcFaceRecognition(); //无异常内存稳定,389-351
}
public static void testAll() {
FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1);
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1);
FaceAlignment simple005pFaceAlignment = new Simple005pFaceAlignment();
FaceAlignment simple106pFaceAlignment = new Simple106pFaceAlignment();
FaceDetection pcnNetworkFaceDetection = new PcnNetworkFaceDetection(new String[]{modelPcn1Path, modelPcn2Path, modelPcn3Path}, 1);
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint, simple106pFaceAlignment, insightArcFaceRecognition);
// FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(insightScrfdFaceDetection, insightCoordFaceKeyPoint, simple106pFaceAlignment, insightArcFaceRecognition);
for (int i = 0; i < 100000; i++) {
for (String fileName : map.keySet()) {
long s = System.currentTimeMillis();
ImageMat imageMat = ImageMat.fromImage(map.get(fileName));
ExtParam extParam = ExtParam.build().setMask(true).setScoreTh(InsightScrfdFaceDetection.defScoreTh).setIouTh(InsightScrfdFaceDetection.defIouTh);
FaceImage faceImage = extractor.extract(imageMat, extParam, null);
long e = System.currentTimeMillis();
System.out.println(i + ",cost=" +(e-s)+"ms:"+ fileName + ":" + faceImage);
imageMat.release();
}
}
}
public static void testPcnNetworkFaceDetection(){
FaceDetection pcnNetworkFaceDetection = new PcnNetworkFaceDetection(new String[]{modelPcn1Path, modelPcn2Path, modelPcn3Path}, 4);
for (int i = 0; i < 100000; i++) {
for (String fileName : map.keySet()) {
ImageMat imageMat = ImageMat.fromImage(map.get(fileName));
List<FaceInfo> list = pcnNetworkFaceDetection.inference(imageMat, PcnNetworkFaceDetection.defScoreTh, PcnNetworkFaceDetection.defIouTh, null);
System.out.println(i + "," + fileName + ":" + list.size());
imageMat.release();
}
}
}
//验证无内存泄露
public static void testInsightScrfdFaceDetection(){
FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 4);
for (int i = 0; i < 100000; i++) {
for (String fileName : map.keySet()) {
ImageMat imageMat = ImageMat.fromImage(map.get(fileName));
List<FaceInfo> list = insightScrfdFaceDetection.inference(imageMat, InsightScrfdFaceDetection.defScoreTh, InsightScrfdFaceDetection.defIouTh, null);
System.out.println(i + "," + fileName + ":" + list.size());
imageMat.release();
}
}
}
//验证无内存泄露
public static void testInsightCoordFaceKeyPoint(){
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
for (int i = 0; i < 100000; i++) {
for (String fileName : map.keySet()) {
ImageMat imageMat = ImageMat.fromImage(map.get(fileName));
FaceInfo.Points list = insightCoordFaceKeyPoint.inference(imageMat, null);
System.out.println(i + "," + fileName + ":" + list.size());
imageMat.release();
}
}
}
//验证无内存泄露
public static void testInsightArcFaceRecognition(){
FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1);
for (int i = 0; i < 100000; i++) {
for (String fileName : map.keySet()) {
ImageMat imageMat = ImageMat.fromImage(map.get(fileName));
FaceInfo.Embedding embedding = insightArcFaceRecognition.inference(imageMat, null);
System.out.println(i + "," + fileName + ":" + embedding);
imageMat.release();
}
}
}
}

View File

@ -0,0 +1,97 @@
package com.visual.face.search.core.test.extract;
import com.visual.face.search.core.base.FaceAlignment;
import com.visual.face.search.core.base.FaceDetection;
import com.visual.face.search.core.base.FaceKeyPoint;
import com.visual.face.search.core.base.FaceRecognition;
import com.visual.face.search.core.domain.ExtParam;
import com.visual.face.search.core.domain.FaceImage;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.extract.FaceFeatureExtractor;
import com.visual.face.search.core.extract.FaceFeatureExtractorImpl;
import com.visual.face.search.core.models.*;
import com.visual.face.search.core.test.base.BaseTest;
import com.visual.face.search.core.utils.CropUtil;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.util.List;
import java.util.Map;
public class FaceFeatureExtractTest extends BaseTest {
private static String modelPcn1Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn1_sd.onnx";
private static String modelPcn2Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn2_sd.onnx";
private static String modelPcn3Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn3_sd.onnx";
private static String modelScrfdPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
private static String imagePath = "face-search-core/src/test/resources/images/faces";
public static void main(String[] args) {
Map<String, String> map = getImagePathMap(imagePath);
FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1);
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1);
FaceAlignment simple005pFaceAlignment = new Simple005pFaceAlignment();
FaceAlignment simple106pFaceAlignment = new Simple106pFaceAlignment();
FaceDetection pcnNetworkFaceDetection = new PcnNetworkFaceDetection(new String[]{modelPcn1Path, modelPcn2Path, modelPcn3Path}, 1);
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(pcnNetworkFaceDetection, insightScrfdFaceDetection, insightCoordFaceKeyPoint, simple106pFaceAlignment, insightArcFaceRecognition);
for(String fileName : map.keySet()){
String imageFilePath = map.get(fileName);
System.out.println(imageFilePath);
Mat image = Imgcodecs.imread(imageFilePath);
long s = System.currentTimeMillis();
ExtParam extParam = ExtParam.build()
.setMask(true)
.setTopK(20)
.setScoreTh(0)
.setIouTh(0);
FaceImage faceImage = extractor.extract(ImageMat.fromCVMat(image), extParam, null);
List<FaceInfo> faceInfos = faceImage.faceInfos();
long e = System.currentTimeMillis();
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos);
for(FaceInfo faceInfo : faceInfos){
FaceInfo.FaceBox box = faceInfo.rotateFaceBox().scaling(1.0f);
Imgproc.line(image, new Point(box.leftTop.x, box.leftTop.y), new Point(box.rightTop.x, box.rightTop.y), new Scalar(0,0,255), 1);
Imgproc.line(image, new Point(box.rightTop.x, box.rightTop.y), new Point(box.rightBottom.x, box.rightBottom.y), new Scalar(255,0,0), 1);
Imgproc.line(image, new Point(box.rightBottom.x, box.rightBottom.y), new Point(box.leftBottom.x, box.leftBottom.y), new Scalar(255,0,0), 1);
Imgproc.line(image, new Point(box.leftBottom.x, box.leftBottom.y), new Point(box.leftTop.x, box.leftTop.y), new Scalar(255,0,0), 1);
Imgproc.putText(image, String.valueOf(faceInfo.angle), new Point(box.leftTop.x, box.leftTop.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
// Imgproc.rectangle(image, new Point(faceInfo.box.x1(), faceInfo.box.y1()), new Point(faceInfo.box.x2(), faceInfo.box.y2()), new Scalar(255,0,255));
FaceInfo.FaceBox box1 = faceInfo.rotateFaceBox();
Imgproc.circle(image, new Point(box1.leftTop.x, box1.leftTop.y), 3, new Scalar(0,0,255), -1);
Imgproc.circle(image, new Point(box1.rightTop.x, box1.rightTop.y), 3, new Scalar(0,0,255), -1);
Imgproc.circle(image, new Point(box1.rightBottom.x, box1.rightBottom.y), 3, new Scalar(0,0,255), -1);
Imgproc.circle(image, new Point(box1.leftBottom.x, box1.leftBottom.y), 3, new Scalar(0,0,255), -1);
int pointNum = 1;
for(FaceInfo.Point keyPoint : faceInfo.points){
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 1, new Scalar(0,0,255), -1);
// Imgproc.putText(image, String.valueOf(pointNum), new Point(keyPoint.x+1, keyPoint.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
pointNum ++ ;
}
// Mat crop = CropUtil.crop(image, box);
// HighGui.imshow(fileName, crop);
// HighGui.waitKey();
}
HighGui.imshow(fileName, image);
HighGui.waitKey();
image.release();
}
System.exit(1);
}
}

View File

@ -0,0 +1,49 @@
package com.visual.face.search.core.test.models;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
import com.visual.face.search.core.test.base.BaseTest;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.util.Map;
public class InsightCoordFaceKeyPointTest extends BaseTest {
private static String modelPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
private static String imagePath = "face-search-core/src/test/resources/images/faces";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
public static void main(String[] args) {
Map<String, String> map = getImagePathMap(imagePath);
InsightCoordFaceKeyPoint infer = new InsightCoordFaceKeyPoint(modelPath, 2);
for(String fileName : map.keySet()){
String imageFilePath = map.get(fileName);
System.out.println(imageFilePath);
Mat image = Imgcodecs.imread(imageFilePath);
long s = System.currentTimeMillis();
FaceInfo.Points points = infer.inference(ImageMat.fromCVMat(image), null);
long e = System.currentTimeMillis();
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+points);
int pointNum = 1;
for(FaceInfo.Point keyPoint : points){
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 1, new Scalar(0,0,255), -1);
Imgproc.putText(image, String.valueOf(pointNum), new Point(keyPoint.x+1, keyPoint.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
pointNum ++ ;
System.out.println("["+keyPoint.x+"," +keyPoint.y+"],");
}
HighGui.imshow(fileName, image);
HighGui.waitKey();
}
System.exit(1);
}
}

View File

@ -0,0 +1,57 @@
package com.visual.face.search.core.test.models;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.models.InsightScrfdFaceDetection;
import com.visual.face.search.core.test.base.BaseTest;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.util.List;
import java.util.Map;
public class InsightScrfdFaceDetectionTest extends BaseTest {
private static String modelPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
private static String imagePath = "face-search-core/src/test/resources/images/faces";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug";
public static void main(String[] args) {
Map<String, String> map = getImagePathMap(imagePath);
InsightScrfdFaceDetection infer = new InsightScrfdFaceDetection(modelPath, 2);
for(String fileName : map.keySet()){
String imageFilePath = map.get(fileName);
System.out.println(imageFilePath);
Mat image = Imgcodecs.imread(imageFilePath);
long s = System.currentTimeMillis();
List<FaceInfo> faceInfos = infer.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null);
long e = System.currentTimeMillis();
if(faceInfos.size() > 0){
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos.get(0).score);
}else{
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos);
}
for(FaceInfo faceInfo : faceInfos){
Imgproc.rectangle(image, new Point(faceInfo.box.x1(), faceInfo.box.y1()), new Point(faceInfo.box.x2(), faceInfo.box.y2()), new Scalar(0,0,255));
int pointNum = 1;
for(FaceInfo.Point keyPoint : faceInfo.points){
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 3, new Scalar(0,0,255), -1);
Imgproc.putText(image, String.valueOf(pointNum), new Point(keyPoint.x+1, keyPoint.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
pointNum ++ ;
}
}
HighGui.imshow(fileName, image);
HighGui.waitKey();
}
System.exit(1);
}
}

View File

@ -0,0 +1,77 @@
package com.visual.face.search.core.test.models;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.models.PcnNetworkFaceDetection;
import com.visual.face.search.core.test.base.BaseTest;
import com.visual.face.search.core.utils.CropUtil;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.util.List;
import java.util.Map;
public class PcnNetworkFaceDetectionTest extends BaseTest {
private static String model1Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn1_sd.onnx";
private static String model2Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn2_sd.onnx";
private static String model3Path = "face-search-core/src/main/resources/model/onnx/detection_face_pcn/pcn3_sd.onnx";
private static String imagePath = "face-search-core/src/test/resources/images/faces";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/rotate";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/debug";
public static void main(String[] args) {
Map<String, String> map = getImagePathMap(imagePath);
PcnNetworkFaceDetection infer = new PcnNetworkFaceDetection(new String[]{model1Path, model2Path, model3Path}, 2);
for(String fileName : map.keySet()){
String imageFilePath = map.get(fileName);
System.out.println(imageFilePath);
Mat image = Imgcodecs.imread(imageFilePath);
long s = System.currentTimeMillis();
List<FaceInfo> faceInfos = infer.inference(ImageMat.fromCVMat(image), PcnNetworkFaceDetection.defScoreTh, PcnNetworkFaceDetection.defIouTh, null);
long e = System.currentTimeMillis();
if(faceInfos.size() > 0){
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos.get(0).score);
}else{
System.out.println("fileName="+fileName+",\tcost="+(e-s)+",\t"+faceInfos);
}
for(FaceInfo faceInfo : faceInfos){
FaceInfo.FaceBox box = faceInfo.rotateFaceBox().scaling(1.0f);
Imgproc.line(image, new Point(box.leftTop.x, box.leftTop.y), new Point(box.rightTop.x, box.rightTop.y), new Scalar(0,0,255), 1);
Imgproc.line(image, new Point(box.rightTop.x, box.rightTop.y), new Point(box.rightBottom.x, box.rightBottom.y), new Scalar(255,0,0), 1);
Imgproc.line(image, new Point(box.rightBottom.x, box.rightBottom.y), new Point(box.leftBottom.x, box.leftBottom.y), new Scalar(255,0,0), 1);
Imgproc.line(image, new Point(box.leftBottom.x, box.leftBottom.y), new Point(box.leftTop.x, box.leftTop.y), new Scalar(255,0,0), 1);
Imgproc.putText(image, String.valueOf(faceInfo.angle), new Point(box.leftTop.x, box.leftTop.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(0,0,255));
// Imgproc.rectangle(image, new Point(faceInfo.box.x1(), faceInfo.box.y1()), new Point(faceInfo.box.x2(), faceInfo.box.y2()), new Scalar(255,0,255));
FaceInfo.FaceBox box1 = faceInfo.rotateFaceBox();
Imgproc.circle(image, new Point(box1.leftTop.x, box1.leftTop.y), 3, new Scalar(0,0,255), -1);
Imgproc.circle(image, new Point(box1.rightTop.x, box1.rightTop.y), 3, new Scalar(0,0,255), -1);
Imgproc.circle(image, new Point(box1.rightBottom.x, box1.rightBottom.y), 3, new Scalar(0,0,255), -1);
Imgproc.circle(image, new Point(box1.leftBottom.x, box1.leftBottom.y), 3, new Scalar(0,0,255), -1);
int pointNum = 1;
for(FaceInfo.Point keyPoint : faceInfo.points){
Imgproc.circle(image, new Point(keyPoint.x, keyPoint.y), 3, new Scalar(0,0,255), -1);
Imgproc.putText(image, String.valueOf(pointNum), new Point(keyPoint.x+1, keyPoint.y), Imgproc.FONT_HERSHEY_PLAIN, 1, new Scalar(255,0,0));
pointNum ++ ;
}
// Mat crop = CropUtil.crop(image, box);
// HighGui.imshow(fileName, crop);
// HighGui.waitKey();
}
HighGui.imshow(fileName, image);
HighGui.waitKey();
}
System.exit(1);
}
}

View File

@ -0,0 +1,175 @@
package com.visual.face.search.core.test.other;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.utils.MaskUtil;
import com.visual.face.search.core.utils.PointUtil;
import org.opencv.core.*;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.util.ArrayList;
import java.util.List;
public class TestCvFillPoly {
//静态加载动态链接库
static{ nu.pattern.OpenCV.loadShared(); }
public static void main(String[] args) {
String imagePath = "face-search-core/src/test/resources/images/faces/debug/debug_0007.jpg";
Mat image = Imgcodecs.imread(imagePath);
Point[] src_points = new Point[106];
FaceInfo.Points points = FaceInfo.Points.build();
for(int i=0; i< dst_points.length; i++){
points.add(PointUtil.convert(new Point(dst_points[i][0] * 439 / 112, dst_points[i][1] * 439 / 112)));
}
Mat d = MaskUtil.maskFor106InsightCoordModel(image, points, false);
// int [] indexes = new int[]{1,9,10,11,12,13,14,15,16,2,3,4,5,6,7,8,0,24,23,22,21,20,19,18,32,31,30,29,28,27,26,25,17,101,105,104,103,102,50,51,49,48,43};
// Point[] fill_points = new Point[indexes.length];
// for(int i=0; i< indexes.length; i++){
// fill_points[i] = src_points[indexes[i]];
// }
//
// // MatOfPoint.zeros(image.size(), CvType.CV_8U);
//// Mat pattern = new Mat(image.size(), CvType.CV_8U);
//
// Mat pattern = MatOfPoint.zeros(image.size(), CvType.CV_8U);;
// List<MatOfPoint> pts = new ArrayList<>();;
// pts.add(new MatOfPoint(fill_points));
// Imgproc.fillPoly(pattern, pts, new Scalar(1,1,1));
//
// HighGui.imshow("Drawing a polygon", pattern);
// HighGui.waitKey();
//
// for (Point src_point : src_points) {
// Imgproc.circle(image, src_point, 1, new Scalar(0, 0, 255), -1);
// }
// Mat d = new Mat();
// image.copyTo(d, pattern);
HighGui.imshow("fileName", d);
HighGui.waitKey();
System.exit(1);
// Imgproc.fil
}
/**对齐矩阵**/
private final static double[][] dst_points = new double[][]{
{56.9405, 104.8443},
{18.5795, 41.9579},
{29.7909, 83.8938},
{32.4892, 87.9299},
{35.5363, 91.7113},
{38.8522, 95.2713},
{42.4044, 98.8245},
{46.4592, 101.8917},
{51.2277, 104.0554},
{18.6211, 47.0351},
{19.1112, 51.9579},
{19.9591, 56.7764},
{21.0883, 61.467},
{22.4328, 66.0933},
{23.9187, 70.6575},
{25.5052, 75.2115},
{27.4518, 79.6344},
{94.3557, 40.9158},
{84.5255, 83.7362},
{81.8184, 87.8399},
{78.7297, 91.68},
{75.3339, 95.2649},
{71.7154, 98.8306},
{67.5886, 101.8988},
{62.7335, 104.062},
{94.4373, 46.0395},
{94.149, 51.0301},
{93.5035, 55.9094},
{92.5847, 60.6884},
{91.4355, 65.4221},
{90.1689, 70.1432},
{88.7174, 74.8307},
{86.8375, 79.3739},
{37.6054, 49.8439},
{38.3254, 46.2809},
{30.5494, 46.5533},
{33.5341, 48.6793},
{42.097, 49.5744},
{38.3255, 46.2799},
{46.1797, 48.9365},
{38.3534, 43.7203},
{33.9532, 44.3446},
{42.8563, 45.3288},
{23.6783, 36.8122},
{29.0158, 35.7892},
{34.6376, 36.2009},
{46.5079, 40.1862},
{40.5936, 37.8703},
{28.5353, 32.615},
{35.1155, 32.4559},
{47.2983, 37.4058},
{41.5737, 34.4037},
{44.3619, 84.8402},
{56.3115, 91.5205},
{50.3433, 85.5404},
{47.3686, 88.0585},
{51.1102, 90.5446},
{62.4342, 85.3069},
{65.6106, 87.7116},
{61.6732, 90.382},
{56.2631, 86.0649},
{68.7674, 84.2605},
{56.1979, 85.0424},
{53.0131, 81.1738},
{48.3541, 82.7133},
{46.1528, 84.9518},
{50.31, 84.8129},
{59.1781, 81.046},
{64.1735, 82.3522},
{66.8877, 84.454},
{62.3321, 84.5486},
{56.1101, 81.9197},
{55.4506, 47.1509},
{55.5035, 55.1007},
{55.5601, 63.0179},
{50.1072, 49.3489},
{47.9679, 65.6062},
{45.988, 71.6565},
{48.8313, 73.9777},
{52.0659, 74.8593},
{55.7886, 76.272},
{60.909, 49.2251},
{63.4678, 65.3933},
{65.718, 71.3635},
{62.8521, 73.7614},
{59.5402, 74.8053},
{55.5967, 70.9364},
{73.4902, 49.5798},
{72.4674, 46.0168},
{64.9145, 48.9078},
{68.9848, 49.4388},
{77.5852, 48.3019},
{72.4663, 46.0169},
{80.5914, 46.1361},
{72.5983, 43.4159},
{68.0961, 45.1217},
{77.0717, 43.964},
{63.9449, 40.2127},
{69.945, 37.6809},
{76.0014, 35.8078},
{81.9096, 35.2657},
{87.7462, 36.15},
{63.0412, 37.4403},
{68.8145, 34.2255},
{75.4285, 32.0634},
{82.3109, 32.0248}
};
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 150 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 248 KiB

102
face-search-server/pom.xml Normal file
View File

@ -0,0 +1,102 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>face-search</artifactId>
<groupId>com.visual.face.search</groupId>
<version>1.0.0</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>face-search-server</artifactId>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.visual.face.search</groupId>
<artifactId>face-search-core</artifactId>
</dependency>
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
</dependency>
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis-spring</artifactId>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
<!--向量数据库引擎-->
<dependency>
<groupId>com.alibaba.proxima</groupId>
<artifactId>proxima-be-java-sdk</artifactId>
</dependency>
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-java-sdk</artifactId>
</dependency>
<!--阿里数据库连接池 -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid-spring-boot-starter</artifactId>
</dependency>
<!-- SpringBoot 拦截器 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!--常用工具类 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
</dependency>
<!--分页插件-->
<dependency>
<groupId>com.github.pagehelper</groupId>
<artifactId>pagehelper</artifactId>
</dependency>
<dependency>
<groupId>com.github.pagehelper</groupId>
<artifactId>pagehelper-spring-boot-starter</artifactId>
</dependency>
<!--文档插件-->
<dependency>
<groupId>io.springfox</groupId>
<artifactId>springfox-swagger2</artifactId>
</dependency>
<dependency>
<groupId>com.github.xiaoymin</groupId>
<artifactId>swagger-bootstrap-ui</artifactId>
</dependency>
</dependencies>
<build>
<finalName>face-search-server</finalName>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,14 @@
package com.visual.face.search.server.bootstrap;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
@SpringBootApplication(exclude = { DataSourceAutoConfiguration.class })
public class FaceSearchApplication {
public static void main(String[] args) {
SpringApplication.run(FaceSearchApplication.class, args);
}
}

View File

@ -0,0 +1,49 @@
package com.visual.face.search.server.bootstrap.conf;
import com.alibaba.proxima.be.client.ConnectParam;
import com.alibaba.proxima.be.client.ProximaGrpcSearchClient;
import com.alibaba.proxima.be.client.ProximaSearchClient;
import com.visual.face.search.server.engine.api.SearchEngine;
import com.visual.face.search.server.engine.impl.MilvusSearchEngine;
import com.visual.face.search.server.engine.impl.ProximaSearchEngine;
import io.milvus.client.MilvusServiceClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration("visualEngineConfig")
public class EngineConfig {
//日志
public Logger logger = LoggerFactory.getLogger(getClass());
@Value("${visual.engine.selected:proxima}")
private String selected;
@Value("${visual.engine.proxima.host}")
private String proximaHost;
@Value("${visual.engine.proxima.port:16000}")
private Integer proximaPort;
@Value("${visual.engine.milvus.host}")
private String milvusHost;
@Value("${visual.engine.milvus.port:19530}")
private Integer milvusPort;
@Bean(name = "visualSearchEngine")
public SearchEngine getSearchEngine(){
if(selected.equalsIgnoreCase("milvus")){
logger.info("current vector engine is milvus");
io.milvus.param.ConnectParam connectParam = io.milvus.param.ConnectParam.newBuilder().withHost(milvusHost).withPort(milvusPort).build();
MilvusServiceClient client = new MilvusServiceClient(connectParam);
return new MilvusSearchEngine(client);
}else{
logger.info("current vector engine is proxima");
ConnectParam connectParam = ConnectParam.newBuilder().withHost(proximaHost).withPort(proximaPort).build();
ProximaSearchClient client = new ProximaGrpcSearchClient(connectParam);
return new ProximaSearchEngine(client);
}
}
}

View File

@ -0,0 +1,177 @@
package com.visual.face.search.server.bootstrap.conf;
import com.visual.face.search.core.base.FaceAlignment;
import com.visual.face.search.core.base.FaceDetection;
import com.visual.face.search.core.base.FaceKeyPoint;
import com.visual.face.search.core.base.FaceRecognition;
import com.visual.face.search.core.extract.FaceFeatureExtractor;
import com.visual.face.search.core.extract.FaceFeatureExtractorImpl;
import com.visual.face.search.core.models.*;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration("visualModelConfig")
public class ModelConfig {
@Value("${spring.profiles.active}")
private String profile;
@Value("${visual.model.faceDetection.name}")
private String faceDetectionName;
@Value("${visual.model.faceDetection.modelPath}")
private String[] faceDetectionModel;
@Value("${visual.model.faceDetection.thread:4}")
private Integer faceDetectionThread;
@Value("${visual.model.faceDetection.backup.name}")
private String backupFaceDetectionName;
@Value("${visual.model.faceDetection.backup.modelPath}")
private String[] backupFaceDetectionModel;
@Value("${visual.model.faceDetection.backup.thread:4}")
private Integer backupFaceDetectionThread;
@Value("${visual.model.faceKeyPoint.name:InsightCoordFaceKeyPoint}")
private String faceKeyPointName;
@Value("${visual.model.faceKeyPoint.modelPath}")
private String[] faceKeyPointModel;
@Value("${visual.model.faceKeyPoint.thread:4}")
private Integer faceKeyPointThread;
@Value("${visual.model.faceAlignment.name:Simple005pFaceAlignment}")
private String faceAlignmentName;
@Value("${visual.model.faceRecognition.name:InsightArcFaceRecognition}")
private String faceRecognitionName;
@Value("${visual.model.faceRecognition.modelPath}")
private String[] faceRecognitionNameModel;
@Value("${visual.model.faceRecognition.thread:4}")
private Integer faceRecognitionNameThread;
/**
* 获取人脸识别模型
* @return
*/
@Bean(name = "visualFaceDetection")
public FaceDetection getFaceDetection(){
if(faceDetectionName.equalsIgnoreCase("PcnNetworkFaceDetection")){
return new PcnNetworkFaceDetection(getModelPath(faceDetectionName, faceDetectionModel), faceDetectionThread);
}else if(faceDetectionName.equalsIgnoreCase("InsightScrfdFaceDetection")){
return new InsightScrfdFaceDetection(getModelPath(faceDetectionName, faceDetectionModel)[0], faceDetectionThread);
}else{
return new PcnNetworkFaceDetection(getModelPath(faceDetectionName, faceDetectionModel), faceDetectionThread);
}
}
/**
* 获取人脸识别模型
* @return
*/
@Bean(name = "visualBackupFaceDetection")
public FaceDetection getBackupFaceDetection(){
if(faceDetectionName.equalsIgnoreCase(backupFaceDetectionName)){
return null;
}else if(backupFaceDetectionName.equalsIgnoreCase("PcnNetworkFaceDetection")){
return new PcnNetworkFaceDetection(getModelPath(backupFaceDetectionName, backupFaceDetectionModel), backupFaceDetectionThread);
}else if(backupFaceDetectionName.equalsIgnoreCase("InsightScrfdFaceDetection")){
return new InsightScrfdFaceDetection(getModelPath(backupFaceDetectionName, backupFaceDetectionModel)[0], backupFaceDetectionThread);
}else{
return new PcnNetworkFaceDetection(backupFaceDetectionModel, backupFaceDetectionThread);
}
}
/**
* 关键点标记服务
* @return
*/
@Bean(name = "visualFaceKeyPoint")
public FaceKeyPoint getFaceKeyPoint(){
if(faceKeyPointName.equalsIgnoreCase("InsightCoordFaceKeyPoint")){
return new InsightCoordFaceKeyPoint(getModelPath(faceKeyPointName, faceKeyPointModel)[0], faceKeyPointThread);
}else{
return new InsightCoordFaceKeyPoint(getModelPath(faceKeyPointName, faceKeyPointModel)[0], faceKeyPointThread);
}
}
/**
* 人脸对齐服务
* @return
*/
@Bean(name = "visualFaceAlignment")
public FaceAlignment getFaceAlignment(){
if(faceAlignmentName.equalsIgnoreCase("Simple005pFaceAlignment")){
return new Simple005pFaceAlignment();
}else if(faceAlignmentName.equalsIgnoreCase("Simple106pFaceAlignment")){
return new Simple106pFaceAlignment();
}else{
return new Simple005pFaceAlignment();
}
}
/**
* 人脸特征提取服务
* @return
*/
@Bean(name = "visualFaceRecognition")
public FaceRecognition getFaceRecognition(){
if(faceRecognitionName.equalsIgnoreCase("InsightArcFaceRecognition")){
return new InsightArcFaceRecognition(getModelPath(faceRecognitionName, faceRecognitionNameModel)[0], faceRecognitionNameThread);
}else{
return new InsightArcFaceRecognition(getModelPath(faceRecognitionName, faceRecognitionNameModel)[0], faceRecognitionNameThread);
}
}
/**
* 构建特征提取器
* @param faceDetection 人脸识别模型
* @param faceKeyPoint 人脸关键点模型
* @param faceAlignment 人脸对齐模型
* @param faceRecognition 人脸特征提取模型
*/
@Bean(name = "visualFaceFeatureExtractor")
public FaceFeatureExtractor getFaceFeatureExtractor(
@Qualifier("visualFaceDetection")FaceDetection faceDetection,
@Qualifier("visualBackupFaceDetection")FaceDetection backupFaceDetection,
@Qualifier("visualFaceKeyPoint")FaceKeyPoint faceKeyPoint,
@Qualifier("visualFaceAlignment")FaceAlignment faceAlignment,
@Qualifier("visualFaceRecognition")FaceRecognition faceRecognition){
return new FaceFeatureExtractorImpl(faceDetection, backupFaceDetection, faceKeyPoint, faceAlignment, faceRecognition);
}
/**
* 获取模型路径
* @param modelName 模型名称
* @return
*/
private String[] getModelPath(String modelName, String modelPath[]){
String basePath = "face-search-core/src/main/resources/";
if("docker".equalsIgnoreCase(profile)){
basePath = "/app/face-search/";
}
if((null == modelPath || modelPath.length != 3) && "PcnNetworkFaceDetection".equalsIgnoreCase(modelName)){
return new String[]{
basePath + "model/onnx/detection_face_pcn/pcn1_sd.onnx",
basePath + "model/onnx/detection_face_pcn/pcn2_sd.onnx",
basePath + "model/onnx/detection_face_pcn/pcn3_sd.onnx"
};
}
if((null == modelPath || modelPath.length != 1) && "InsightScrfdFaceDetection".equalsIgnoreCase(modelName)){
return new String[]{basePath + "model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx"};
}
if((null == modelPath || modelPath.length != 1) && "InsightCoordFaceKeyPoint".equalsIgnoreCase(modelName)){
return new String[]{basePath + "model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx"};
}
if((null == modelPath || modelPath.length != 1) && "InsightArcFaceRecognition".equalsIgnoreCase(modelName)){
return new String[]{basePath + "model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx"};
}
return modelPath;
}
}

View File

@ -0,0 +1,28 @@
package com.visual.face.search.server.bootstrap.conf;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.transaction.annotation.EnableTransactionManagement;
@Configuration("visualServerConfig")
@EnableTransactionManagement
public class ServerConfig {
@Configuration
@MapperScan("com.visual.face.search.server.mapper")
public static class MapperConfig {}
@Configuration
@ComponentScan("com.visual.face.search.server.config")
public static class SearchConfig {}
@Configuration
@ComponentScan({"com.visual.face.search.server.service"})
public static class ServiceConfig {}
@Configuration
@ComponentScan({"com.visual.face.search.server.controller"})
public static class ControllerConfig {}
}

View File

@ -0,0 +1,45 @@
package com.visual.face.search.server.bootstrap.conf;
import com.github.xiaoymin.swaggerbootstrapui.annotations.EnableSwaggerBootstrapUI;
import io.swagger.annotations.Api;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import springfox.documentation.builders.ApiInfoBuilder;
import springfox.documentation.builders.RequestHandlerSelectors;
import springfox.documentation.service.ApiInfo;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spring.web.plugins.Docket;
import springfox.documentation.swagger2.annotations.EnableSwagger2;
@Configuration
@EnableSwagger2
@EnableSwaggerBootstrapUI
public class SwaggerConfig implements WebMvcConfigurer {
@Value("${visual.swagger.enable:true}")
private Boolean enable;
@Bean
public Docket ProductApi() {
return new Docket(DocumentationType.SWAGGER_2)
.enable(enable)
.useDefaultResponseMessages(false)
.forCodeGeneration(false)
.pathMapping("/")
.apiInfo(apiInfo())
.select()
.apis(RequestHandlerSelectors.withClassAnnotation(Api.class))
.build();
}
private ApiInfo apiInfo() {
return new ApiInfoBuilder()
.title("人脸搜索服务API")
.description("人脸搜索服务API")
.version("1.0.0")
.build();
}
}

View File

@ -0,0 +1,121 @@
package com.visual.face.search.server.config;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.spring.boot.autoconfigure.DruidDataSourceBuilder;
import com.alibaba.druid.spring.boot.autoconfigure.properties.DruidStatProperties;
import com.alibaba.druid.util.Utils;
import com.visual.face.search.core.common.enums.DataSourceType;
import com.visual.face.search.server.config.datasource.DynamicDataSource;
import com.visual.face.search.server.config.properties.DruidProperties;
import com.visual.face.search.server.utils.SpringUtils;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import javax.servlet.*;
import javax.sql.DataSource;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
* druid 配置多数据源
*
* @author diven
*/
@Configuration
public class DruidConfig
{
@Bean
@ConfigurationProperties("spring.datasource.druid.master")
public DataSource masterDataSource(DruidProperties druidProperties)
{
DruidDataSource dataSource = DruidDataSourceBuilder.create().build();
return druidProperties.dataSource(dataSource);
}
@Bean
@ConfigurationProperties("spring.datasource.druid.slave")
@ConditionalOnProperty(prefix = "spring.datasource.druid.slave", name = "enabled", havingValue = "true")
public DataSource slaveDataSource(DruidProperties druidProperties) {
DruidDataSource dataSource = DruidDataSourceBuilder.create().build();
return druidProperties.dataSource(dataSource);
}
@Bean(name = "dynamicDataSource")
@Primary
public DynamicDataSource dataSource(DataSource masterDataSource) {
Map<Object, Object> targetDataSources = new HashMap<>();
targetDataSources.put(DataSourceType.MASTER.name(), masterDataSource);
setDataSource(targetDataSources, DataSourceType.SLAVE.name(), "slaveDataSource");
return new DynamicDataSource(masterDataSource, targetDataSources);
}
/**
* 设置数据源
*
* @param targetDataSources 备选数据源集合
* @param sourceName 数据源名称
* @param beanName bean名称
*/
public void setDataSource(Map<Object, Object> targetDataSources, String sourceName, String beanName) {
try
{
DataSource dataSource = SpringUtils.getBean(beanName);
targetDataSources.put(sourceName, dataSource);
}
catch (Exception e)
{
}
}
/**
* 去除监控页面底部的广告
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
@Bean
@ConditionalOnProperty(name = "spring.datasource.druid.statViewServlet.enabled", havingValue = "true")
public FilterRegistrationBean removeDruidFilterRegistrationBean(DruidStatProperties properties) {
// 获取web监控页面的参数
DruidStatProperties.StatViewServlet config = properties.getStatViewServlet();
// 提取common.js的配置路径
String pattern = config.getUrlPattern() != null ? config.getUrlPattern() : "/druid/*";
String commonJsPattern = pattern.replaceAll("\\*", "js/common.js");
final String filePath = "support/http/resources/js/common.js";
// 创建filter进行过滤
Filter filter = new Filter()
{
@Override
public void init(javax.servlet.FilterConfig filterConfig) throws ServletException
{
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException
{
chain.doFilter(request, response);
// 重置缓冲区响应头不会被重置
response.resetBuffer();
// 获取common.js
String text = Utils.readFromResource(filePath);
// 正则替换banner, 除去底部的广告信息
text = text.replaceAll("<a.*?banner\"></a><br/>", "");
text = text.replaceAll("powered.*?shrek.wang</a>", "");
response.getWriter().write(text);
}
@Override
public void destroy()
{
}
};
FilterRegistrationBean registrationBean = new FilterRegistrationBean();
registrationBean.setFilter(filter);
registrationBean.addUrlPatterns(commonJsPattern);
return registrationBean;
}
}

View File

@ -0,0 +1,61 @@
package com.visual.face.search.server.config.aspectj;
import com.visual.face.search.core.common.annotation.DataSource;
import com.visual.face.search.server.config.datasource.DynamicDataSourceContextHolder;
import com.visual.face.search.server.utils.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.Objects;
/**
* 多数据源处理
*
* @author diven
*/
@Aspect
@Order(1)
@Component
public class DataSourceAspect {
protected Logger logger = LoggerFactory.getLogger(getClass());
@Pointcut("@annotation(com.visual.face.search.core.common.annotation.DataSource)" + "|| @within(com.visual.face.search.core.common.annotation.DataSource)")
public void dsPointCut() {}
@Around("dsPointCut()")
public Object around(ProceedingJoinPoint point) throws Throwable {
DataSource dataSource = getDataSource(point);
if (StringUtils.isNotNull(dataSource)) {
DynamicDataSourceContextHolder.setDataSourceType(dataSource.value().name());
}
try{
return point.proceed();
}
finally{
// 销毁数据源 在执行方法之后
DynamicDataSourceContextHolder.clearDataSourceType();
}
}
/**
* 获取需要切换的数据源
*/
public DataSource getDataSource(ProceedingJoinPoint point) {
MethodSignature signature = (MethodSignature) point.getSignature();
DataSource dataSource = AnnotationUtils.findAnnotation(signature.getMethod(), DataSource.class);
if (Objects.nonNull(dataSource)) {
return dataSource;
}
return AnnotationUtils.findAnnotation(signature.getDeclaringType(), DataSource.class);
}
}

View File

@ -0,0 +1,56 @@
package com.visual.face.search.server.config.datainit;
import java.io.Reader;
import javax.sql.DataSource;
import java.sql.Connection;
import org.apache.ibatis.io.Resources;
import org.apache.ibatis.jdbc.ScriptRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
@Component
public class DataSourceInitializer implements InitializingBean {
public Logger logger = LoggerFactory.getLogger(getClass());
@Autowired
ApplicationContext applicationContext;
@Override
public void afterPropertiesSet() throws Exception {
logger.info("start init schema info.");
DataSource dataSource = applicationContext.getBean(DataSource.class);
long startCurrentTime = System.currentTimeMillis();
boolean initFlag = false;
Exception exception = null;
while ((System.currentTimeMillis() - startCurrentTime) <= 2 * 60 * 1000){
try (
Connection connection= dataSource.getConnection();
Reader resourceAsReader = Resources.getResourceAsReader("sqls/schema-init.sql")
){
ScriptRunner runner=new ScriptRunner(connection);
runner.setLogWriter(null);
runner.setErrorLogWriter(null);
runner.runScript(resourceAsReader);
initFlag = true;
break;
} catch (Exception e) {
exception = e;
}
}
if(initFlag){
logger.info("success init schema info.");
}else{
if(null == exception){
logger.error("run schema-init error");
}else{
logger.error("run schema-init error", exception);
}
}
}
}

View File

@ -0,0 +1,27 @@
package com.visual.face.search.server.config.datasource;
import java.util.Map;
import javax.sql.DataSource;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
import com.visual.face.search.server.config.datasource.DynamicDataSourceContextHolder;
/**
* 动态数据源
*
* @author diven
*/
public class DynamicDataSource extends AbstractRoutingDataSource
{
public DynamicDataSource(DataSource defaultTargetDataSource, Map<Object, Object> targetDataSources)
{
super.setDefaultTargetDataSource(defaultTargetDataSource);
super.setTargetDataSources(targetDataSources);
super.afterPropertiesSet();
}
@Override
protected Object determineCurrentLookupKey()
{
return DynamicDataSourceContextHolder.getDataSourceType();
}
}

View File

@ -0,0 +1,45 @@
package com.visual.face.search.server.config.datasource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* 数据源切换处理
*
* @author diven
*/
public class DynamicDataSourceContextHolder
{
public static final Logger log = LoggerFactory.getLogger(DynamicDataSourceContextHolder.class);
/**
* 使用ThreadLocal维护变量ThreadLocal为每个使用该变量的线程提供独立的变量副本
* 所以每一个线程都可以独立地改变自己的副本而不会影响其它线程所对应的副本
*/
private static final ThreadLocal<String> CONTEXT_HOLDER = new ThreadLocal<>();
/**
* 设置数据源的变量
*/
public static void setDataSourceType(String dsType)
{
log.info("切换到{}数据源", dsType);
CONTEXT_HOLDER.set(dsType);
}
/**
* 获得数据源的变量
*/
public static String getDataSourceType()
{
return CONTEXT_HOLDER.get();
}
/**
* 清空数据源变量
*/
public static void clearDataSourceType()
{
CONTEXT_HOLDER.remove();
}
}

Some files were not shown because too many files have changed in this diff Show More