2222# Input Shape: (B, 1, H, W) -> Auto expanded internally
2323TEST_CASE_BINARY_LOGITS = [
2424 {
25- "y_pred" : torch .tensor ([[[[10.0 , - 10.0 ], [- 10.0 , 10.0 ]]]]),
26- "y_true" : torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]]),
25+ "y_pred" : torch .tensor ([[[[10.0 , - 10.0 ], [- 10.0 , 10.0 ]]]]),
26+ "y_true" : torch .tensor ([[[[1.0 , 0.0 ], [0.0 , 1.0 ]]]]),
2727 },
2828 0.0 ,
29- {"use_softmax" : False , "to_onehot_y" : False }
29+ {"use_softmax" : False , "to_onehot_y" : False }
3030]
3131
3232# 2. Binary Case (2 Channels input): Prediction matches GT perfectly
5656# 4. Multi-Class Case: Wrong Prediction
5757TEST_CASE_MULTICLASS_WRONG = [
5858 {
59- "y_pred" : torch .tensor ([[[[- 10.0 , - 10.0 ], [- 10.0 , - 10.0 ]],
60- [[10.0 , 10.0 ], [10.0 , 10.0 ]],
61- [[- 10.0 , - 10.0 ], [- 10.0 , - 10.0 ]]]]),
59+ "y_pred" : torch .tensor ([[[[- 10.0 , - 10.0 ], [- 10.0 , - 10.0 ]],
60+ [[10.0 , 10.0 ], [10.0 , 10.0 ]],
61+ [[- 10.0 , - 10.0 ], [- 10.0 , - 10.0 ]]]]),
6262 "y_true" : torch .tensor ([[[[0 , 0 ], [0 , 0 ]]]]), # GT is class 0, but Pred is class 1
6363 },
64- None ,
64+ None ,
6565 {"use_softmax" : True , "to_onehot_y" : True }
6666]
6767
@@ -99,11 +99,11 @@ def test_with_cuda(self):
9999 # Binary logits case on GPU
100100 i = torch .tensor ([[[[10.0 , 0 ], [0 , 10.0 ]]], [[[10.0 , 0 ], [0 , 10.0 ]]]]).cuda ()
101101 j = torch .tensor ([[[[1.0 , 0 ], [0 , 1.0 ]]], [[[1.0 , 0 ], [0 , 1.0 ]]]]).cuda ()
102-
102+
103103 output = loss (i , j )
104104 print (f"CUDA Output: { output .item ()} " )
105105 self .assertTrue (output .is_cuda )
106- self .assertLess (output .item (), 1.0 )
106+ self .assertLess (output .item (), 1.0 )
107107
108108if __name__ == "__main__" :
109- unittest .main ()
109+ unittest .main ()
0 commit comments