Skip to content

Commit 38d4c13

Browse files
committed
change tss formula
Signed-off-by: thibaultdvx <[email protected]>
1 parent 2b90caa commit 38d4c13

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

monai/metrics/r2_score.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutp
120120
def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float:
121121
num_obs = len(y)
122122
rss = np.sum((y_pred - y) ** 2)
123-
tss = np.sum(y**2) - np.sum(y) ** 2 / num_obs
123+
tss = np.sum((y - np.mean(y)) ** 2)
124124
r2 = 1 - (rss / tss)
125125
r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1)
126126

tests/handlers/test_handler_r2_score.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,6 @@ def test_compute(self):
6767
result = r2_score.compute()
6868
np.testing.assert_allclose(0.829185, result, rtol=1e-5)
6969

70+
7071
if __name__ == "__main__":
7172
unittest.main()

0 commit comments

Comments
 (0)