1515
1616import numpy as np
1717import torch
18+ import torch .distributed as dist
1819
1920from monai .handlers import R2Score
21+ from tests .test_utils import DistCall , DistTestCase
2022
2123
2224class TestHandlerR2Score (unittest .TestCase ):
@@ -37,5 +39,33 @@ def test_compute(self):
3739 np .testing .assert_allclose (0.867314 , r2 , rtol = 1e-5 )
3840
3941
42+ class DistributedR2Score (DistTestCase ):
43+
44+ @DistCall (nnodes = 1 , nproc_per_node = 2 , node_rank = 0 )
45+ def test_compute (self ):
46+ r2_score = R2Score (multi_output = "variance_weighted" , p = 1 )
47+
48+ device = f"cuda:{ dist .get_rank ()} " if torch .cuda .is_available () else "cpu"
49+ if dist .get_rank () == 0 :
50+ y_pred = [torch .tensor ([0.1 , 1.0 ], device = device ), torch .tensor ([- 0.25 , 0.5 ], device = device )]
51+ y = [torch .tensor ([0.1 , 0.82 ], device = device ), torch .tensor ([- 0.2 , 0.01 ], device = device )]
52+
53+ if dist .get_rank () == 1 :
54+ y_pred = [
55+ torch .tensor ([3.0 , - 0.2 ], device = device ),
56+ torch .tensor ([0.99 , 2.1 ], device = device ),
57+ torch .tensor ([- 0.1 , 0.0 ], device = device ),
58+ ]
59+ y = [
60+ torch .tensor ([2.7 , - 0.1 ], device = device ),
61+ torch .tensor ([1.58 , 2.0 ], device = device ),
62+ torch .tensor ([- 1.0 , - 0.1 ], device = device ),
63+ ]
64+
65+ r2_score .update ([y_pred , y ])
66+
67+ result = r2_score .compute ()
68+ np .testing .assert_allclose (0.829185 , result , rtol = 1e-5 )
69+
4070if __name__ == "__main__" :
4171 unittest .main ()
0 commit comments