!17 添加open search近似knn支持

Merge pull request !17 from foy/approximateKnnSupport
This commit is contained in:
divenswu 2023-11-15 09:24:24 +00:00 committed by Gitee
commit 6f36d44563
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 199 additions and 28 deletions

View File

@ -48,7 +48,8 @@ public class CollectHandler extends BaseHandler<CollectHandler>{
.setFaceColumns(collect.getFaceColumns())
.setShardsNum(collect.getShardsNum())
.setStorageFaceInfo(collect.getStorageFaceInfo())
.setStorageEngine(collect.getStorageEngine());
.setStorageEngine(collect.getStorageEngine())
.setApproximateKnn(collect.isApproximateKnn());
return HttpClient.post(Api.getUrl(this.serverHost, Api.collect_create), collectReq);
}

View File

@ -45,7 +45,8 @@ public class SearchHandler extends BaseHandler<SearchHandler>{
.setMaxFaceNum(search.getMaxFaceNum())
.setLimit(search.getLimit())
.setConfidenceThreshold(search.getConfidenceThreshold())
.setFaceScoreThreshold(search.getFaceScoreThreshold());
.setFaceScoreThreshold(search.getFaceScoreThreshold())
.setApproximateKnn(search.isApproximateKnn());
return HttpClient.post(Api.getUrl(this.serverHost, Api.visual_search), searchReq, new TypeReference<Response<List<SearchRep>>>(){});
}
}

View File

