Skip to content

Commit 2a7fdb5

Browse files
committed
fix issue in distributed test
Signed-off-by: thibaultdvx <[email protected]>
1 parent 5c260aa commit 2a7fdb5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/handlers/test_handler_r2_score.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_compute(self):
4949
if dist.get_rank() == 0:
5050
y_pred = [torch.tensor([0.1, 1.0], device=device), torch.tensor([-0.25, 0.5], device=device)]
5151
y = [torch.tensor([0.1, 0.82], device=device), torch.tensor([-0.2, 0.01], device=device)]
52+
r2_score.update([y_pred, y])
5253

5354
if dist.get_rank() == 1:
5455
y_pred = [
@@ -61,8 +62,7 @@ def test_compute(self):
6162
torch.tensor([1.58, 2.0], device=device),
6263
torch.tensor([-1.0, -0.1], device=device),
6364
]
64-
65-
r2_score.update([y_pred, y])
65+
r2_score.update([y_pred, y])
6666

6767
result = r2_score.compute()
6868
np.testing.assert_allclose(0.829185, result, rtol=1e-5)

0 commit comments

Comments
 (0)