mirror of
https://gitee.com/open-visual/face-search.git
synced 2025-07-25 19:41:42 +08:00
add:添加SeetaFace6Open的人脸关键点遮挡模型
This commit is contained in:
parent
142828a6d2
commit
91bf0b8853
@ -0,0 +1,21 @@
|
||||
package com.visual.face.search.core.base;
|
||||
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.QualityInfo;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 人脸关键点检测
|
||||
*/
|
||||
public interface FaceMaskPoint {
|
||||
|
||||
/**
|
||||
* 人脸关键点检测
|
||||
* @param imageMat 图像数据
|
||||
* @param params 参数信息
|
||||
* @return
|
||||
*/
|
||||
QualityInfo.MaskPoints inference(ImageMat imageMat, Map<String, Object> params);
|
||||
|
||||
}
|
@ -264,7 +264,7 @@ public class ImageMat implements Serializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* 对图像进行预处理,并释放原始图片数据
|
||||
* 对图像进行预处理,并释放原始图片数据:(先交换RB通道(swapRB),再减法(mean),最后缩放(scale))
|
||||
* @param scale 图像各通道数值的缩放比例
|
||||
* @param mean 用于各通道减去的值,以降低光照的影响
|
||||
* @param swapRB 交换RB通道,默认为False.
|
||||
|
@ -0,0 +1,112 @@
|
||||
package com.visual.face.search.core.domain;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class QualityInfo {
|
||||
|
||||
public MaskPoints maskPoints;
|
||||
|
||||
private QualityInfo(MaskPoints maskPoints) {
|
||||
this.maskPoints = maskPoints;
|
||||
}
|
||||
|
||||
public static QualityInfo build(MaskPoints maskPoints){
|
||||
return new QualityInfo(maskPoints);
|
||||
}
|
||||
|
||||
public MaskPoints getMaskPoints() {
|
||||
return maskPoints;
|
||||
}
|
||||
|
||||
|
||||
public boolean isMask(){
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 遮挡类
|
||||
*/
|
||||
public static class Mask implements Serializable {
|
||||
/**遮挡分数*/
|
||||
public float score;
|
||||
|
||||
public static Mask build(float score){
|
||||
return new QualityInfo.Mask(score);
|
||||
}
|
||||
|
||||
private Mask(float score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Mask{" + "score=" + score + '}';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 点遮挡类
|
||||
*/
|
||||
public static class MaskPoint extends Mask{
|
||||
/**坐标X的值**/
|
||||
public float x;
|
||||
/**坐标Y的值**/
|
||||
public float y;
|
||||
|
||||
public static MaskPoint build(float x, float y, float score){
|
||||
return new MaskPoint(x, y, score);
|
||||
}
|
||||
|
||||
private MaskPoint(float x, float y, float score) {
|
||||
super(score);
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
public float getX() {
|
||||
return x;
|
||||
}
|
||||
|
||||
public float getY() {
|
||||
return y;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MaskPoint{" + "x=" + x + ", y=" + y + ", score=" + score + '}';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 点遮挡类集合
|
||||
*/
|
||||
public static class MaskPoints extends ArrayList<MaskPoint> {
|
||||
|
||||
private MaskPoints(){}
|
||||
|
||||
/**
|
||||
* 构建一个集合
|
||||
* @return
|
||||
*/
|
||||
public static MaskPoints build(){
|
||||
return new MaskPoints();
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加点
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
public MaskPoints add(MaskPoint...point){
|
||||
super.addAll(Arrays.asList(point));
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
@ -24,6 +24,10 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
public final static float defScoreTh = 0.5f;
|
||||
//人脸重叠iou阈值
|
||||
public final static float defIouTh = 0.7f;
|
||||
//给人脸框一个默认的缩放
|
||||
public final static float defBoxScale = 1.0f;
|
||||
//人脸框缩放参数KEY
|
||||
public final static String boxScaleParamKey = "boxScale";
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
@ -48,6 +52,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
ImageMat imageMat = image.clone();
|
||||
try {
|
||||
float imgScale = 1.0f;
|
||||
float boxScale = getBoxScale(params);
|
||||
iouTh = iouTh <= 0 ? defIouTh : iouTh;
|
||||
scoreTh = scoreTh <= 0 ? defScoreTh : scoreTh;
|
||||
int imageWidth = imageMat.getWidth(), imageHeight = imageMat.getHeight();
|
||||
@ -68,7 +73,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
.blobFromImageAndDoReleaseMat(1.0/128, new Scalar(127.5, 127.5, 127.5), true)
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = getSession().run(Collections.singletonMap(getInputName(), tensor));
|
||||
return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale);
|
||||
return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale, boxScale);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
@ -94,7 +99,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
* @return
|
||||
* @throws OrtException
|
||||
*/
|
||||
private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale) throws OrtException {
|
||||
private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale, float boxScale) throws OrtException {
|
||||
//分数过滤及计算正确的人脸框值
|
||||
List<FaceInfo> faceInfos = new ArrayList<>();
|
||||
for(int index=0; index< 3; index++) {
|
||||
@ -122,7 +127,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
float pointY = (point[2*pointIndex+1] * strides[index] + anchorY) * imgScale;
|
||||
keyPoints.add(FaceInfo.Point.build(pointX, pointY));
|
||||
}
|
||||
faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2), keyPoints));
|
||||
faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2).scaling(boxScale), keyPoints));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -147,4 +152,20 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
|
||||
return faces;
|
||||
}
|
||||
|
||||
private float getBoxScale(Map<String, Object> params){
|
||||
float boxScale = 0;
|
||||
try {
|
||||
if(null != params && params.containsKey(boxScaleParamKey)){
|
||||
Object value = params.get(boxScaleParamKey);
|
||||
if(null != value){
|
||||
if (value instanceof Number){
|
||||
boxScale = ((Number) value).floatValue();
|
||||
}else{
|
||||
boxScale = Float.parseFloat(value.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}catch (Exception e){}
|
||||
return boxScale > 0 ? boxScale : defBoxScale;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,94 @@
|
||||
package com.visual.face.search.core.models;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.visual.face.search.core.base.BaseOnnxInfer;
|
||||
import com.visual.face.search.core.base.FaceMaskPoint;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.QualityInfo;
|
||||
import com.visual.face.search.core.utils.SoftMaxUtil;
|
||||
import org.opencv.core.*;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
public class SeetaMaskFaceKeyPoint extends BaseOnnxInfer implements FaceMaskPoint {
|
||||
|
||||
private static final int stride = 8;
|
||||
private static final int shape = 128;
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param modelPath 模型路径
|
||||
* @param threads 线程数
|
||||
*/
|
||||
public SeetaMaskFaceKeyPoint(String modelPath, int threads) {
|
||||
super(modelPath, threads);
|
||||
}
|
||||
|
||||
/**
|
||||
* 人脸关键点检测
|
||||
*
|
||||
* @param imageMat 图像数据
|
||||
* @param params 参数信息
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public QualityInfo.MaskPoints inference(ImageMat imageMat, Map<String, Object> params) {
|
||||
Mat borderMat = null;
|
||||
Mat resizeMat = null;
|
||||
OnnxTensor tensor = null;
|
||||
OrtSession.Result output = null;
|
||||
try {
|
||||
Mat image = imageMat.toCvMat();
|
||||
//将图片转换为正方形
|
||||
int w = imageMat.getWidth();
|
||||
int h = imageMat.getHeight();
|
||||
int new_w = Math.max(h, w);
|
||||
int new_h = Math.max(h, w);
|
||||
if (Math.max(h, w) % stride != 0){
|
||||
new_w = new_w + (stride - Math.max(h, w) % stride);
|
||||
new_h = new_h + (stride - Math.max(h, w) % stride);
|
||||
}
|
||||
int ow = (new_w - w) / 2;
|
||||
int oh = (new_h - h) / 2;
|
||||
borderMat = new Mat();
|
||||
Core.copyMakeBorder(image, borderMat, oh, oh, ow, ow, Core.BORDER_CONSTANT, new Scalar(114, 114, 114));
|
||||
//对图片进行resize
|
||||
float ratio = 1.0f * shape / new_h;
|
||||
resizeMat = new Mat();
|
||||
Imgproc.resize(borderMat, resizeMat, new Size(shape, shape));
|
||||
//模型推理
|
||||
tensor = ImageMat.fromCVMat(resizeMat)
|
||||
.blobFromImageAndDoReleaseMat(1.0/32, new Scalar(104, 117, 123), false)
|
||||
.to4dFloatOnnxTensorAndDoReleaseMat(true);
|
||||
output = this.getSession().run(Collections.singletonMap(this.getInputName(), tensor));
|
||||
float[] value = ((float[][]) output.get(0).getValue())[0];
|
||||
//转换为标准的坐标点
|
||||
QualityInfo.MaskPoints pointList = QualityInfo.MaskPoints.build();
|
||||
for(int i=0; i<5; i++){
|
||||
float x = value[i * 4 + 0] / ratio * 128 - ow;
|
||||
float y = value[i * 4 + 1] / ratio * 128 - oh;
|
||||
double[] softMax = SoftMaxUtil.softMax(new double[]{value[i * 4 + 2], value[i * 4 + 3]});
|
||||
pointList.add(QualityInfo.MaskPoint.build(x, y, Double.valueOf(softMax[1]).floatValue()));
|
||||
}
|
||||
return pointList;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}finally {
|
||||
if(null != tensor){
|
||||
tensor.close();
|
||||
}
|
||||
if(null != output){
|
||||
output.close();
|
||||
}
|
||||
if(null != borderMat){
|
||||
borderMat.release();
|
||||
}
|
||||
if(null != resizeMat){
|
||||
resizeMat.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Binary file not shown.
@ -15,6 +15,7 @@ import org.opencv.imgcodecs.Imgcodecs;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class FaceCompareTest extends BaseTest {
|
||||
|
||||
@ -25,19 +26,22 @@ public class FaceCompareTest extends BaseTest {
|
||||
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
|
||||
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
|
||||
private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_512.onnx";
|
||||
// private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_1024.onnx";
|
||||
private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx";
|
||||
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
private static String imagePath = "face-search-test/src/main/resources/image/validate/index/马化腾/";
|
||||
private static String imagePath3 = "face-search-test/src/main/resources/image/validate/index/雷军/";
|
||||
// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
|
||||
private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg";
|
||||
private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg";
|
||||
// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg";
|
||||
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682053163961.jpg";
|
||||
|
||||
|
||||
// private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/1c7abcaf2dabdd2bc08e90c224d4c381.jpeg";
|
||||
private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/0762c790db41a64f8f3f97598a825372.jpeg";
|
||||
private static String imagePath2 = "face-search-test/src/main/resources/image/validate/index/张一鸣/ea191e61bdd4be8fddc89b828f5399b6.jpeg";
|
||||
public static void main(String[] args) {
|
||||
// Map<String, String> map = getImagePathMap(imagePath);
|
||||
//口罩模型0.48,light模型0.52,normal模型0.62
|
||||
Map<String, String> map1 = getImagePathMap(imagePath);
|
||||
Map<String, String> map2 = getImagePathMap(imagePath3);
|
||||
FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1);
|
||||
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
|
||||
FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1);
|
||||
@ -49,9 +53,11 @@ public class FaceCompareTest extends BaseTest {
|
||||
|
||||
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(
|
||||
insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint,
|
||||
simple005pFaceAlignment, insightSeetaFaceRecognition, insightFaceAttribute);
|
||||
simple005pFaceAlignment, insightArcFaceRecognition, insightFaceAttribute);
|
||||
|
||||
Mat image1 = Imgcodecs.imread(imagePath1);
|
||||
for(String file1 : map1.keySet()){
|
||||
for(String file2 : map2.keySet()){
|
||||
Mat image1 = Imgcodecs.imread(map1.get(file1));
|
||||
long s = System.currentTimeMillis();
|
||||
ExtParam extParam = ExtParam.build().setMask(false).setTopK(20).setScoreTh(0).setIouTh(0);
|
||||
FaceImage faceImage1 = extractor.extract(ImageMat.fromCVMat(image1), extParam, null);
|
||||
@ -59,14 +65,16 @@ public class FaceCompareTest extends BaseTest {
|
||||
long e = System.currentTimeMillis();
|
||||
System.out.println("image1 extract cost:"+(e-s)+"ms");;
|
||||
|
||||
Mat image2 = Imgcodecs.imread(imagePath2);
|
||||
Mat image2 = Imgcodecs.imread(map2.get(file2));
|
||||
s = System.currentTimeMillis();
|
||||
FaceImage faceImage2 = extractor.extract(ImageMat.fromCVMat(image2), extParam, null);
|
||||
List<FaceInfo> faceInfos2 = faceImage2.faceInfos();
|
||||
e = System.currentTimeMillis();
|
||||
System.out.println("image2 extract cost:"+(e-s)+"ms");
|
||||
float similarity = Similarity.cosineSimilarity(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds);
|
||||
System.out.println("face similarity="+similarity);
|
||||
float similarity = Similarity.cosineSimilarityNorm(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds);
|
||||
System.out.println(file1 + ","+ file2 + ",face similarity="+similarity);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,57 @@
|
||||
package com.visual.face.search.core.test.models;
|
||||
|
||||
import com.visual.face.search.core.domain.FaceInfo;
|
||||
import com.visual.face.search.core.domain.ImageMat;
|
||||
import com.visual.face.search.core.domain.QualityInfo;
|
||||
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
|
||||
import com.visual.face.search.core.models.InsightScrfdFaceDetection;
|
||||
import com.visual.face.search.core.models.SeetaMaskFaceKeyPoint;
|
||||
import com.visual.face.search.core.test.base.BaseTest;
|
||||
import com.visual.face.search.core.utils.CropUtil;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Point;
|
||||
import org.opencv.core.Scalar;
|
||||
import org.opencv.highgui.HighGui;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SeetaMaskFaceKeyPointTest extends BaseTest {
|
||||
private static String modelDetectionPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
|
||||
private static String modelKeypointPath = "face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx";
|
||||
private static String imagePath = "face-search-core/src/test/resources/images/faces";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare";
|
||||
// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg";
|
||||
|
||||
public static void main(String[] args) {
|
||||
Map<String, String> map = getImagePathMap(imagePath);
|
||||
InsightScrfdFaceDetection detectionInfer = new InsightScrfdFaceDetection(modelDetectionPath, 1);
|
||||
SeetaMaskFaceKeyPoint keyPointInfer = new SeetaMaskFaceKeyPoint(modelKeypointPath, 1);
|
||||
for(String fileName : map.keySet()) {
|
||||
System.out.println(fileName);
|
||||
String imageFilePath = map.get(fileName);
|
||||
Mat image = Imgcodecs.imread(imageFilePath);
|
||||
List<FaceInfo> faceInfos = detectionInfer.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null);
|
||||
for(FaceInfo faceInfo : faceInfos){
|
||||
FaceInfo.FaceBox rotateFaceBox = faceInfo.rotateFaceBox();
|
||||
Mat cropFace = CropUtil.crop(image, rotateFaceBox.scaling(1.0f));
|
||||
ImageMat cropImageMat = ImageMat.fromCVMat(cropFace);
|
||||
QualityInfo.MaskPoints maskPoints = keyPointInfer.inference(cropImageMat, null);
|
||||
System.out.println(maskPoints);
|
||||
for(QualityInfo.MaskPoint maskPoint : maskPoints){
|
||||
if(maskPoint.score >= 0.5){
|
||||
Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(0,0,255), -1);
|
||||
}else{
|
||||
Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(255, 0,0), -1);
|
||||
}
|
||||
}
|
||||
HighGui.imshow(fileName, cropFace);
|
||||
HighGui.waitKey();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
After Width: | Height: | Size: 75 KiB |
Loading…
x
Reference in New Issue
Block a user