mirror of
https://gitee.com/open-visual/face-search.git
synced 2025-07-25 19:41:42 +08:00
!17 添加open search近似knn支持
Merge pull request !17 from foy/approximateKnnSupport
This commit is contained in:
commit
6f36d44563
@ -48,7 +48,8 @@ public class CollectHandler extends BaseHandler<CollectHandler>{
|
||||
.setFaceColumns(collect.getFaceColumns())
|
||||
.setShardsNum(collect.getShardsNum())
|
||||
.setStorageFaceInfo(collect.getStorageFaceInfo())
|
||||
.setStorageEngine(collect.getStorageEngine());
|
||||
.setStorageEngine(collect.getStorageEngine())
|
||||
.setApproximateKnn(collect.isApproximateKnn());
|
||||
return HttpClient.post(Api.getUrl(this.serverHost, Api.collect_create), collectReq);
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,8 @@ public class SearchHandler extends BaseHandler<SearchHandler>{
|
||||
.setMaxFaceNum(search.getMaxFaceNum())
|
||||
.setLimit(search.getLimit())
|
||||
.setConfidenceThreshold(search.getConfidenceThreshold())
|
||||
.setFaceScoreThreshold(search.getFaceScoreThreshold());
|
||||
.setFaceScoreThreshold(search.getFaceScoreThreshold())
|
||||
.setApproximateKnn(search.isApproximateKnn());
|
||||
return HttpClient.post(Api.getUrl(this.serverHost, Api.visual_search), searchReq, new TypeReference<Response<List<SearchRep>>>(){});
|
||||
}
|
||||
}
|
||||
|
@ -23,6 +23,9 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
|
||||
/**保留图片及人脸信息的存储组件**/
|
||||
private StorageEngine storageEngine;
|
||||
|
||||
/**是否启用近似knn搜索**/
|
||||
private boolean approximateKnn;
|
||||
|
||||
/**
|
||||
* 构建集合对象
|
||||
* @return
|
||||
@ -51,6 +54,15 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public boolean isApproximateKnn(){
|
||||
return this.approximateKnn;
|
||||
}
|
||||
|
||||
public ExtendsVo setApproximateKnn(boolean approximateKnn){
|
||||
this.approximateKnn = approximateKnn;
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public Integer getShardsNum() {
|
||||
return shardsNum;
|
||||
}
|
||||
|
@ -15,6 +15,8 @@ public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable
|
||||
/**对输入图像中多少个人脸进行检索比对:默认5**/
|
||||
private Integer maxFaceNum=5;
|
||||
|
||||
/**是否启用近似knn搜索**/
|
||||
private boolean approximateKnn = false;
|
||||
/**
|
||||
* 构建检索对象
|
||||
* @param imageBase64 待检索的图片
|
||||
@ -33,6 +35,15 @@ public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public boolean isApproximateKnn() {
|
||||
return approximateKnn;
|
||||
}
|
||||
|
||||
public ExtendsVo setApproximateKnn(boolean approximateKnn) {
|
||||
this.approximateKnn = approximateKnn;
|
||||
return (ExtendsVo)this;
|
||||
}
|
||||
|
||||
public Float getFaceScoreThreshold() {
|
||||
return faceScoreThreshold;
|
||||
}
|
||||
|
@ -14,7 +14,11 @@ public interface SearchEngine {
|
||||
|
||||
public boolean dropCollection(String collectionName);
|
||||
|
||||
public boolean createCollection(String collectionName, MapParam param);
|
||||
default public boolean createCollection(String collectionName, MapParam param){
|
||||
return createCollection(collectionName,param,false);
|
||||
};
|
||||
|
||||
public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn);
|
||||
|
||||
public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors);
|
||||
|
||||
@ -22,7 +26,11 @@ public interface SearchEngine {
|
||||
|
||||
public boolean deleteVectorByKey(String collectionName, List<String> faceIds);
|
||||
|
||||
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK);
|
||||
default public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK){
|
||||
return search(collectionName,features,algorithm,topK,false);
|
||||
};
|
||||
|
||||
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK, boolean approximateKnn);
|
||||
|
||||
public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm);
|
||||
|
||||
|
@ -4,6 +4,7 @@ public class Constant {
|
||||
|
||||
public final static String IndexShardsNum = "shardsNum";
|
||||
public final static String IndexReplicasNum = "replicasNum";
|
||||
public final static String IndexAlgoParamEfSearch = "algoParamEfSearch";
|
||||
|
||||
public final static String ColumnNameFaceId = "face_id";
|
||||
public final static String ColumnNameSampleId = "sample_id";
|
||||
|
@ -3,6 +3,7 @@ package com.visual.face.search.engine.impl;
|
||||
import com.visual.face.search.engine.api.SearchEngine;
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import com.visual.face.search.engine.exps.SearchEngineException;
|
||||
import com.visual.face.search.engine.impl.query.ApproximateKnnQueryBuilder;
|
||||
import com.visual.face.search.engine.model.*;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.opensearch.action.DocWriteResponse;
|
||||
@ -71,17 +72,32 @@ public class OpenSearchEngine implements SearchEngine {
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean createCollection(String collectionName, MapParam param) {
|
||||
public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn) {
|
||||
try {
|
||||
//构建请求
|
||||
CreateIndexRequest createIndexRequest = new CreateIndexRequest(collectionName);
|
||||
createIndexRequest.settings(Settings.builder()
|
||||
.put("index.number_of_shards", param.getIndexShardsNum())
|
||||
.put("index.number_of_replicas", param.getIndexReplicasNum())
|
||||
);
|
||||
Settings.Builder builder = Settings.builder()
|
||||
.put("index.number_of_shards", param.getIndexShardsNum())
|
||||
.put("index.number_of_replicas", param.getIndexReplicasNum());
|
||||
if(approximateKnn){
|
||||
//启用open search近似knn搜索支持
|
||||
builder.put("index.knn",true);
|
||||
builder.put("index.knn.algo_param.ef_search",param.getIndexAlgoParamEfSearch());
|
||||
}
|
||||
|
||||
createIndexRequest.settings(builder);
|
||||
HashMap<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constant.ColumnNameSampleId, Map.of("type", "keyword"));
|
||||
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512"));
|
||||
if(approximateKnn){
|
||||
//启用open search近似knn搜索支持
|
||||
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512",
|
||||
"method",Map.of("engine","nmslib",
|
||||
"space_type","cosinesimil",
|
||||
"name","hnsw",
|
||||
"parameters",Map.of())));
|
||||
}else {
|
||||
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512"));
|
||||
}
|
||||
createIndexRequest.mapping(Map.of("properties", properties));
|
||||
//创建集合
|
||||
CreateIndexResponse createIndexResponse = client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
|
||||
@ -135,23 +151,37 @@ public class OpenSearchEngine implements SearchEngine {
|
||||
}
|
||||
|
||||
@Override
|
||||
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK) {
|
||||
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK,boolean approximateKnn) {
|
||||
try {
|
||||
//构建搜索请求
|
||||
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
|
||||
for(float[] feature : features){
|
||||
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("field", Constant.ColumnNameFaceVector);
|
||||
params.put("space_type", algorithm);
|
||||
params.put("query_value", feature);
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
|
||||
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
|
||||
.query(scriptScoreQueryBuilder).size(topK)
|
||||
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段
|
||||
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
|
||||
multiSearchRequest.add(searchRequest);
|
||||
if(approximateKnn){
|
||||
//近似knn搜索
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("vector",feature);
|
||||
params.put("k",topK);
|
||||
ApproximateKnnQueryBuilder approximateKnnQueryBuilder = new ApproximateKnnQueryBuilder(params);
|
||||
SearchSourceBuilder searchSourceBuilder =new SearchSourceBuilder()
|
||||
.query(approximateKnnQueryBuilder).size(topK)
|
||||
.fetchSource(null,Constant.ColumnNameFaceVector);
|
||||
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
|
||||
multiSearchRequest.add(searchRequest);
|
||||
}else {
|
||||
//常规搜索
|
||||
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("field", Constant.ColumnNameFaceVector);
|
||||
params.put("space_type", algorithm);
|
||||
params.put("query_value", feature);
|
||||
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
|
||||
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
|
||||
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
|
||||
.query(scriptScoreQueryBuilder).size(topK)
|
||||
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段
|
||||
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
|
||||
multiSearchRequest.add(searchRequest);
|
||||
}
|
||||
}
|
||||
//查询索引
|
||||
MultiSearchResponse response = this.client.msearch(multiSearchRequest, RequestOptions.DEFAULT);
|
||||
@ -167,7 +197,7 @@ public class OpenSearchEngine implements SearchEngine {
|
||||
if(searchHits != null){
|
||||
for(SearchHit searchHit : searchHits){
|
||||
String faceId = searchHit.getId();
|
||||
float score = searchHit.getScore()-1;
|
||||
float score = approximateKnn? searchHit.getScore() : (searchHit.getScore()-1);
|
||||
Map<String, Object> sourceMap = searchHit.getSourceAsMap();
|
||||
String sampleId = MapUtils.getString(sourceMap, Constant.ColumnNameSampleId);
|
||||
Object faceVector = MapUtils.getObject(sourceMap, Constant.ColumnNameFaceVector);
|
||||
|
@ -0,0 +1,58 @@
|
||||
package com.visual.face.search.engine.impl.query;
|
||||
|
||||
import com.visual.face.search.engine.conf.Constant;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.opensearch.common.io.stream.StreamOutput;
|
||||
import org.opensearch.core.xcontent.XContentBuilder;
|
||||
import org.opensearch.index.query.AbstractQueryBuilder;
|
||||
import org.opensearch.index.query.QueryShardContext;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 近似KNN搜索
|
||||
* @Author Foy Lian
|
||||
* @Date 2023/9/13 14:26
|
||||
**/
|
||||
|
||||
public class ApproximateKnnQueryBuilder extends AbstractQueryBuilder<ApproximateKnnQueryBuilder> {
|
||||
|
||||
private Map<String, Object> mParams;
|
||||
|
||||
public ApproximateKnnQueryBuilder(Map<String, Object> params){
|
||||
this.mParams = params;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doWriteTo(StreamOutput streamOutput) throws IOException {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
|
||||
xContentBuilder.startObject("knn");
|
||||
xContentBuilder.field(Constant.ColumnNameFaceVector,mParams);
|
||||
xContentBuilder.endObject();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query doToQuery(QueryShardContext queryShardContext) throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean doEquals(ApproximateKnnQueryBuilder approximateKnnQueryBuilder) {
|
||||
return Objects.equals(this.mParams, approximateKnnQueryBuilder.mParams);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int doHashCode() {
|
||||
return Objects.hash(new Object[]{this.mParams});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return "knn";
|
||||
}
|
||||
}
|
@ -71,4 +71,9 @@ public class MapParam extends ConcurrentHashMap<String, Object> {
|
||||
return shardsNum;
|
||||
}
|
||||
|
||||
public Integer getIndexAlgoParamEfSearch(){
|
||||
Integer algoParamEfSearch = this.getInteger(Constant.IndexAlgoParamEfSearch, 512);
|
||||
algoParamEfSearch = (null == algoParamEfSearch || algoParamEfSearch <= 0) ? 512 : algoParamEfSearch;
|
||||
return algoParamEfSearch;
|
||||
}
|
||||
}
|
||||
|
@ -46,6 +46,11 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
@ApiModelProperty(value="保留图片及人脸信息的存储组件", position = 9,required = false)
|
||||
private StorageEngine storageEngine;
|
||||
|
||||
/**是否启用近似knn搜索**/
|
||||
@ApiModelProperty(value="是否启用近似knn搜索", position = 10,required = false)
|
||||
private boolean approximateKnn;
|
||||
|
||||
|
||||
/**
|
||||
* 构建集合对象
|
||||
* @param namespace 命名空间
|
||||
@ -56,6 +61,15 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
return new CollectVo().setNamespace(namespace).setCollectionName(collectionName);
|
||||
}
|
||||
|
||||
public boolean isApproximateKnn() {
|
||||
return approximateKnn;
|
||||
}
|
||||
|
||||
public ExtendsVo setApproximateKnn(boolean approximateKnn){
|
||||
this.approximateKnn = approximateKnn;
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
public String getNamespace() {
|
||||
return namespace;
|
||||
}
|
||||
@ -144,4 +158,5 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
|
||||
}
|
||||
return (ExtendsVo) this;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -42,6 +42,12 @@ public class FaceSearchReqVo extends BaseVo {
|
||||
@ApiModelProperty(value="对输入图像中多少个人脸进行检索比对:默认5", position = 7, required = false)
|
||||
private Integer maxFaceNum;
|
||||
|
||||
/**是否使用近似knn搜索**/
|
||||
@ApiModelProperty(value="是否使用近似knn搜索", position = 8, required = false)
|
||||
private boolean approximateKnn;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 构建检索对象
|
||||
* @param namespace 命名空间
|
||||
@ -130,4 +136,13 @@ public class FaceSearchReqVo extends BaseVo {
|
||||
this.maxFaceNum = maxFaceNum;
|
||||
return this;
|
||||
}
|
||||
|
||||
public boolean isApproximateKnn() {
|
||||
return approximateKnn;
|
||||
}
|
||||
|
||||
public FaceSearchReqVo setApproximateKnn(boolean approximateKnn) {
|
||||
this.approximateKnn = approximateKnn;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ public class CollectServiceImpl extends BaseService implements CollectService {
|
||||
MapParam param = MapParam.build()
|
||||
.put(Constant.IndexShardsNum, collect.getShardsNum())
|
||||
.put(Constant.IndexReplicasNum, collect.getReplicasNum());
|
||||
boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param);
|
||||
boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param,collect.isApproximateKnn());
|
||||
if(!createVectorFlag){
|
||||
throw new RuntimeException("create vector table error");
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
|
||||
}
|
||||
//特征搜索
|
||||
int topK = (null == search.getLimit() || search.getLimit() <= 0) ? 5 : search.getLimit();
|
||||
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK);
|
||||
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK,search.isApproximateKnn());
|
||||
if(!searchResponse.getStatus().ok()){
|
||||
throw new RuntimeException(searchResponse.getStatus().getReason());
|
||||
}
|
||||
|
@ -26,6 +26,12 @@ public class FaceSearchExample {
|
||||
public static String collectionName = "collect_20211201_v11";
|
||||
public static FaceSearch faceSearch = FaceSearch.build(serverHost, namespace, collectionName);
|
||||
|
||||
//是否启用近似knn,建议底库集比较大时启用.
|
||||
public static boolean approximateKnn = false;
|
||||
|
||||
//底库集比较大时,建议调大,如:32个分片
|
||||
public static int shardsNum = 4;
|
||||
|
||||
/**集合创建*/
|
||||
public static void collect(){
|
||||
//样本属性字段
|
||||
@ -44,7 +50,12 @@ public class FaceSearchExample {
|
||||
//是否保存人脸及图片数据信息
|
||||
.setStorageFaceInfo(true)
|
||||
//目前只实现了数据库存储,对其他类型存储实现StorageImageService接口即可
|
||||
.setStorageEngine(StorageEngine.CURR_DB);
|
||||
.setStorageEngine(StorageEngine.CURR_DB)
|
||||
//设置分片大小
|
||||
.setShardsNum(shardsNum)
|
||||
//开启关闭近似knn搜索
|
||||
.setApproximateKnn(approximateKnn);
|
||||
|
||||
//删除集合
|
||||
Response<Boolean> deleteCollect = faceSearch.collect().deleteCollect();
|
||||
System.out.println(deleteCollect);
|
||||
@ -97,6 +108,9 @@ public class FaceSearchExample {
|
||||
.search(Search.build(imageBase64)
|
||||
.setConfidenceThreshold(50f) //最小置信分:50
|
||||
.setMaxFaceNum(10).setLimit(1)
|
||||
//collect()创建集合时即使开了近似knn搜索,这里设置为false也可以使用精确knn搜索。
|
||||
//这里数据量足够大时,就能发现近似knn返回较快
|
||||
.setApproximateKnn(approximateKnn)
|
||||
);
|
||||
Long e = System.currentTimeMillis();
|
||||
System.out.println("search cost:" + (e-s)+"ms");
|
||||
|
Loading…
x
Reference in New Issue
Block a user