add:添加SeetaFace6Open的人脸关键点遮挡模型

This commit is contained in:
divenswu 2023-04-27 16:21:20 +08:00
parent 142828a6d2
commit 91bf0b8853
9 changed files with 340 additions and 27 deletions

View File

@ -0,0 +1,21 @@
package com.visual.face.search.core.base;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.domain.QualityInfo;
import java.util.Map;
/**
* 人脸关键点检测
*/
public interface FaceMaskPoint {
/**
* 人脸关键点检测
* @param imageMat 图像数据
* @param params 参数信息
* @return
*/
QualityInfo.MaskPoints inference(ImageMat imageMat, Map<String, Object> params);
}

View File

@ -264,7 +264,7 @@ public class ImageMat implements Serializable {
} }
/** /**
* 对图像进行预处理,并释放原始图片数据 * 对图像进行预处理,并释放原始图片数据:先交换RB通道swapRB再减法mean最后缩放scale
* @param scale 图像各通道数值的缩放比例 * @param scale 图像各通道数值的缩放比例
* @param mean 用于各通道减去的值以降低光照的影响 * @param mean 用于各通道减去的值以降低光照的影响
* @param swapRB 交换RB通道默认为False. * @param swapRB 交换RB通道默认为False.

View File

@ -0,0 +1,112 @@
package com.visual.face.search.core.domain;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
public class QualityInfo {
public MaskPoints maskPoints;
private QualityInfo(MaskPoints maskPoints) {
this.maskPoints = maskPoints;
}
public static QualityInfo build(MaskPoints maskPoints){
return new QualityInfo(maskPoints);
}
public MaskPoints getMaskPoints() {
return maskPoints;
}
public boolean isMask(){
return true;
}
/**
* 遮挡类
*/
public static class Mask implements Serializable {
/**遮挡分数*/
public float score;
public static Mask build(float score){
return new QualityInfo.Mask(score);
}
private Mask(float score) {
this.score = score;
}
public float getScore() {
return score;
}
@Override
public String toString() {
return "Mask{" + "score=" + score + '}';
}
}
/**
* 点遮挡类
*/
public static class MaskPoint extends Mask{
/**坐标X的值**/
public float x;
/**坐标Y的值**/
public float y;
public static MaskPoint build(float x, float y, float score){
return new MaskPoint(x, y, score);
}
private MaskPoint(float x, float y, float score) {
super(score);
this.x = x;
this.y = y;
}
public float getX() {
return x;
}
public float getY() {
return y;
}
@Override
public String toString() {
return "MaskPoint{" + "x=" + x + ", y=" + y + ", score=" + score + '}';
}
}
/**
* 点遮挡类集合
*/
public static class MaskPoints extends ArrayList<MaskPoint> {
private MaskPoints(){}
/**
* 构建一个集合
* @return
*/
public static MaskPoints build(){
return new MaskPoints();
}
/**
* 添加点
* @param point
* @return
*/
public MaskPoints add(MaskPoint...point){
super.addAll(Arrays.asList(point));
return this;
}
}
}

View File