@ -23,6 +23,9 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
/**保留图片及人脸信息的存储组件**/
private StorageEngine storageEngine;
/**是否启用近似knn搜索**/
private boolean approximateKnn;
/**
* 构建集合对象
* @return
@ -51,6 +54,15 @@ public class Collect<ExtendsVo extends Collect<ExtendsVo>> implements Serializab
return (ExtendsVo) this;
}
public boolean isApproximateKnn(){
return this.approximateKnn;
}
public ExtendsVo setApproximateKnn(boolean approximateKnn){
this.approximateKnn = approximateKnn;
return (ExtendsVo) this;
}
public Integer getShardsNum() {
return shardsNum;
}

View File

@ -15,6 +15,8 @@ public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable
/**对输入图像中多少个人脸进行检索比对默认5**/
private Integer maxFaceNum=5;
/**是否启用近似knn搜索**/
private boolean approximateKnn = false;
/**
* 构建检索对象
* @param imageBase64 待检索的图片
@ -33,6 +35,15 @@ public class Search<ExtendsVo extends Search<ExtendsVo>> implements Serializable
return (ExtendsVo) this;
}
public boolean isApproximateKnn() {
return approximateKnn;
}
public ExtendsVo setApproximateKnn(boolean approximateKnn) {
this.approximateKnn = approximateKnn;
return (ExtendsVo)this;
}
public Float getFaceScoreThreshold() {
return faceScoreThreshold;
}

View File

@ -14,7 +14,11 @@ public interface SearchEngine {
public boolean dropCollection(String collectionName);
public boolean createCollection(String collectionName, MapParam param);
default public boolean createCollection(String collectionName, MapParam param){
return createCollection(collectionName,param,false);
};
public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn);
public boolean insertVector(String collectionName, String sampleId, String faceId, float[] vectors);
@ -22,7 +26,11 @@ public interface SearchEngine {
public boolean deleteVectorByKey(String collectionName, List<String> faceIds);
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK);
default public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK){
return search(collectionName,features,algorithm,topK,false);
};
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK, boolean approximateKnn);
public float searchMinScoreBySampleId(String collectionName, String sampleId,float[] feature, String algorithm);

View File

@ -4,6 +4,7 @@ public class Constant {
public final static String IndexShardsNum = "shardsNum";
public final static String IndexReplicasNum = "replicasNum";
public final static String IndexAlgoParamEfSearch = "algoParamEfSearch";
public final static String ColumnNameFaceId = "face_id";
public final static String ColumnNameSampleId = "sample_id";

View File

@ -3,6 +3,7 @@ package com.visual.face.search.engine.impl;
import com.visual.face.search.engine.api.SearchEngine;
import com.visual.face.search.engine.conf.Constant;
import com.visual.face.search.engine.exps.SearchEngineException;
import com.visual.face.search.engine.impl.query.ApproximateKnnQueryBuilder;
import com.visual.face.search.engine.model.*;
import org.apache.commons.collections4.MapUtils;
import org.opensearch.action.DocWriteResponse;
@ -71,17 +72,32 @@ public class OpenSearchEngine implements SearchEngine {
}
@Override
public boolean createCollection(String collectionName, MapParam param) {
public boolean createCollection(String collectionName, MapParam param,boolean approximateKnn) {
try {
//构建请求
CreateIndexRequest createIndexRequest = new CreateIndexRequest(collectionName);
createIndexRequest.settings(Settings.builder()
.put("index.number_of_shards", param.getIndexShardsNum())
.put("index.number_of_replicas", param.getIndexReplicasNum())
);
Settings.Builder builder = Settings.builder()
.put("index.number_of_shards", param.getIndexShardsNum())
.put("index.number_of_replicas", param.getIndexReplicasNum());
if(approximateKnn){
//启用open search近似knn搜索支持
builder.put("index.knn",true);
builder.put("index.knn.algo_param.ef_search",param.getIndexAlgoParamEfSearch());
}
createIndexRequest.settings(builder);
HashMap<String, Object> properties = new HashMap<>();
properties.put(Constant.ColumnNameSampleId, Map.of("type", "keyword"));
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512"));
if(approximateKnn){
//启用open search近似knn搜索支持
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512",
"method",Map.of("engine","nmslib",
"space_type","cosinesimil",
"name","hnsw",
"parameters",Map.of())));
}else {
properties.put(Constant.ColumnNameFaceVector, Map.of("type", "knn_vector", "dimension", "512"));
}
createIndexRequest.mapping(Map.of("properties", properties));
//创建集合
CreateIndexResponse createIndexResponse = client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
@ -135,23 +151,37 @@ public class OpenSearchEngine implements SearchEngine {
}
@Override
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK) {
public SearchResponse search(String collectionName, float[][] features, String algorithm, int topK,boolean approximateKnn) {
try {
//构建搜索请求
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
for(float[] feature : features){
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
params.put("field", Constant.ColumnNameFaceVector);
params.put("space_type", algorithm);
params.put("query_value", feature);
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(scriptScoreQueryBuilder).size(topK)
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
multiSearchRequest.add(searchRequest);
if(approximateKnn){
//近似knn搜索
Map<String, Object> params = new HashMap<>();
params.put("vector",feature);
params.put("k",topK);
ApproximateKnnQueryBuilder approximateKnnQueryBuilder = new ApproximateKnnQueryBuilder(params);
SearchSourceBuilder searchSourceBuilder =new SearchSourceBuilder()
.query(approximateKnnQueryBuilder).size(topK)
.fetchSource(null,Constant.ColumnNameFaceVector);
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
multiSearchRequest.add(searchRequest);
}else {
//常规搜索
QueryBuilder queryBuilder = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
params.put("field", Constant.ColumnNameFaceVector);
params.put("space_type", algorithm);
params.put("query_value", feature);
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "knn", "knn_score", params);
ScriptScoreQueryBuilder scriptScoreQueryBuilder = new ScriptScoreQueryBuilder(queryBuilder, script);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder()
.query(scriptScoreQueryBuilder).size(topK)
.fetchSource(null, Constant.ColumnNameFaceVector); //是否需要向量字段
SearchRequest searchRequest = new SearchRequest(collectionName).source(searchSourceBuilder);
multiSearchRequest.add(searchRequest);
}
}
//查询索引
MultiSearchResponse response = this.client.msearch(multiSearchRequest, RequestOptions.DEFAULT);
@ -167,7 +197,7 @@ public class OpenSearchEngine implements SearchEngine {
if(searchHits != null){
for(SearchHit searchHit : searchHits){
String faceId = searchHit.getId();
float score = searchHit.getScore()-1;
float score = approximateKnn? searchHit.getScore() : (searchHit.getScore()-1);
Map<String, Object> sourceMap = searchHit.getSourceAsMap();
String sampleId = MapUtils.getString(sourceMap, Constant.ColumnNameSampleId);
Object faceVector = MapUtils.getObject(sourceMap, Constant.ColumnNameFaceVector);

View File

@ -0,0 +1,58 @@
package com.visual.face.search.engine.impl.query;
import com.visual.face.search.engine.conf.Constant;
import org.apache.lucene.search.Query;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
/**
* 近似KNN搜索
* @Author Foy Lian
* @Date 2023/9/13 14:26
**/
public class ApproximateKnnQueryBuilder extends AbstractQueryBuilder<ApproximateKnnQueryBuilder> {
private Map<String, Object> mParams;
public ApproximateKnnQueryBuilder(Map<String, Object> params){
this.mParams = params;
}
@Override
protected void doWriteTo(StreamOutput streamOutput) throws IOException {
}
@Override
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject("knn");
xContentBuilder.field(Constant.ColumnNameFaceVector,mParams);
xContentBuilder.endObject();
}
@Override
protected Query doToQuery(QueryShardContext queryShardContext) throws IOException {
return null;
}
@Override
protected boolean doEquals(ApproximateKnnQueryBuilder approximateKnnQueryBuilder) {
return Objects.equals(this.mParams, approximateKnnQueryBuilder.mParams);
}
@Override
protected int doHashCode() {
return Objects.hash(new Object[]{this.mParams});
}
@Override
public String getWriteableName() {
return "knn";
}
}

View File

