open search 近似knn搜索支持

This commit is contained in:
lianjunsheng 2023-11-15 16:59:54 +08:00
parent 20e1e8cb62
commit 4d0c1debd9
14 changed files with 199 additions and 28 deletions

View File

@ -48,7 +48,8 @@ public class CollectHandler extends BaseHandler<CollectHandler>{
.setFaceColumns(collect.getFaceColumns()) .setFaceColumns(collect.getFaceColumns())
.setShardsNum(collect.getShardsNum()) .setShardsNum(collect.getShardsNum())
.setStorageFaceInfo(collect.getStorageFaceInfo()) .setStorageFaceInfo(collect.getStorageFaceInfo())
.setStorageEngine(collect.getStorageEngine()); .setStorageEngine(collect.getStorageEngine())
.setApproximateKnn(collect.isApproximateKnn());
return HttpClient.post(Api.getUrl(this.serverHost, Api.collect_create), collectReq); return HttpClient.post(Api.getUrl(this.serverHost, Api.collect_create), collectReq);
} }

View File

@ -45,7 +45,8 @@ public class SearchHandler extends BaseHandler<SearchHandler>{
.setMaxFaceNum(search.getMaxFaceNum()) .setMaxFaceNum(search.getMaxFaceNum())
.setLimit(search.getLimit()) .setLimit(search.getLimit())
.setConfidenceThreshold(search.getConfidenceThreshold()) .setConfidenceThreshold(search.getConfidenceThreshold())
.setFaceScoreThreshold(search.getFaceScoreThreshold()); .setFaceScoreThreshold(search.getFaceScoreThreshold())
.setApproximateKnn(search.isApproximateKnn());
return HttpClient.post(Api.getUrl(this.serverHost, Api.visual_search), searchReq, new TypeReference<Response<List<SearchRep>>>(){}); return HttpClient.post(Api.getUrl(this.serverHost, Api.visual_search), searchReq, new TypeReference<Response<List<SearchRep>>>(){});
} }
} }

View File