@ -24,6 +24,10 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
public final static float defScoreTh = 0.5f; public final static float defScoreTh = 0.5f;
//人脸重叠iou阈值 //人脸重叠iou阈值
public final static float defIouTh = 0.7f; public final static float defIouTh = 0.7f;
//给人脸框一个默认的缩放
public final static float defBoxScale = 1.0f;
//人脸框缩放参数KEY
public final static String boxScaleParamKey = "boxScale";
/** /**
* 构造函数 * 构造函数
@ -48,6 +52,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
ImageMat imageMat = image.clone(); ImageMat imageMat = image.clone();
try { try {
float imgScale = 1.0f; float imgScale = 1.0f;
float boxScale = getBoxScale(params);
iouTh = iouTh <= 0 ? defIouTh : iouTh; iouTh = iouTh <= 0 ? defIouTh : iouTh;
scoreTh = scoreTh <= 0 ? defScoreTh : scoreTh; scoreTh = scoreTh <= 0 ? defScoreTh : scoreTh;
int imageWidth = imageMat.getWidth(), imageHeight = imageMat.getHeight(); int imageWidth = imageMat.getWidth(), imageHeight = imageMat.getHeight();
@ -68,7 +73,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
.blobFromImageAndDoReleaseMat(1.0/128, new Scalar(127.5, 127.5, 127.5), true) .blobFromImageAndDoReleaseMat(1.0/128, new Scalar(127.5, 127.5, 127.5), true)
.to4dFloatOnnxTensorAndDoReleaseMat(true); .to4dFloatOnnxTensorAndDoReleaseMat(true);
output = getSession().run(Collections.singletonMap(getInputName(), tensor)); output = getSession().run(Collections.singletonMap(getInputName(), tensor));
return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale); return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale, boxScale);
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
}finally { }finally {
@ -94,7 +99,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
* @return * @return
* @throws OrtException * @throws OrtException
*/ */
private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale) throws OrtException { private List<FaceInfo> fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale, float boxScale) throws OrtException {
//分数过滤及计算正确的人脸框值 //分数过滤及计算正确的人脸框值
List<FaceInfo> faceInfos = new ArrayList<>(); List<FaceInfo> faceInfos = new ArrayList<>();
for(int index=0; index< 3; index++) { for(int index=0; index< 3; index++) {
@ -122,7 +127,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
float pointY = (point[2*pointIndex+1] * strides[index] + anchorY) * imgScale; float pointY = (point[2*pointIndex+1] * strides[index] + anchorY) * imgScale;
keyPoints.add(FaceInfo.Point.build(pointX, pointY)); keyPoints.add(FaceInfo.Point.build(pointX, pointY));
} }
faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2), keyPoints)); faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2).scaling(boxScale), keyPoints));
} }
} }
} }
@ -147,4 +152,20 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete
return faces; return faces;
} }
private float getBoxScale(Map<String, Object> params){
float boxScale = 0;
try {
if(null != params && params.containsKey(boxScaleParamKey)){
Object value = params.get(boxScaleParamKey);
if(null != value){
if (value instanceof Number){
boxScale = ((Number) value).floatValue();
}else{
boxScale = Float.parseFloat(value.toString());
}
}
}
}catch (Exception e){}
return boxScale > 0 ? boxScale : defBoxScale;
}
} }

View File