@ -71,4 +71,9 @@ public class MapParam extends ConcurrentHashMap<String, Object> {
return shardsNum;
}
public Integer getIndexAlgoParamEfSearch(){
Integer algoParamEfSearch = this.getInteger(Constant.IndexAlgoParamEfSearch, 512);
algoParamEfSearch = (null == algoParamEfSearch || algoParamEfSearch <= 0) ? 512 : algoParamEfSearch;
return algoParamEfSearch;
}
}

View File

@ -46,6 +46,11 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
@ApiModelProperty(value="保留图片及人脸信息的存储组件", position = 9,required = false)
private StorageEngine storageEngine;
/**是否启用近似knn搜索**/
@ApiModelProperty(value="是否启用近似knn搜索", position = 10,required = false)
private boolean approximateKnn;
/**
* 构建集合对象
* @param namespace 命名空间
@ -56,6 +61,15 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
return new CollectVo().setNamespace(namespace).setCollectionName(collectionName);
}
public boolean isApproximateKnn() {
return approximateKnn;
}
public ExtendsVo setApproximateKnn(boolean approximateKnn){
this.approximateKnn = approximateKnn;
return (ExtendsVo) this;
}
public String getNamespace() {
return namespace;
}
@ -144,4 +158,5 @@ public class CollectVo<ExtendsVo extends CollectVo<ExtendsVo>> extends BaseVo {
}
return (ExtendsVo) this;
}
}

View File

@ -42,6 +42,12 @@ public class FaceSearchReqVo extends BaseVo {
@ApiModelProperty(value="对输入图像中多少个人脸进行检索比对默认5", position = 7, required = false)
private Integer maxFaceNum;
/**是否使用近似knn搜索**/
@ApiModelProperty(value="是否使用近似knn搜索", position = 8, required = false)
private boolean approximateKnn;
/**
* 构建检索对象
* @param namespace 命名空间
@ -130,4 +136,13 @@ public class FaceSearchReqVo extends BaseVo {
this.maxFaceNum = maxFaceNum;
return this;
}
public boolean isApproximateKnn() {
return approximateKnn;
}
public FaceSearchReqVo setApproximateKnn(boolean approximateKnn) {
this.approximateKnn = approximateKnn;
return this;
}
}

View File

@ -90,7 +90,7 @@ public class CollectServiceImpl extends BaseService implements CollectService {
MapParam param = MapParam.build()
.put(Constant.IndexShardsNum, collect.getShardsNum())
.put(Constant.IndexReplicasNum, collect.getReplicasNum());
boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param);
boolean createVectorFlag = searchEngine.createCollection(vectorTableName, param,collect.isApproximateKnn());
if(!createVectorFlag){
throw new RuntimeException("create vector table error");
}

View File

@ -83,7 +83,7 @@ public class FaceSearchServiceImpl extends BaseService implements FaceSearchServ
}
//特征搜索
int topK = (null == search.getLimit() || search.getLimit() <= 0) ? 5 : search.getLimit();
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK);
SearchResponse searchResponse =searchEngine.search(collection.getVectorTable(), vectors, search.getAlgorithm().algorithm(), topK,search.isApproximateKnn());
if(!searchResponse.getStatus().ok()){
throw new RuntimeException(searchResponse.getStatus().getReason());
}

View File

@ -26,6 +26,12 @@ public class FaceSearchExample {
public static String collectionName = "collect_20211201_v11";
public static FaceSearch faceSearch = FaceSearch.build(serverHost, namespace, collectionName);
//是否启用近似knn,建议底库集比较大时启用.
public static boolean approximateKnn = false;
//底库集比较大时建议调大:32个分片
public static int shardsNum = 4;
/**集合创建*/
public static void collect(){
//样本属性字段
@ -44,7 +50,12 @@ public class FaceSearchExample {
//是否保存人脸及图片数据信息
.setStorageFaceInfo(true)
//目前只实现了数据库存储对其他类型存储实现StorageImageService接口即可
.setStorageEngine(StorageEngine.CURR_DB);
.setStorageEngine(StorageEngine.CURR_DB)
//设置分片大小
.setShardsNum(shardsNum)
//开启关闭近似knn搜索
.setApproximateKnn(approximateKnn);
//删除集合
Response<Boolean> deleteCollect = faceSearch.collect().deleteCollect();
System.out.println(deleteCollect);
@ -97,6 +108,9 @@ public class FaceSearchExample {
.search(Search.build(imageBase64)
.setConfidenceThreshold(50f) //最小置信分50
.setMaxFaceNum(10).setLimit(1)
//collect()创建集合时即使开了近似knn搜索这里设置为false也可以使用精确knn搜索
//这里数据量足够大时就能发现近似knn返回较快
.setApproximateKnn(approximateKnn)
);
Long e = System.currentTimeMillis();
System.out.println("search cost:" + (e-s)+"ms");