diff --git a/module/loss.py b/module/loss.py index 9658eb9..446411a 100644 --- a/module/loss.py +++ b/module/loss.py @@ -111,7 +111,7 @@ class DetectorLoss(nn.Module): tobj = torch.zeros_like(pobj) factor = torch.ones_like(pobj) * 0.75 - if len(gt_box[0]) > 0: + if len(gt_box) > 0: # 计算检测框回归loss b, gx, gy = ps_index[0] ptbox = torch.ones((preg[b, gy, gx].shape)).to(self.device)