@ -0,0 +1,94 @@
package com.visual.face.search.core.models;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtSession;
import com.visual.face.search.core.base.BaseOnnxInfer;
import com.visual.face.search.core.base.FaceMaskPoint;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.domain.QualityInfo;
import com.visual.face.search.core.utils.SoftMaxUtil;
import org.opencv.core.*;
import org.opencv.imgproc.Imgproc;
import java.util.Collections;
import java.util.Map;
public class SeetaMaskFaceKeyPoint extends BaseOnnxInfer implements FaceMaskPoint {
private static final int stride = 8;
private static final int shape = 128;
/**
* 构造函数
* @param modelPath 模型路径
* @param threads 线程数
*/
public SeetaMaskFaceKeyPoint(String modelPath, int threads) {
super(modelPath, threads);
}
/**
* 人脸关键点检测
*
* @param imageMat 图像数据
* @param params 参数信息
* @return
*/
@Override
public QualityInfo.MaskPoints inference(ImageMat imageMat, Map<String, Object> params) {
Mat borderMat = null;
Mat resizeMat = null;
OnnxTensor tensor = null;
OrtSession.Result output = null;
try {
Mat image = imageMat.toCvMat();
//将图片转换为正方形
int w = imageMat.getWidth();
int h = imageMat.getHeight();
int new_w = Math.max(h, w);
int new_h = Math.max(h, w);
if (Math.max(h, w) % stride != 0){
new_w = new_w + (stride - Math.max(h, w) % stride);
new_h = new_h + (stride - Math.max(h, w) % stride);
}
int ow = (new_w - w) / 2;
int oh = (new_h - h) / 2;
borderMat = new Mat();
Core.copyMakeBorder(image, borderMat, oh, oh, ow, ow, Core.BORDER_CONSTANT, new Scalar(114, 114, 114));
//对图片进行resize
float ratio = 1.0f * shape / new_h;
resizeMat = new Mat();
Imgproc.resize(borderMat, resizeMat, new Size(shape, shape));
//模型推理
tensor = ImageMat.fromCVMat(resizeMat)
.blobFromImageAndDoReleaseMat(1.0/32, new Scalar(104, 117, 123), false)
.to4dFloatOnnxTensorAndDoReleaseMat(true);
output = this.getSession().run(Collections.singletonMap(this.getInputName(), tensor));
float[] value = ((float[][]) output.get(0).getValue())[0];
//转换为标准的坐标点
QualityInfo.MaskPoints pointList = QualityInfo.MaskPoints.build();
for(int i=0; i<5; i++){
float x = value[i * 4 + 0] / ratio * 128 - ow;
float y = value[i * 4 + 1] / ratio * 128 - oh;
double[] softMax = SoftMaxUtil.softMax(new double[]{value[i * 4 + 2], value[i * 4 + 3]});
pointList.add(QualityInfo.MaskPoint.build(x, y, Double.valueOf(softMax[1]).floatValue()));
}
return pointList;
} catch (Exception e) {
throw new RuntimeException(e);
}finally {
if(null != tensor){
tensor.close();
}
if(null != output){
output.close();
}
if(null != borderMat){
borderMat.release();
}
if(null != resizeMat){
resizeMat.release();
}
}
}
}

View File

@ -15,6 +15,7 @@ import org.opencv.imgcodecs.Imgcodecs;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Map;
public class FaceCompareTest extends BaseTest { public class FaceCompareTest extends BaseTest {
@ -25,19 +26,22 @@ public class FaceCompareTest extends BaseTest {
private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx"; private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx";
private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx"; private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx";
private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_512.onnx"; private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_512.onnx";
// private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_1024.onnx";
private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx"; private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx";
// private static String imagePath = "face-search-core/src/test/resources/images/faces"; private static String imagePath = "face-search-test/src/main/resources/image/validate/index/马化腾/";
private static String imagePath3 = "face-search-test/src/main/resources/image/validate/index/雷军/";
// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg"; // private static String imagePath1 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg"; // private static String imagePath2 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg";
private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg"; // private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg";
private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg"; // private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg";
// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682053163961.jpg"; // private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682053163961.jpg";
// private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/1c7abcaf2dabdd2bc08e90c224d4c381.jpeg";
private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/0762c790db41a64f8f3f97598a825372.jpeg";
private static String imagePath2 = "face-search-test/src/main/resources/image/validate/index/张一鸣/ea191e61bdd4be8fddc89b828f5399b6.jpeg";
public static void main(String[] args) { public static void main(String[] args) {
// Map<String, String> map = getImagePathMap(imagePath); //口罩模型0.48light模型0.52normal模型0.62
Map<String, String> map1 = getImagePathMap(imagePath);
Map<String, String> map2 = getImagePathMap(imagePath3);
FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1); FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1);
FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1); FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1);
FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1); FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1);
@ -49,9 +53,11 @@ public class FaceCompareTest extends BaseTest {
FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl( FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl(
insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint, insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint,
simple005pFaceAlignment, insightSeetaFaceRecognition, insightFaceAttribute); simple005pFaceAlignment, insightArcFaceRecognition, insightFaceAttribute);
Mat image1 = Imgcodecs.imread(imagePath1); for(String file1 : map1.keySet()){
for(String file2 : map2.keySet()){
Mat image1 = Imgcodecs.imread(map1.get(file1));
long s = System.currentTimeMillis(); long s = System.currentTimeMillis();
ExtParam extParam = ExtParam.build().setMask(false).setTopK(20).setScoreTh(0).setIouTh(0); ExtParam extParam = ExtParam.build().setMask(false).setTopK(20).setScoreTh(0).setIouTh(0);
FaceImage faceImage1 = extractor.extract(ImageMat.fromCVMat(image1), extParam, null); FaceImage faceImage1 = extractor.extract(ImageMat.fromCVMat(image1), extParam, null);
@ -59,14 +65,16 @@ public class FaceCompareTest extends BaseTest {
long e = System.currentTimeMillis(); long e = System.currentTimeMillis();
System.out.println("image1 extract cost:"+(e-s)+"ms");; System.out.println("image1 extract cost:"+(e-s)+"ms");;
Mat image2 = Imgcodecs.imread(imagePath2); Mat image2 = Imgcodecs.imread(map2.get(file2));
s = System.currentTimeMillis(); s = System.currentTimeMillis();
FaceImage faceImage2 = extractor.extract(ImageMat.fromCVMat(image2), extParam, null); FaceImage faceImage2 = extractor.extract(ImageMat.fromCVMat(image2), extParam, null);
List<FaceInfo> faceInfos2 = faceImage2.faceInfos(); List<FaceInfo> faceInfos2 = faceImage2.faceInfos();
e = System.currentTimeMillis(); e = System.currentTimeMillis();
System.out.println("image2 extract cost:"+(e-s)+"ms"); System.out.println("image2 extract cost:"+(e-s)+"ms");
float similarity = Similarity.cosineSimilarity(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds); float similarity = Similarity.cosineSimilarityNorm(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds);
System.out.println("face similarity="+similarity); System.out.println(file1 + ","+ file2 + ",face similarity="+similarity);
}
}
} }
} }

