fix bug
This commit is contained in:
parent
b1dfad4a21
commit
fdf9e05803
@ -111,7 +111,7 @@ class DetectorLoss(nn.Module):
|
|||||||
tobj = torch.zeros_like(pobj)
|
tobj = torch.zeros_like(pobj)
|
||||||
factor = torch.ones_like(pobj) * 0.75
|
factor = torch.ones_like(pobj) * 0.75
|
||||||
|
|
||||||
if len(gt_box[0]) > 0:
|
if len(gt_box) > 0:
|
||||||
# 计算检测框回归loss
|
# 计算检测框回归loss
|
||||||
b, gx, gy = ps_index[0]
|
b, gx, gy = ps_index[0]
|
||||||
ptbox = torch.ones((preg[b, gy, gx].shape)).to(self.device)
|
ptbox = torch.ones((preg[b, gy, gx].shape)).to(self.device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user