diff --git a/face-search-core/src/main/java/com/visual/face/search/core/base/FaceMaskPoint.java b/face-search-core/src/main/java/com/visual/face/search/core/base/FaceMaskPoint.java new file mode 100755 index 0000000..35e0a57 --- /dev/null +++ b/face-search-core/src/main/java/com/visual/face/search/core/base/FaceMaskPoint.java @@ -0,0 +1,21 @@ +package com.visual.face.search.core.base; + +import com.visual.face.search.core.domain.ImageMat; +import com.visual.face.search.core.domain.QualityInfo; + +import java.util.Map; + +/** + * 人脸关键点检测 + */ +public interface FaceMaskPoint { + + /** + * 人脸关键点检测 + * @param imageMat 图像数据 + * @param params 参数信息 + * @return + */ + QualityInfo.MaskPoints inference(ImageMat imageMat, Map params); + +} diff --git a/face-search-core/src/main/java/com/visual/face/search/core/domain/ImageMat.java b/face-search-core/src/main/java/com/visual/face/search/core/domain/ImageMat.java index c755dc1..28d95af 100755 --- a/face-search-core/src/main/java/com/visual/face/search/core/domain/ImageMat.java +++ b/face-search-core/src/main/java/com/visual/face/search/core/domain/ImageMat.java @@ -264,7 +264,7 @@ public class ImageMat implements Serializable { } /** - * 对图像进行预处理,并释放原始图片数据 + * 对图像进行预处理,并释放原始图片数据:(先交换RB通道(swapRB),再减法(mean),最后缩放(scale)) * @param scale 图像各通道数值的缩放比例 * @param mean 用于各通道减去的值,以降低光照的影响 * @param swapRB 交换RB通道,默认为False. diff --git a/face-search-core/src/main/java/com/visual/face/search/core/domain/QualityInfo.java b/face-search-core/src/main/java/com/visual/face/search/core/domain/QualityInfo.java new file mode 100644 index 0000000..5c8999b --- /dev/null +++ b/face-search-core/src/main/java/com/visual/face/search/core/domain/QualityInfo.java @@ -0,0 +1,112 @@ +package com.visual.face.search.core.domain; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; + +public class QualityInfo { + + public MaskPoints maskPoints; + + private QualityInfo(MaskPoints maskPoints) { + this.maskPoints = maskPoints; + } + + public static QualityInfo build(MaskPoints maskPoints){ + return new QualityInfo(maskPoints); + } + + public MaskPoints getMaskPoints() { + return maskPoints; + } + + + public boolean isMask(){ + return true; + } + + + /** + * 遮挡类 + */ + public static class Mask implements Serializable { + /**遮挡分数*/ + public float score; + + public static Mask build(float score){ + return new QualityInfo.Mask(score); + } + + private Mask(float score) { + this.score = score; + } + + public float getScore() { + return score; + } + + @Override + public String toString() { + return "Mask{" + "score=" + score + '}'; + } + } + + /** + * 点遮挡类 + */ + public static class MaskPoint extends Mask{ + /**坐标X的值**/ + public float x; + /**坐标Y的值**/ + public float y; + + public static MaskPoint build(float x, float y, float score){ + return new MaskPoint(x, y, score); + } + + private MaskPoint(float x, float y, float score) { + super(score); + this.x = x; + this.y = y; + } + + public float getX() { + return x; + } + + public float getY() { + return y; + } + + @Override + public String toString() { + return "MaskPoint{" + "x=" + x + ", y=" + y + ", score=" + score + '}'; + } + } + + /** + * 点遮挡类集合 + */ + public static class MaskPoints extends ArrayList { + + private MaskPoints(){} + + /** + * 构建一个集合 + * @return + */ + public static MaskPoints build(){ + return new MaskPoints(); + } + + /** + * 添加点 + * @param point + * @return + */ + public MaskPoints add(MaskPoint...point){ + super.addAll(Arrays.asList(point)); + return this; + } + } +} diff --git a/face-search-core/src/main/java/com/visual/face/search/core/models/InsightScrfdFaceDetection.java b/face-search-core/src/main/java/com/visual/face/search/core/models/InsightScrfdFaceDetection.java index 6697e6b..5faad91 100755 --- a/face-search-core/src/main/java/com/visual/face/search/core/models/InsightScrfdFaceDetection.java +++ b/face-search-core/src/main/java/com/visual/face/search/core/models/InsightScrfdFaceDetection.java @@ -24,6 +24,10 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete public final static float defScoreTh = 0.5f; //人脸重叠iou阈值 public final static float defIouTh = 0.7f; + //给人脸框一个默认的缩放 + public final static float defBoxScale = 1.0f; + //人脸框缩放参数KEY + public final static String boxScaleParamKey = "boxScale"; /** * 构造函数 @@ -48,6 +52,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete ImageMat imageMat = image.clone(); try { float imgScale = 1.0f; + float boxScale = getBoxScale(params); iouTh = iouTh <= 0 ? defIouTh : iouTh; scoreTh = scoreTh <= 0 ? defScoreTh : scoreTh; int imageWidth = imageMat.getWidth(), imageHeight = imageMat.getHeight(); @@ -68,7 +73,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete .blobFromImageAndDoReleaseMat(1.0/128, new Scalar(127.5, 127.5, 127.5), true) .to4dFloatOnnxTensorAndDoReleaseMat(true); output = getSession().run(Collections.singletonMap(getInputName(), tensor)); - return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale); + return fitterBoxes(output, scoreTh, iouTh, tensor.getInfo().getShape()[3], imgScale, boxScale); } catch (Exception e) { throw new RuntimeException(e); }finally { @@ -94,7 +99,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete * @return * @throws OrtException */ - private List fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale) throws OrtException { + private List fitterBoxes(OrtSession.Result output, float scoreTh, float iouTh, long tensorWidth, float imgScale, float boxScale) throws OrtException { //分数过滤及计算正确的人脸框值 List faceInfos = new ArrayList<>(); for(int index=0; index< 3; index++) { @@ -122,7 +127,7 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete float pointY = (point[2*pointIndex+1] * strides[index] + anchorY) * imgScale; keyPoints.add(FaceInfo.Point.build(pointX, pointY)); } - faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2), keyPoints)); + faceInfos.add(FaceInfo.build(scores[i][0], 0, FaceInfo.FaceBox.build(x1,y1,x2,y2).scaling(boxScale), keyPoints)); } } } @@ -147,4 +152,20 @@ public class InsightScrfdFaceDetection extends BaseOnnxInfer implements FaceDete return faces; } + private float getBoxScale(Map params){ + float boxScale = 0; + try { + if(null != params && params.containsKey(boxScaleParamKey)){ + Object value = params.get(boxScaleParamKey); + if(null != value){ + if (value instanceof Number){ + boxScale = ((Number) value).floatValue(); + }else{ + boxScale = Float.parseFloat(value.toString()); + } + } + } + }catch (Exception e){} + return boxScale > 0 ? boxScale : defBoxScale; + } } diff --git a/face-search-core/src/main/java/com/visual/face/search/core/models/SeetaMaskFaceKeyPoint.java b/face-search-core/src/main/java/com/visual/face/search/core/models/SeetaMaskFaceKeyPoint.java new file mode 100755 index 0000000..e3983f4 --- /dev/null +++ b/face-search-core/src/main/java/com/visual/face/search/core/models/SeetaMaskFaceKeyPoint.java @@ -0,0 +1,94 @@ +package com.visual.face.search.core.models; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtSession; +import com.visual.face.search.core.base.BaseOnnxInfer; +import com.visual.face.search.core.base.FaceMaskPoint; +import com.visual.face.search.core.domain.ImageMat; +import com.visual.face.search.core.domain.QualityInfo; +import com.visual.face.search.core.utils.SoftMaxUtil; +import org.opencv.core.*; +import org.opencv.imgproc.Imgproc; + +import java.util.Collections; +import java.util.Map; + +public class SeetaMaskFaceKeyPoint extends BaseOnnxInfer implements FaceMaskPoint { + + private static final int stride = 8; + private static final int shape = 128; + + /** + * 构造函数 + * @param modelPath 模型路径 + * @param threads 线程数 + */ + public SeetaMaskFaceKeyPoint(String modelPath, int threads) { + super(modelPath, threads); + } + + /** + * 人脸关键点检测 + * + * @param imageMat 图像数据 + * @param params 参数信息 + * @return + */ + @Override + public QualityInfo.MaskPoints inference(ImageMat imageMat, Map params) { + Mat borderMat = null; + Mat resizeMat = null; + OnnxTensor tensor = null; + OrtSession.Result output = null; + try { + Mat image = imageMat.toCvMat(); + //将图片转换为正方形 + int w = imageMat.getWidth(); + int h = imageMat.getHeight(); + int new_w = Math.max(h, w); + int new_h = Math.max(h, w); + if (Math.max(h, w) % stride != 0){ + new_w = new_w + (stride - Math.max(h, w) % stride); + new_h = new_h + (stride - Math.max(h, w) % stride); + } + int ow = (new_w - w) / 2; + int oh = (new_h - h) / 2; + borderMat = new Mat(); + Core.copyMakeBorder(image, borderMat, oh, oh, ow, ow, Core.BORDER_CONSTANT, new Scalar(114, 114, 114)); + //对图片进行resize + float ratio = 1.0f * shape / new_h; + resizeMat = new Mat(); + Imgproc.resize(borderMat, resizeMat, new Size(shape, shape)); + //模型推理 + tensor = ImageMat.fromCVMat(resizeMat) + .blobFromImageAndDoReleaseMat(1.0/32, new Scalar(104, 117, 123), false) + .to4dFloatOnnxTensorAndDoReleaseMat(true); + output = this.getSession().run(Collections.singletonMap(this.getInputName(), tensor)); + float[] value = ((float[][]) output.get(0).getValue())[0]; + //转换为标准的坐标点 + QualityInfo.MaskPoints pointList = QualityInfo.MaskPoints.build(); + for(int i=0; i<5; i++){ + float x = value[i * 4 + 0] / ratio * 128 - ow; + float y = value[i * 4 + 1] / ratio * 128 - oh; + double[] softMax = SoftMaxUtil.softMax(new double[]{value[i * 4 + 2], value[i * 4 + 3]}); + pointList.add(QualityInfo.MaskPoint.build(x, y, Double.valueOf(softMax[1]).floatValue())); + } + return pointList; + } catch (Exception e) { + throw new RuntimeException(e); + }finally { + if(null != tensor){ + tensor.close(); + } + if(null != output){ + output.close(); + } + if(null != borderMat){ + borderMat.release(); + } + if(null != resizeMat){ + resizeMat.release(); + } + } + } +} diff --git a/face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx b/face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx new file mode 100644 index 0000000..eeb184f Binary files /dev/null and b/face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx differ diff --git a/face-search-core/src/test/java/com/visual/face/search/core/test/extract/FaceCompareTest.java b/face-search-core/src/test/java/com/visual/face/search/core/test/extract/FaceCompareTest.java index 6ce7e8e..36821c4 100755 --- a/face-search-core/src/test/java/com/visual/face/search/core/test/extract/FaceCompareTest.java +++ b/face-search-core/src/test/java/com/visual/face/search/core/test/extract/FaceCompareTest.java @@ -15,6 +15,7 @@ import org.opencv.imgcodecs.Imgcodecs; import java.util.Arrays; import java.util.List; +import java.util.Map; public class FaceCompareTest extends BaseTest { @@ -25,19 +26,22 @@ public class FaceCompareTest extends BaseTest { private static String modelCoordPath = "face-search-core/src/main/resources/model/onnx/keypoint_coordinate/coordinate_106_mobilenet_05.onnx"; private static String modelArcPath = "face-search-core/src/main/resources/model/onnx/recognition_face_arc/glint360k_cosface_r18_fp16_0.1.onnx"; private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_512.onnx"; -// private static String modelSeetaPath = "face-search-core/src/main/resources/model/onnx/recognition_face_seeta/face_recognizer_1024.onnx"; private static String modelArrPath = "face-search-core/src/main/resources/model/onnx/attribute_gender_age/insight_gender_age.onnx"; -// private static String imagePath = "face-search-core/src/test/resources/images/faces"; + private static String imagePath = "face-search-test/src/main/resources/image/validate/index/马化腾/"; + private static String imagePath3 = "face-search-test/src/main/resources/image/validate/index/雷军/"; // private static String imagePath1 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg"; // private static String imagePath2 = "face-search-core/src/test/resources/images/faces/debug/debug_0001.jpg"; - private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg"; - private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg"; +// private static String imagePath1 = "face-search-core/src/test/resources/images/faces/compare/1682052661610.jpg"; +// private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682052669004.jpg"; // private static String imagePath2 = "face-search-core/src/test/resources/images/faces/compare/1682053163961.jpg"; - - +// private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/1c7abcaf2dabdd2bc08e90c224d4c381.jpeg"; + private static String imagePath1 = "face-search-test/src/main/resources/image/validate/index/张一鸣/0762c790db41a64f8f3f97598a825372.jpeg"; + private static String imagePath2 = "face-search-test/src/main/resources/image/validate/index/张一鸣/ea191e61bdd4be8fddc89b828f5399b6.jpeg"; public static void main(String[] args) { -// Map map = getImagePathMap(imagePath); + //口罩模型0.48,light模型0.52,normal模型0.62 + Map map1 = getImagePathMap(imagePath); + Map map2 = getImagePathMap(imagePath3); FaceDetection insightScrfdFaceDetection = new InsightScrfdFaceDetection(modelScrfdPath, 1); FaceKeyPoint insightCoordFaceKeyPoint = new InsightCoordFaceKeyPoint(modelCoordPath, 1); FaceRecognition insightArcFaceRecognition = new InsightArcFaceRecognition(modelArcPath, 1); @@ -49,24 +53,28 @@ public class FaceCompareTest extends BaseTest { FaceFeatureExtractor extractor = new FaceFeatureExtractorImpl( insightScrfdFaceDetection, pcnNetworkFaceDetection, insightCoordFaceKeyPoint, - simple005pFaceAlignment, insightSeetaFaceRecognition, insightFaceAttribute); + simple005pFaceAlignment, insightArcFaceRecognition, insightFaceAttribute); - Mat image1 = Imgcodecs.imread(imagePath1); - long s = System.currentTimeMillis(); - ExtParam extParam = ExtParam.build().setMask(false).setTopK(20).setScoreTh(0).setIouTh(0); - FaceImage faceImage1 = extractor.extract(ImageMat.fromCVMat(image1), extParam, null); - List faceInfos1 = faceImage1.faceInfos(); - long e = System.currentTimeMillis(); - System.out.println("image1 extract cost:"+(e-s)+"ms");; + for(String file1 : map1.keySet()){ + for(String file2 : map2.keySet()){ + Mat image1 = Imgcodecs.imread(map1.get(file1)); + long s = System.currentTimeMillis(); + ExtParam extParam = ExtParam.build().setMask(false).setTopK(20).setScoreTh(0).setIouTh(0); + FaceImage faceImage1 = extractor.extract(ImageMat.fromCVMat(image1), extParam, null); + List faceInfos1 = faceImage1.faceInfos(); + long e = System.currentTimeMillis(); + System.out.println("image1 extract cost:"+(e-s)+"ms");; - Mat image2 = Imgcodecs.imread(imagePath2); - s = System.currentTimeMillis(); - FaceImage faceImage2 = extractor.extract(ImageMat.fromCVMat(image2), extParam, null); - List faceInfos2 = faceImage2.faceInfos(); - e = System.currentTimeMillis(); - System.out.println("image2 extract cost:"+(e-s)+"ms"); - float similarity = Similarity.cosineSimilarity(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds); - System.out.println("face similarity="+similarity); + Mat image2 = Imgcodecs.imread(map2.get(file2)); + s = System.currentTimeMillis(); + FaceImage faceImage2 = extractor.extract(ImageMat.fromCVMat(image2), extParam, null); + List faceInfos2 = faceImage2.faceInfos(); + e = System.currentTimeMillis(); + System.out.println("image2 extract cost:"+(e-s)+"ms"); + float similarity = Similarity.cosineSimilarityNorm(faceInfos1.get(0).embedding.embeds, faceInfos2.get(0).embedding.embeds); + System.out.println(file1 + ","+ file2 + ",face similarity="+similarity); + } + } } } diff --git a/face-search-core/src/test/java/com/visual/face/search/core/test/models/SeetaMaskFaceKeyPointTest.java b/face-search-core/src/test/java/com/visual/face/search/core/test/models/SeetaMaskFaceKeyPointTest.java new file mode 100644 index 0000000..9b9c605 --- /dev/null +++ b/face-search-core/src/test/java/com/visual/face/search/core/test/models/SeetaMaskFaceKeyPointTest.java @@ -0,0 +1,57 @@ +package com.visual.face.search.core.test.models; + +import com.visual.face.search.core.domain.FaceInfo; +import com.visual.face.search.core.domain.ImageMat; +import com.visual.face.search.core.domain.QualityInfo; +import com.visual.face.search.core.models.InsightCoordFaceKeyPoint; +import com.visual.face.search.core.models.InsightScrfdFaceDetection; +import com.visual.face.search.core.models.SeetaMaskFaceKeyPoint; +import com.visual.face.search.core.test.base.BaseTest; +import com.visual.face.search.core.utils.CropUtil; +import org.opencv.core.Mat; +import org.opencv.core.Point; +import org.opencv.core.Scalar; +import org.opencv.highgui.HighGui; +import org.opencv.imgcodecs.Imgcodecs; +import org.opencv.imgproc.Imgproc; + +import java.util.List; +import java.util.Map; + +public class SeetaMaskFaceKeyPointTest extends BaseTest { + private static String modelDetectionPath = "face-search-core/src/main/resources/model/onnx/detection_face_scrfd/scrfd_500m_bnkps.onnx"; + private static String modelKeypointPath = "face-search-core/src/main/resources/model/onnx/keypoint_seeta_mask/landmarker_005_mask_pts5.onnx"; + private static String imagePath = "face-search-core/src/test/resources/images/faces"; +// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare"; +// private static String imagePath = "face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg"; + + public static void main(String[] args) { + Map map = getImagePathMap(imagePath); + InsightScrfdFaceDetection detectionInfer = new InsightScrfdFaceDetection(modelDetectionPath, 1); + SeetaMaskFaceKeyPoint keyPointInfer = new SeetaMaskFaceKeyPoint(modelKeypointPath, 1); + for(String fileName : map.keySet()) { + System.out.println(fileName); + String imageFilePath = map.get(fileName); + Mat image = Imgcodecs.imread(imageFilePath); + List faceInfos = detectionInfer.inference(ImageMat.fromCVMat(image), 0.5f, 0.7f, null); + for(FaceInfo faceInfo : faceInfos){ + FaceInfo.FaceBox rotateFaceBox = faceInfo.rotateFaceBox(); + Mat cropFace = CropUtil.crop(image, rotateFaceBox.scaling(1.0f)); + ImageMat cropImageMat = ImageMat.fromCVMat(cropFace); + QualityInfo.MaskPoints maskPoints = keyPointInfer.inference(cropImageMat, null); + System.out.println(maskPoints); + for(QualityInfo.MaskPoint maskPoint : maskPoints){ + if(maskPoint.score >= 0.5){ + Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(0,0,255), -1); + }else{ + Imgproc.circle(cropFace, new Point(maskPoint.x, maskPoint.y), 3, new Scalar(255, 0,0), -1); + } + } + HighGui.imshow(fileName, cropFace); + HighGui.waitKey(); + } + } + + } + +} diff --git a/face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg b/face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg new file mode 100644 index 0000000..525a4da Binary files /dev/null and b/face-search-core/src/test/resources/images/faces/compare/1694353163955.jpg differ