diff --git a/shimmer_ssd/modules/domains/attribute.py b/shimmer_ssd/modules/domains/attribute.py index 7974692..b6b6315 100644 --- a/shimmer_ssd/modules/domains/attribute.py +++ b/shimmer_ssd/modules/domains/attribute.py @@ -324,8 +324,8 @@ def __init__(self): def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput: - pred_cat, pred_attr, _ = self.decode(pred) - target_cat, target_attr, _ = self.decode(target) + pred_cat, pred_attr = self.decode(pred) + target_cat, target_attr = self.decode(target) loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean") loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1))