diff --git a/README.md b/README.md index b9679bd..157ac71 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Network|mAPval 0.5|mAPval 0.5:0.95|Resolution|Run Time(4xCore)|Run Time(1xCore)| [nanodet_m](https://github.com/RangiLyu/nanodet)|-|20.6%|320X320|49.24ms|160.35ms|0.95M [yolo-fastestv1.1](https://github.com/dog-qiuqiu/Yolo-Fastest/tree/master/ModelZoo/yolo-fastest-1.1_coco)|24.40%|-|320X320|26.60ms|75.74ms|0.35M [yolo-fastestv2](https://github.com/dog-qiuqiu/Yolo-FastestV2/tree/main/modelzoo)|24.10%|-|352X352|23.8ms|68.9ms|0.25M -FastestDet|25.0%|12.3%|352X352|23.51ms|70.62ms|0.24M +FastestDet|25.3%|13.0%|352X352|23.51ms|70.62ms|0.24M * ***Test platform Radxa Rock3A RK3568 ARM Cortex-A55 CPU,Based on [NCNN](https://github.com/Tencent/ncnn)*** * ***CPU lock frequency 2.0GHz*** # Improvement @@ -40,7 +40,7 @@ Intel|i7-8700(X86-cpu)|Linux(amd64)|ncnn|4.51ms|4.33ms ## Test * Picture test ``` - python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg + python3 test.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth --img data/3.jpg ```
/> @@ -148,7 +148,7 @@ TRAIN: ### Evaluation * Calculate map evaluation ``` - python3 eval.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth + python3 eval.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth ``` * COCO2017 evaluation ``` @@ -161,30 +161,30 @@ TRAIN: DONE (t=30.85s). Accumulating evaluation results... DONE (t=4.97s). - Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.123 - Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.250 - Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.109 - Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.017 - Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.115 - Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.238 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.139 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.199 - Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.205 - Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.035 - Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.218 - Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.374 + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.130 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.253 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.119 + Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.021 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.129 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.237 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.142 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.208 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.214 + Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.043 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.236 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.372 ``` # Deploy ## Export onnx * You can export .onnx by adding the --onnx option when executing test.py ``` - python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --onnx + python3 test.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth --img data/3.jpg --onnx ``` ## Export torchscript * You can export .pt by adding the --torchscript option when executing test.py ``` - python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --torchscript + python3 test.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth --img data/3.jpg --torchscript ``` ## NCNN * Need to compile ncnn and opencv in advance and modify the path in build.sh diff --git a/data/0_result.png b/data/0_result.png deleted file mode 100644 index 972bc86..0000000 Binary files a/data/0_result.png and /dev/null differ diff --git a/data/1_result.png b/data/1_result.png deleted file mode 100644 index 541ed5e..0000000 Binary files a/data/1_result.png and /dev/null differ diff --git a/data/2_result.png b/data/2_result.png deleted file mode 100644 index 8468fe8..0000000 Binary files a/data/2_result.png and /dev/null differ diff --git a/data/3_result.png b/data/3_result.png deleted file mode 100644 index 5b11514..0000000 Binary files a/data/3_result.png and /dev/null differ diff --git a/data/4_result.png b/data/4_result.png deleted file mode 100644 index 7585b41..0000000 Binary files a/data/4_result.png and /dev/null differ diff --git a/example/onnx-runtime/1.jpg b/example/onnx-runtime/1.jpg deleted file mode 100644 index 7912379..0000000 Binary files a/example/onnx-runtime/1.jpg and /dev/null differ diff --git a/example/onnx-runtime/3.jpg b/example/onnx-runtime/3.jpg new file mode 100644 index 0000000..19c4fff Binary files /dev/null and b/example/onnx-runtime/3.jpg differ diff --git a/example/onnx-runtime/FastestDet.onnx b/example/onnx-runtime/FastestDet.onnx index cce8577..d60f81f 100644 Binary files a/example/onnx-runtime/FastestDet.onnx and b/example/onnx-runtime/FastestDet.onnx differ diff --git a/example/onnx-runtime/result.jpg b/example/onnx-runtime/result.jpg index 5cae23d..15453df 100644 Binary files a/example/onnx-runtime/result.jpg and b/example/onnx-runtime/result.jpg differ diff --git a/example/onnx-runtime/runtime.py b/example/onnx-runtime/runtime.py index 75a8eaa..622b073 100755 --- a/example/onnx-runtime/runtime.py +++ b/example/onnx-runtime/runtime.py @@ -20,7 +20,7 @@ def preprocess(src_img, size): return output.astype('float32') # nms算法 -def nms(dets, thresh=0.35): +def nms(dets, thresh=0.45): # dets:N*M,N是bbox的个数,M的前4位是对应的(x1,y1,x2,y2),第5位是对应的分数 # #thresh:0.3,0.5.... x1 = dets[:, 0] @@ -89,7 +89,7 @@ def detection(session, img, input_width, input_height, thresh): # 解析检测框置信度 obj_score, cls_score = data[0], data[5:].max() - score = obj_score * cls_score + score = (obj_score ** 0.6) * (cls_score ** 0.4) # 阈值筛选 if score > thresh: @@ -114,14 +114,14 @@ def detection(session, img, input_width, input_height, thresh): if __name__ == '__main__': # 读取图片 - img = cv2.imread("1.jpg") + img = cv2.imread("3.jpg") # 模型输入的宽高 input_width, input_height = 352, 352 # 加载模型 session = onnxruntime.InferenceSession('FastestDet.onnx') # 目标检测 start = time.perf_counter() - bboxes = detection(session, img, input_width, input_height, 0.8) + bboxes = detection(session, img, input_width, input_height, 0.65) end = time.perf_counter() time = (end - start) * 1000. print("forward time:%fms"%time) diff --git a/module/loss.py b/module/loss.py index 446411a..4569e4f 100644 --- a/module/loss.py +++ b/module/loss.py @@ -94,7 +94,8 @@ class DetectorLoss(nn.Module): # 定义obj和cls的损失函数 BCEcls = nn.NLLLoss() - BCEobj = nn.BCELoss(reduction='none') + # smmoth L1相比于bce效果最好 + BCEobj = nn.SmoothL1Loss(reduction='none') # 构建ground truth gt_box, gt_cls, ps_index = self.build_target(preds, targets) @@ -134,7 +135,8 @@ class DetectorLoss(nn.Module): ps = torch.log(pcls[b, gy, gx]) cls_loss = BCEcls(ps, gt_cls[0][f]) - tobj[b, gy, gx] = 1.0 + # iou aware + tobj[b, gy, gx] = iou.float() # 统计每个图片正样本的数量 n = torch.bincount(b) factor[b, gy, gx] = (1. / (n[b] / (H * W))) * 0.25 diff --git a/result.png b/result.png index 5b11514..41bc89c 100644 Binary files a/result.png and b/result.png differ diff --git a/test.py b/test.py index 53ca3dd..8fd6fae 100644 --- a/test.py +++ b/test.py @@ -44,7 +44,7 @@ if __name__ == '__main__': # 模型加载 print("load weight from:%s"%opt.weight) model = Detector(cfg.category_num, True).to(device) - model.load_state_dict(torch.load(opt.weight)) + model.load_state_dict(torch.load(opt.weight, map_location=device)) #sets the module in eval node model.eval() diff --git a/utils/datasets.py b/utils/datasets.py index 886b992..48fd6d0 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -5,7 +5,7 @@ import numpy as np import torch import random -def random_scale(image, boxes): +def random_crop(image, boxes): height, width, _ = image.shape # random crop imgage cw, ch = random.randint(int(width * 0.75), width), random.randint(int(height * 0.75), height) @@ -29,6 +29,30 @@ def random_scale(image, boxes): return roi, output +def random_narrow(image, boxes): + height, width, _ = image.shape + # random narrow + cw, ch = random.randint(width, int(width * 1.25)), random.randint(height, int(height * 1.25)) + cx, cy = random.randint(0, cw - width), random.randint(0, ch - height) + + background = np.ones((ch, cw, 3), np.uint8) * 128 + background[cy:cy + height, cx:cx + width] = image + + output = [] + for box in boxes: + index, category = box[0], box[1] + bx, by = box[2] * width, box[3] * height + bw, bh = box[4] * width, box[5] * height + + bx, by = (bx + cx)/cw, (by + cy)/ch + bw, bh = bw/cw, bh/ch + + output.append([index, category, bx, by, bw, bh]) + + output = np.array(output, dtype=float) + + return background, output + def collate_fn(batch): img, label = zip(*batch) for i, l in enumerate(label): @@ -84,7 +108,10 @@ class TensorDataset(): # 是否进行数据增强 if self.aug: - img, label = random_scale(img, label) + if random.randint(1, 10) % 2 == 0: + img, label = random_narrow(img, label) + else: + img, label = random_crop(img, label) img = cv2.resize(img, (self.img_width, self.img_height), interpolation = cv2.INTER_LINEAR) diff --git a/utils/tool.py b/utils/tool.py index fd4ee4f..bc7f07a 100644 --- a/utils/tool.py +++ b/utils/tool.py @@ -72,7 +72,7 @@ def handle_preds(preds, device, conf_thresh=0.25, nms_thresh=0.45): pcls = pred[:, :, :, 5:] # 检测框置信度 - bboxes[..., 4] = pobj.squeeze(-1) * pcls.max(dim=-1)[0] + bboxes[..., 4] = (pobj.squeeze(-1) ** 0.6) * (pcls.max(dim=-1)[0] ** 0.4) bboxes[..., 5] = pcls.argmax(dim=-1) # 检测框的坐标 diff --git a/weights/coco_ap05_0.250_280epoch.pth b/weights/coco_ap05_0.250_280epoch.pth deleted file mode 100644 index cd78427..0000000 Binary files a/weights/coco_ap05_0.250_280epoch.pth and /dev/null differ diff --git a/weights/weight_AP05:0.253207_280-epoch.pth b/weights/weight_AP05:0.253207_280-epoch.pth new file mode 100644 index 0000000..0f9d984 Binary files /dev/null and b/weights/weight_AP05:0.253207_280-epoch.pth differ