diff --git a/README.md b/README.md index b9679bd..157ac71 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Network|mAPval 0.5|mAPval 0.5:0.95|Resolution|Run Time(4xCore)|Run Time(1xCore)| [nanodet_m](https://github.com/RangiLyu/nanodet)|-|20.6%|320X320|49.24ms|160.35ms|0.95M [yolo-fastestv1.1](https://github.com/dog-qiuqiu/Yolo-Fastest/tree/master/ModelZoo/yolo-fastest-1.1_coco)|24.40%|-|320X320|26.60ms|75.74ms|0.35M [yolo-fastestv2](https://github.com/dog-qiuqiu/Yolo-FastestV2/tree/main/modelzoo)|24.10%|-|352X352|23.8ms|68.9ms|0.25M -FastestDet|25.0%|12.3%|352X352|23.51ms|70.62ms|0.24M +FastestDet|25.3%|13.0%|352X352|23.51ms|70.62ms|0.24M * ***Test platform Radxa Rock3A RK3568 ARM Cortex-A55 CPU,Based on [NCNN](https://github.com/Tencent/ncnn)*** * ***CPU lock frequency 2.0GHz*** # Improvement @@ -40,7 +40,7 @@ Intel|i7-8700(X86-cpu)|Linux(amd64)|ncnn|4.51ms|4.33ms ## Test * Picture test ``` - python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg + python3 test.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth --img data/3.jpg ```
/>
@@ -148,7 +148,7 @@ TRAIN:
### Evaluation
* Calculate map evaluation
```
- python3 eval.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth
+ python3 eval.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth
```
* COCO2017 evaluation
```
@@ -161,30 +161,30 @@ TRAIN:
DONE (t=30.85s).
Accumulating evaluation results...
DONE (t=4.97s).
- Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.123
- Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.250
- Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.109
- Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.017
- Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.115
- Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.238
- Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.139
- Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.199
- Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.205
- Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.035
- Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.218
- Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.374
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.130
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.253
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.119
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.021
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.129
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.237
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.142
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.208
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.214
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.043
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.236
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.372
```
# Deploy
## Export onnx
* You can export .onnx by adding the --onnx option when executing test.py
```
- python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --onnx
+ python3 test.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth --img data/3.jpg --onnx
```
## Export torchscript
* You can export .pt by adding the --torchscript option when executing test.py
```
- python3 test.py --yaml configs/coco.yaml --weight weights/coco_ap05_0.250_280epoch.pth --img data/3.jpg --torchscript
+ python3 test.py --yaml configs/coco.yaml --weight weights/weight_AP05:0.253207_280-epoch.pth --img data/3.jpg --torchscript
```
## NCNN
* Need to compile ncnn and opencv in advance and modify the path in build.sh
diff --git a/data/0_result.png b/data/0_result.png
deleted file mode 100644
index 972bc86..0000000
Binary files a/data/0_result.png and /dev/null differ
diff --git a/data/1_result.png b/data/1_result.png
deleted file mode 100644
index 541ed5e..0000000
Binary files a/data/1_result.png and /dev/null differ
diff --git a/data/2_result.png b/data/2_result.png
deleted file mode 100644
index 8468fe8..0000000
Binary files a/data/2_result.png and /dev/null differ
diff --git a/data/3_result.png b/data/3_result.png
deleted file mode 100644
index 5b11514..0000000
Binary files a/data/3_result.png and /dev/null differ
diff --git a/data/4_result.png b/data/4_result.png
deleted file mode 100644
index 7585b41..0000000
Binary files a/data/4_result.png and /dev/null differ
diff --git a/example/onnx-runtime/1.jpg b/example/onnx-runtime/1.jpg
deleted file mode 100644
index 7912379..0000000
Binary files a/example/onnx-runtime/1.jpg and /dev/null differ
diff --git a/example/onnx-runtime/3.jpg b/example/onnx-runtime/3.jpg
new file mode 100644
index 0000000..19c4fff
Binary files /dev/null and b/example/onnx-runtime/3.jpg differ
diff --git a/example/onnx-runtime/FastestDet.onnx b/example/onnx-runtime/FastestDet.onnx
index cce8577..d60f81f 100644
Binary files a/example/onnx-runtime/FastestDet.onnx and b/example/onnx-runtime/FastestDet.onnx differ
diff --git a/example/onnx-runtime/result.jpg b/example/onnx-runtime/result.jpg
index 5cae23d..15453df 100644
Binary files a/example/onnx-runtime/result.jpg and b/example/onnx-runtime/result.jpg differ
diff --git a/example/onnx-runtime/runtime.py b/example/onnx-runtime/runtime.py
index 75a8eaa..622b073 100755
--- a/example/onnx-runtime/runtime.py
+++ b/example/onnx-runtime/runtime.py
@@ -20,7 +20,7 @@ def preprocess(src_img, size):
return output.astype('float32')
# nms算法
-def nms(dets, thresh=0.35):
+def nms(dets, thresh=0.45):
# dets:N*M,N是bbox的个数,M的前4位是对应的(x1,y1,x2,y2),第5位是对应的分数
# #thresh:0.3,0.5....
x1 = dets[:, 0]
@@ -89,7 +89,7 @@ def detection(session, img, input_width, input_height, thresh):
# 解析检测框置信度
obj_score, cls_score = data[0], data[5:].max()
- score = obj_score * cls_score
+ score = (obj_score ** 0.6) * (cls_score ** 0.4)
# 阈值筛选
if score > thresh:
@@ -114,14 +114,14 @@ def detection(session, img, input_width, input_height, thresh):
if __name__ == '__main__':
# 读取图片
- img = cv2.imread("1.jpg")
+ img = cv2.imread("3.jpg")
# 模型输入的宽高
input_width, input_height = 352, 352
# 加载模型
session = onnxruntime.InferenceSession('FastestDet.onnx')
# 目标检测
start = time.perf_counter()
- bboxes = detection(session, img, input_width, input_height, 0.8)
+ bboxes = detection(session, img, input_width, input_height, 0.65)
end = time.perf_counter()
time = (end - start) * 1000.
print("forward time:%fms"%time)
diff --git a/module/loss.py b/module/loss.py
index 446411a..4569e4f 100644
--- a/module/loss.py
+++ b/module/loss.py
@@ -94,7 +94,8 @@ class DetectorLoss(nn.Module):
# 定义obj和cls的损失函数
BCEcls = nn.NLLLoss()
- BCEobj = nn.BCELoss(reduction='none')
+ # smmoth L1相比于bce效果最好
+ BCEobj = nn.SmoothL1Loss(reduction='none')
# 构建ground truth
gt_box, gt_cls, ps_index = self.build_target(preds, targets)
@@ -134,7 +135,8 @@ class DetectorLoss(nn.Module):
ps = torch.log(pcls[b, gy, gx])
cls_loss = BCEcls(ps, gt_cls[0][f])
- tobj[b, gy, gx] = 1.0
+ # iou aware
+ tobj[b, gy, gx] = iou.float()
# 统计每个图片正样本的数量
n = torch.bincount(b)
factor[b, gy, gx] = (1. / (n[b] / (H * W))) * 0.25
diff --git a/result.png b/result.png
index 5b11514..41bc89c 100644
Binary files a/result.png and b/result.png differ
diff --git a/test.py b/test.py
index 53ca3dd..8fd6fae 100644
--- a/test.py
+++ b/test.py
@@ -44,7 +44,7 @@ if __name__ == '__main__':
# 模型加载
print("load weight from:%s"%opt.weight)
model = Detector(cfg.category_num, True).to(device)
- model.load_state_dict(torch.load(opt.weight))
+ model.load_state_dict(torch.load(opt.weight, map_location=device))
#sets the module in eval node
model.eval()
diff --git a/utils/datasets.py b/utils/datasets.py
index 886b992..48fd6d0 100644
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -5,7 +5,7 @@ import numpy as np
import torch
import random
-def random_scale(image, boxes):
+def random_crop(image, boxes):
height, width, _ = image.shape
# random crop imgage
cw, ch = random.randint(int(width * 0.75), width), random.randint(int(height * 0.75), height)
@@ -29,6 +29,30 @@ def random_scale(image, boxes):
return roi, output
+def random_narrow(image, boxes):
+ height, width, _ = image.shape
+ # random narrow
+ cw, ch = random.randint(width, int(width * 1.25)), random.randint(height, int(height * 1.25))
+ cx, cy = random.randint(0, cw - width), random.randint(0, ch - height)
+
+ background = np.ones((ch, cw, 3), np.uint8) * 128
+ background[cy:cy + height, cx:cx + width] = image
+
+ output = []
+ for box in boxes:
+ index, category = box[0], box[1]
+ bx, by = box[2] * width, box[3] * height
+ bw, bh = box[4] * width, box[5] * height
+
+ bx, by = (bx + cx)/cw, (by + cy)/ch
+ bw, bh = bw/cw, bh/ch
+
+ output.append([index, category, bx, by, bw, bh])
+
+ output = np.array(output, dtype=float)
+
+ return background, output
+
def collate_fn(batch):
img, label = zip(*batch)
for i, l in enumerate(label):
@@ -84,7 +108,10 @@ class TensorDataset():
# 是否进行数据增强
if self.aug:
- img, label = random_scale(img, label)
+ if random.randint(1, 10) % 2 == 0:
+ img, label = random_narrow(img, label)
+ else:
+ img, label = random_crop(img, label)
img = cv2.resize(img, (self.img_width, self.img_height), interpolation = cv2.INTER_LINEAR)
diff --git a/utils/tool.py b/utils/tool.py
index fd4ee4f..bc7f07a 100644
--- a/utils/tool.py
+++ b/utils/tool.py
@@ -72,7 +72,7 @@ def handle_preds(preds, device, conf_thresh=0.25, nms_thresh=0.45):
pcls = pred[:, :, :, 5:]
# 检测框置信度
- bboxes[..., 4] = pobj.squeeze(-1) * pcls.max(dim=-1)[0]
+ bboxes[..., 4] = (pobj.squeeze(-1) ** 0.6) * (pcls.max(dim=-1)[0] ** 0.4)
bboxes[..., 5] = pcls.argmax(dim=-1)
# 检测框的坐标
diff --git a/weights/coco_ap05_0.250_280epoch.pth b/weights/coco_ap05_0.250_280epoch.pth
deleted file mode 100644
index cd78427..0000000
Binary files a/weights/coco_ap05_0.250_280epoch.pth and /dev/null differ
diff --git a/weights/weight_AP05:0.253207_280-epoch.pth b/weights/weight_AP05:0.253207_280-epoch.pth
new file mode 100644
index 0000000..0f9d984
Binary files /dev/null and b/weights/weight_AP05:0.253207_280-epoch.pth differ