Skip to content

Commit

Permalink
missing targets
Browse files Browse the repository at this point in the history
  • Loading branch information
pfeatherstone committed Aug 30, 2024
1 parent e84e88a commit c8c9b92
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,10 +1079,10 @@ def __init__(self, variant, num_classes):
self.fpn = HeadV5(w, r, d)
self.head = Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*2)))

def forward(self, x):
def forward(self, x, targets=None):
x = self.net(x)
x = self.fpn(*x)
return self.head(x)
return self.head(x, targets)

class Yolov8(nn.Module):
def __init__(self, variant, num_classes):
Expand Down Expand Up @@ -1120,10 +1120,10 @@ def __init__(self, variant, num_classes):
self.fpn = CSPRepBiFPANNeck(w, d, csp_e=csp_e) if csp else RepBiFPANNeck(w, d)
self.head = DetectV6(num_classes, ch=(int(128*w), int(256*w), int(512*w)), use_dfl=True, distill=distill)

def forward(self, x):
def forward(self, x, targets=None):
x = self.net(x)
x = self.fpn(*x)
return self.head(x)
return self.head(x, targets)

@torch.no_grad()
def nms(preds: torch.Tensor, conf_thresh: float, nms_thresh: float , has_objectness: bool):
Expand Down

0 comments on commit c8c9b92

Please sign in to comment.