FastestDet/utils/tool.py
2022-07-14 21:11:22 +08:00

121 lines
4.0 KiB
Python

import yaml
import torch
import torchvision
# 解析yaml配置文件
class LoadYaml:
def __init__(self, path):
with open(path, encoding='utf8') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
self.val_txt = data["DATASET"]["VAL"]
self.train_txt = data["DATASET"]["TRAIN"]
self.names = data["DATASET"]["NAMES"]
self.learn_rate = data["TRAIN"]["LR"]
self.batch_size = data["TRAIN"]["BATCH_SIZE"]
self.milestones = data["TRAIN"]["MILESTIONES"]
self.end_epoch = data["TRAIN"]["END_EPOCH"]
self.input_width = data["MODEL"]["INPUT_WIDTH"]
self.input_height = data["MODEL"]["INPUT_HEIGHT"]
self.category_num = data["MODEL"]["NC"]
print("Load yaml sucess...")
class EMA():
def __init__(self, model, decay):
self.model = model
self.decay = decay
self.shadow = {}
self.backup = {}
def register(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
# 后处理(归一化后的坐标)
def handle_preds(preds, device, conf_thresh=0.25, nms_thresh=0.45):
total_bboxes, output_bboxes = [], []
# 将特征图转换为检测框的坐标
N, C, H, W = preds.shape
bboxes = torch.zeros((N, H, W, 6))
pred = preds.permute(0, 2, 3, 1)
# 前背景分类分支
pobj = pred[:, :, :, 0].unsqueeze(dim=-1)
# 检测框回归分支
preg = pred[:, :, :, 1:5]
# 目标类别分类分支
pcls = pred[:, :, :, 5:]
# 检测框置信度
bboxes[..., 4] = (pobj.squeeze(-1) ** 0.6) * (pcls.max(dim=-1)[0] ** 0.4)
bboxes[..., 5] = pcls.argmax(dim=-1)
# 检测框的坐标
gy, gx = torch.meshgrid([torch.arange(H), torch.arange(W)])
bw, bh = preg[..., 2].sigmoid(), preg[..., 3].sigmoid()
bcx = (preg[..., 0].tanh() + gx.to(device)) / W
bcy = (preg[..., 1].tanh() + gy.to(device)) / H
# cx,cy,w,h = > x1,y1,x2,y1
x1, y1 = bcx - 0.5 * bw, bcy - 0.5 * bh
x2, y2 = bcx + 0.5 * bw, bcy + 0.5 * bh
bboxes[..., 0], bboxes[..., 1] = x1, y1
bboxes[..., 2], bboxes[..., 3] = x2, y2
bboxes = bboxes.reshape(N, H*W, 6)
total_bboxes.append(bboxes)
batch_bboxes = torch.cat(total_bboxes, 1)
# 对检测框进行NMS处理
for p in batch_bboxes:
output, temp = [], []
b, s, c = [], [], []
# 阈值筛选
t = p[:, 4] > conf_thresh
pb = p[t]
for bbox in pb:
obj_score = bbox[4]
category = bbox[5]
x1, y1 = bbox[0], bbox[1]
x2, y2 = bbox[2], bbox[3]
s.append([obj_score])
c.append([category])
b.append([x1, y1, x2, y2])
temp.append([x1, y1, x2, y2, obj_score, category])
# Torchvision NMS
if len(b) > 0:
b = torch.Tensor(b).to(device)
c = torch.Tensor(c).squeeze(1).to(device)
s = torch.Tensor(s).squeeze(1).to(device)
keep = torchvision.ops.batched_nms(b, s, c, nms_thresh)
for i in keep:
output.append(temp[i])
output_bboxes.append(torch.Tensor(output))
return output_bboxes