View File

@ -0,0 +1,57 @@
package com.visual.face.search.core.test.models;
import com.visual.face.search.core.domain.FaceInfo;
import com.visual.face.search.core.domain.ImageMat;
import com.visual.face.search.core.domain.QualityInfo;
import com.visual.face.search.core.models.InsightCoordFaceKeyPoint;
import com.visual.face.search.core.models.InsightScrfdFaceDetection;
import com.visual.face.search.core.models.SeetaMaskFaceKeyPoint;
import com.visual.face.search.core.test.base.BaseTest;
import com.visual.face.search.core.utils.CropUtil;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.util.List;
import java.util.Map;
public class SeetaMaskFaceKeyPointTest extends BaseTest {
private static String modelDetectionPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx";
private static String modelKeypointPath = "face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx";
private static String imagePath = "face-search-core/src/test/resources/images/faces";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare";
// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg";
public static void main(String[] args) {
Map<String, String> map = getImagePathMap(imagePath);
InsightScrfdFaceDetection detectionInfer = new InsightScrfdFaceDetection(modelDetectionPath, 1);
SeetaMaskFaceKeyPoint keyPointInfer = new SeetaMaskFaceKeyPoint(modelKeypointPath, 1);
for(String fileName : map.keySet()) {
System.out.println(fileName);
String imageFilePath = map.get(fileName);
Mat image = Imgcodecs.imread(imageFilePath);
List<FaceInfo> faceInfos = detectionInfer.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null);
for(FaceInfo faceInfo : faceInfos){
FaceInfo.FaceBox rotateFaceBox = faceInfo.rotateFaceBox();
Mat cropFace = CropUtil.crop(image, rotateFaceBox.scaling(1.0f));
ImageMat cropImageMat = ImageMat.fromCVMat(cropFace);
QualityInfo.MaskPoints maskPoints = keyPointInfer.inference(cropImageMat, null);
System.out.println(maskPoints);
for(QualityInfo.MaskPoint maskPoint : maskPoints){
if(maskPoint.score >= 0.5){
Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(0,0,255), -1);
}else{
Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(255, 0,0), -1);
}
}
HighGui.imshow(fileName, cropFace);
HighGui.waitKey();
}
}
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB