This commit is contained in:
duoduo 2022-07-05 19:44:14 +08:00
parent b1dfad4a21
commit fdf9e05803

View File

@ -111,7 +111,7 @@ class DetectorLoss(nn.Module):
tobj = torch.zeros_like(pobj) tobj = torch.zeros_like(pobj)
factor = torch.ones_like(pobj) * 0.75 factor = torch.ones_like(pobj) * 0.75
if len(gt_box[0]) > 0: if len(gt_box) > 0:
# 计算检测框回归loss # 计算检测框回归loss
b, gx, gy = ps_index[0] b, gx, gy = ps_index[0]
ptbox = torch.ones((preg[b, gy, gx].shape)).to(self.device) ptbox = torch.ones((preg[b, gy, gx].shape)).to(self.device)