Skip to content

Commit f7cad77

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent edb01ce commit f7cad77

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

tests/losses/test_unified_focal_loss.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
# Input Shape: (B, 1, H, W) -> Auto expanded internally
2323
TEST_CASE_BINARY_LOGITS = [
2424
{
25-
"y_pred": torch.tensor([[[[10.0, -10.0], [-10.0, 10.0]]]]),
26-
"y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
25+
"y_pred": torch.tensor([[[[10.0, -10.0], [-10.0, 10.0]]]]),
26+
"y_true": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
2727
},
2828
0.0,
29-
{"use_softmax": False, "to_onehot_y": False}
29+
{"use_softmax": False, "to_onehot_y": False}
3030
]
3131

3232
# 2. Binary Case (2 Channels input): Prediction matches GT perfectly
@@ -56,12 +56,12 @@
5656
# 4. Multi-Class Case: Wrong Prediction
5757
TEST_CASE_MULTICLASS_WRONG = [
5858
{
59-
"y_pred": torch.tensor([[[[-10.0, -10.0], [-10.0, -10.0]],
60-
[[10.0, 10.0], [10.0, 10.0]],
61-
[[-10.0, -10.0], [-10.0, -10.0]]]]),
59+
"y_pred": torch.tensor([[[[-10.0, -10.0], [-10.0, -10.0]],
60+
[[10.0, 10.0], [10.0, 10.0]],
61+
[[-10.0, -10.0], [-10.0, -10.0]]]]),
6262
"y_true": torch.tensor([[[[0, 0], [0, 0]]]]), # GT is class 0, but Pred is class 1
6363
},
64-
None,
64+
None,
6565
{"use_softmax": True, "to_onehot_y": True}
6666
]
6767

@@ -99,11 +99,11 @@ def test_with_cuda(self):
9999
# Binary logits case on GPU
100100
i = torch.tensor([[[[10.0, 0], [0, 10.0]]], [[[10.0, 0], [0, 10.0]]]]).cuda()
101101
j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]).cuda()
102-
102+
103103
output = loss(i, j)
104104
print(f"CUDA Output: {output.item()}")
105105
self.assertTrue(output.is_cuda)
106-
self.assertLess(output.item(), 1.0)
106+
self.assertLess(output.item(), 1.0)
107107

108108
if __name__ == "__main__":
109-
unittest.main()
109+
unittest.main()

0 commit comments

Comments
 (0)