@ -23,6 +23,9 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
/**保留图片及人脸信息的存储组件**/ /**保留图片及人脸信息的存储组件**/
private StorageEngine storageEngine; private StorageEngine storageEngine;
/**是否启用近似knn搜索**/
private boolean approximateKnn;
/** /**
* 构建集合对象 * 构建集合对象
* @return * @return
@ -51,6 +54,15 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
return (ExtendsVo) this; return (ExtendsVo) this;
} }
public boolean isApproximateKnn(){
return this.approximateKnn;
}
public ExtendsVo setApproximateKnn(boolean approximateKnn){
this.approximateKnn = approximateKnn;
return (ExtendsVo) this;
}
public Integer getShardsNum() { public Integer getShardsNum() {
return shardsNum; return shardsNum;
} }

View File

@ -15,6 +15,8 @@ public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable
/**对输入图像中多少个人脸进行检索比对默认5**/ /**对输入图像中多少个人脸进行检索比对默认5**/
private Integer maxFaceNum=5; private Integer maxFaceNum=5;
/**是否启用近似knn搜索**/
private boolean approximateKnn = false;
/** /**
* 构建检索对象 * 构建检索对象
* @param imageBase64 待检索的图片 * @param imageBase64 待检索的图片
@ -33,6 +35,15 @@ public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable
return (ExtendsVo) this; return (ExtendsVo) this;
} }
public boolean isApproximateKnn() {
return approximateKnn;
}
public ExtendsVo setApproximateKnn(boolean approximateKnn) {
this.approximateKnn = approximateKnn;
return (ExtendsVo)this;
}
public Float getFaceScoreThreshold() { public Float getFaceScoreThreshold() {
return faceScoreThreshold; return faceScoreThreshold;
} }

View File

@ -14,7 +14,11 @@ public interface SearchEngine {
public boolean dropCollection(String collectionName); public boolean dropCollection(String collectionName);
public boolean createCollection(String collectionName, MapParam param); default public boolean createCollection(String collectionName, MapParam param){
return createCollection(collectionName,param,false);
};
public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn);
public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors); public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors);
@ -22,7 +26,11 @@ public interface SearchEngine {
public boolean deleteVectorByKey(String collectionName, List<String> faceIds); public boolean deleteVectorByKey(String collectionName, List<String> faceIds);
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK); default public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK){
return search(collectionName,features,algorithm,topK,false);
};
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK, boolean approximateKnn);
public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm); public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm);

View File

@ -4,6 +4,7 @@ public class Constant {
public final static String IndexShardsNum = "shardsNum"; public final static String IndexShardsNum = "shardsNum";
public final static String IndexReplicasNum = "replicasNum"; public final static String IndexReplicasNum = "replicasNum";
public final static String IndexAlgoParamEfSearch = "algoParamEfSearch";
public final static String ColumnNameFaceId = "face_id"; public final static String ColumnNameFaceId = "face_id";
public final static String ColumnNameSampleId = "sample_id"; public final static String ColumnNameSampleId = "sample_id";

View File

@ -3,6 +3,7 @@ package com.visual.face.search.engine.impl;
import com.visual.face.search.engine.api.SearchEngine; import com.visual.face.search.engine.api.SearchEngine;
import com.visual.face.search.engine.conf.Constant; import com.visual.face.search.engine.conf.Constant;
import com.visual.face.search.engine.exps.SearchEngineException; import com.visual.face.search.engine.exps.SearchEngineException;
import com.visual.face.search.engine.impl.query.ApproximateKnnQueryBuilder;
import com.visual.face.search.engine.model.*; import com.visual.face.search.engine.model.*;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
import org.opensearch.action.DocWriteResponse; import org.opensearch.action.DocWriteResponse;
@ -71,17 +72,32 @@ public class OpenSearchEngine implements SearchEngine {
} }
@Override @Override
public boolean createCollection(String collectionName, MapParam param) { public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn) {
try { try {
//构建请求 //构建请求
CreateIndexRequest createIndexRequest = new CreateIndexRequest(collectionName); CreateIndexRequest createIndexRequest = new CreateIndexRequest(collectionName);
createIndexRequest.settings(Settings.builder() Settings.Builder builder = Settings.builder()
.put("index.number_of_shards", param.getIndexShardsNum()) .put("index.number_of_shards", param.getIndexShardsNum())
.put("index.number_of_replicas", param.getIndexReplicasNum()) .put("index.number_of_replicas", param.getIndexReplicasNum());
); if(approximateKnn){
//启用open search近似knn搜索支持
builder.put("index.knn",true);
builder.put("index.knn.algo_param.ef_search",param.getIndexAlgoParamEfSearch());
}
createIndexRequest.settings(builder);
HashMap<String, Object> properties = new HashMap<>(); HashMap<String, Object> properties = new HashMap<>();
properties.put(Constant.ColumnNameSampleId, Map.of("type", "keyword")); properties.put(Constant.ColumnNameSampleId, Map.of("type", "keyword"));
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512")); if(approximateKnn){
//启用open search近似knn搜索支持
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512",
"method",Map.of("engine","nmslib",
"space_type","cosinesimil",
"name","hnsw",
"parameters",Map.of())));
}else {
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512"));
}
createIndexRequest.mapping(Map.of("properties", properties)); createIndexRequest.mapping(Map.of("properties", properties));
//创建集合 //创建集合
CreateIndexResponse createIndexResponse = client.indices().create(createIndexRequest, RequestOptions.DEFAULT); CreateIndexResponse createIndexResponse = client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
@ -135,23 +151,37 @@ public class OpenSearchEngine implements SearchEngine {
} }
@Override @Override
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK) { public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK,boolean approximateKnn) {
try { try {
//构建搜索请求 //构建搜索请求
MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
for(float[] feature : features){ for(float[] feature : features){
QueryBuilder queryBuilder = new MatchAllQueryBuilder(); if(approximateKnn){
Map<String, Object> params = new HashMap<>(); //近似knn搜索
params.put("field", Constant.ColumnNameFaceVector); Map<String, Object> params = new HashMap<>();
params.put("space_type", algorithm); params.put("vector",feature);
params.put("query_value", feature); params.put("k",topK);
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params); ApproximateKnnQueryBuilder approximateKnnQueryBuilder = new ApproximateKnnQueryBuilder(params);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script); SearchSourceBuilder searchSourceBuilder =new SearchSourceBuilder()
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(approximateKnnQueryBuilder).size(topK)
.query(scriptScoreQueryBuilder).size(topK) .fetchSource(null,Constant.ColumnNameFaceVector);
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段 SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder); multiSearchRequest.add(searchRequest);
multiSearchRequest.add(searchRequest); }else {
//常规搜索
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
params.put("field", Constant.ColumnNameFaceVector);
params.put("space_type", algorithm);
params.put("query_value", feature);
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(scriptScoreQueryBuilder).size(topK)
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
multiSearchRequest.add(searchRequest);
}
} }
//查询索引 //查询索引
MultiSearchResponse response = this.client.msearch(multiSearchRequest, RequestOptions.DEFAULT); MultiSearchResponse response = this.client.msearch(multiSearchRequest, RequestOptions.DEFAULT);
@ -167,7 +197,7 @@ public class OpenSearchEngine implements SearchEngine {
if(searchHits != null){ if(searchHits != null){
for(SearchHit searchHit : searchHits){ for(SearchHit searchHit : searchHits){
String faceId = searchHit.getId(); String faceId = searchHit.getId();
float score = searchHit.getScore()-1; float score = approximateKnn? searchHit.getScore() : (searchHit.getScore()-1);
Map<String, Object> sourceMap = searchHit.getSourceAsMap(); Map<String, Object> sourceMap = searchHit.getSourceAsMap();
String sampleId = MapUtils.getString(sourceMap, Constant.ColumnNameSampleId); String sampleId = MapUtils.getString(sourceMap, Constant.ColumnNameSampleId);
Object faceVector = MapUtils.getObject(sourceMap, Constant.ColumnNameFaceVector); Object faceVector = MapUtils.getObject(sourceMap, Constant.ColumnNameFaceVector);

View File

@ -0,0 +1,58 @@
package com.visual.face.search.engine.impl.query;
import com.visual.face.search.engine.conf.Constant;
import org.apache.lucene.search.Query;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
/**
* 近似KNN搜索
* @Author Foy Lian
* @Date 2023/9/13 14:26
**/
public class ApproximateKnnQueryBuilder extends AbstractQueryBuilder<ApproximateKnnQueryBuilder> {
private Map<String, Object> mParams;
public ApproximateKnnQueryBuilder(Map<String, Object> params){
this.mParams = params;
}
@Override
protected void doWriteTo(StreamOutput streamOutput) throws IOException {
}
@Override
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject("knn");
xContentBuilder.field(Constant.ColumnNameFaceVector,mParams);
xContentBuilder.endObject();
}
@Override
protected Query doToQuery(QueryShardContext queryShardContext) throws IOException {
return null;
}
@Override
protected boolean doEquals(ApproximateKnnQueryBuilder approximateKnnQueryBuilder) {
return Objects.equals(this.mParams, approximateKnnQueryBuilder.mParams);
}
@Override
protected int doHashCode() {
return Objects.hash(new Object[]{this.mParams});
}
@Override
public String getWriteableName() {
return "knn";
}
}

