Skip to content

Commit ca81e4a

Browse files
committed
fix tensor shape to match expected (B=1, C=3, H=1, W=2) format
Signed-off-by: ytl0623 <[email protected]>
1 parent 41dccad commit ca81e4a

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

tests/losses/test_unified_focal_loss.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,8 @@
4848
"include_background": True,
4949
},
5050
{
51-
# Logits:
52-
# Pixel 1: Class 0 is target -> [10, -10, -10]
53-
# Pixel 2: Class 2 is target -> [-10, -10, 10]
54-
"y_pred": torch.tensor(
55-
[[[[logit_pos, logit_neg], [logit_neg, logit_neg], [logit_neg, logit_pos]]]] # Ch 0 # Ch 1 # Ch 2
56-
),
57-
"y_true": torch.tensor([[[[1.0, 0.0], [0.0, 0.0], [0.0, 1.0]]]]), # Ch 0 (Background) # Ch 1 # Ch 2
51+
"y_pred": torch.tensor([[[[logit_pos, logit_neg]], [[logit_neg, logit_neg]], [[logit_neg, logit_pos]]]]),
52+
"y_true": torch.tensor([[[[1.0, 0.0]], [[0.0, 0.0]], [[0.0, 1.0]]]]),
5853
},
5954
0.0,
6055
],

0 commit comments

Comments
 (0)