View File

@ -71,4 +71,9 @@ public class MapParam extends ConcurrentHashMap<String, Object> {
return shardsNum; return shardsNum;
} }
public Integer getIndexAlgoParamEfSearch(){
Integer algoParamEfSearch = this.getInteger(Constant.IndexAlgoParamEfSearch, 512);
algoParamEfSearch = (null == algoParamEfSearch || algoParamEfSearch <= 0) ? 512 : algoParamEfSearch;
return algoParamEfSearch;
}
} }

View File

@ -46,6 +46,11 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
@ApiModelProperty(value="保留图片及人脸信息的存储组件", position = 9,required = false) @ApiModelProperty(value="保留图片及人脸信息的存储组件", position = 9,required = false)
private StorageEngine storageEngine; private StorageEngine storageEngine;
/**是否启用近似knn搜索**/
@ApiModelProperty(value="是否启用近似knn搜索", position = 10,required = false)
private boolean approximateKnn;
/** /**
* 构建集合对象 * 构建集合对象
* @param namespace 命名空间 * @param namespace 命名空间
@ -56,6 +61,15 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
return new CollectVo().setNamespace(namespace).setCollectionName(collectionName); return new CollectVo().setNamespace(namespace).setCollectionName(collectionName);
} }
public boolean isApproximateKnn() {
return approximateKnn;
}
public ExtendsVo setApproximateKnn(boolean approximateKnn){
this.approximateKnn = approximateKnn;
return (ExtendsVo) this;
}
public String getNamespace() { public String getNamespace() {
return namespace; return namespace;
} }
@ -144,4 +158,5 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
} }
return (ExtendsVo) this; return (ExtendsVo) this;
} }
} }

View File

@ -42,6 +42,12 @@ public class FaceSearchReqVo extends BaseVo {
@ApiModelProperty(value="对输入图像中多少个人脸进行检索比对默认5", position = 7, required = false) @ApiModelProperty(value="对输入图像中多少个人脸进行检索比对默认5", position = 7, required = false)
private Integer maxFaceNum; private Integer maxFaceNum;
/**是否使用近似knn搜索**/
@ApiModelProperty(value="是否使用近似knn搜索", position = 8, required = false)
private boolean approximateKnn;
/** /**
* 构建检索对象 * 构建检索对象
* @param namespace 命名空间 * @param namespace 命名空间
@ -130,4 +136,13 @@ public class FaceSearchReqVo extends BaseVo {
this.maxFaceNum = maxFaceNum; this.maxFaceNum = maxFaceNum;
return this; return this;
} }
public boolean isApproximateKnn() {
return approximateKnn;
}
public FaceSearchReqVo setApproximateKnn(boolean approximateKnn) {
this.approximateKnn = approximateKnn;
return this;
}
} }

View File

@ -90,7 +90,7 @@ public class CollectServiceImpl extends BaseService implements CollectService {
MapParam param = MapParam.build() MapParam param = MapParam.build()
.put(Constant.IndexShardsNum, collect.getShardsNum()) .put(Constant.IndexShardsNum, collect.getShardsNum())
.put(Constant.IndexReplicasNum, collect.getReplicasNum()); .put(Constant.IndexReplicasNum, collect.getReplicasNum());
boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param); boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param,collect.isApproximateKnn());
if(!createVectorFlag){ if(!createVectorFlag){
throw new RuntimeException("create vector table error"); throw new RuntimeException("create vector table error");
} }

View File

@ -83,7 +83,7 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
} }
//特征搜索 //特征搜索
int topK = (null == search.getLimit() || search.getLimit() <= 0) ? 5 : search.getLimit(); int topK = (null == search.getLimit() || search.getLimit() <= 0) ? 5 : search.getLimit();
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK); SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK,search.isApproximateKnn());
if(!searchResponse.getStatus().ok()){ if(!searchResponse.getStatus().ok()){
throw new RuntimeException(searchResponse.getStatus().getReason()); throw new RuntimeException(searchResponse.getStatus().getReason());
} }

View File

@ -26,6 +26,12 @@ public class FaceSearchExample {
public static String collectionName = "collect_20211201_v11"; public static String collectionName = "collect_20211201_v11";
public static FaceSearch faceSearch = FaceSearch.build(serverHost, namespace, collectionName); public static FaceSearch faceSearch = FaceSearch.build(serverHost, namespace, collectionName);
//是否启用近似knn,建议底库集比较大时启用.
public static boolean approximateKnn = false;
//底库集比较大时建议调大:32个分片
public static int shardsNum = 4;
/**集合创建*/ /**集合创建*/
public static void collect(){ public static void collect(){
//样本属性字段 //样本属性字段
@ -44,7 +50,12 @@ public class FaceSearchExample {
//是否保存人脸及图片数据信息 //是否保存人脸及图片数据信息
.setStorageFaceInfo(true) .setStorageFaceInfo(true)
//目前只实现了数据库存储对其他类型存储实现StorageImageService接口即可 //目前只实现了数据库存储对其他类型存储实现StorageImageService接口即可
.setStorageEngine(StorageEngine.CURR_DB); .setStorageEngine(StorageEngine.CURR_DB)
//设置分片大小
.setShardsNum(shardsNum)
//开启关闭近似knn搜索
.setApproximateKnn(approximateKnn);
//删除集合 //删除集合
Response<Boolean> deleteCollect = faceSearch.collect().deleteCollect(); Response<Boolean> deleteCollect = faceSearch.collect().deleteCollect();
System.out.println(deleteCollect); System.out.println(deleteCollect);
@ -97,6 +108,9 @@ public class FaceSearchExample {
.search(Search.build(imageBase64) .search(Search.build(imageBase64)
.setConfidenceThreshold(50f) //最小置信分50 .setConfidenceThreshold(50f) //最小置信分50
.setMaxFaceNum(10).setLimit(1) .setMaxFaceNum(10).setLimit(1)
//这里其实即使开了近似knn搜索这里设置为false也可以使用精确knn搜索
//这里数据量足够大时就能发现近似knn返回较快
.setApproximateKnn(approximateKnn)
); );
Long e = System.currentTimeMillis(); Long e = System.currentTimeMillis();
System.out.println("search cost:" + (e-s)+"ms"); System.out.println("search cost:" + (e-s)+"ms");