Skip to content

Commit c9ee60c

Browse files
committed
mypy issues
Signed-off-by: thibaultdvx <[email protected]>
1 parent b01b38b commit c9ee60c

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

monai/metrics/r2_score.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, multi_output: MultiOutput | str = MultiOutput.UNIFORM, p: int
6868
self.multi_output = multi_output
6969
self.p = p
7070

71-
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
71+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
7272
_check_dim(y_pred, y)
7373
return y_pred, y
7474

@@ -100,7 +100,7 @@ def _check_dim(y_pred: torch.Tensor, y: torch.Tensor) -> None:
100100
)
101101

102102

103-
def _check_r2_params(multi_output, p) -> tuple[MultiOutput, int]:
103+
def _check_r2_params(multi_output: MultiOutput | str, p: int) -> tuple[MultiOutput | str, int]:
104104
multi_output = look_up_option(multi_output, MultiOutput)
105105
if not isinstance(p, int) or p < 0:
106106
raise ValueError(f"`p` must be an integer larger or equal to 0, got {p}.")
@@ -115,7 +115,7 @@ def _calculate(y_pred: np.ndarray, y: np.ndarray, p: int) -> float:
115115
r2 = 1 - (rss / tss)
116116
r2_adjusted = 1 - (1 - r2) * (num_obs - 1) / (num_obs - p - 1)
117117

118-
return r2_adjusted
118+
return r2_adjusted # type: ignore[no-any-return]
119119

120120

121121
def compute_r2_score(
@@ -154,28 +154,29 @@ def compute_r2_score(
154154
_check_dim(y_pred, y)
155155
dim = y.ndimension()
156156
n = y.shape[0]
157-
y = y.cpu().numpy()
158-
y_pred = y_pred.cpu().numpy()
157+
y = y.cpu().numpy() # type: ignore[assignment]
158+
y_pred = y_pred.cpu().numpy() # type: ignore[assignment]
159159

160160
if n < 2:
161161
raise ValueError("There is no enough data for computing. Needs at least two samples to calculate r2 score.")
162162
if p >= n - 1:
163163
raise ValueError("`p` must be smaller than n_samples - 1, " f"got p={p}, n_samples={n}.")
164164

165165
if dim == 2 and y_pred.shape[1] == 1:
166-
y_pred = np.squeeze(y_pred, axis=-1)
167-
y = np.squeeze(y, axis=-1)
166+
y_pred = np.squeeze(y_pred, axis=-1) # type: ignore[assignment]
167+
y = np.squeeze(y, axis=-1) # type: ignore[assignment]
168168
dim = 1
169169

170170
if dim == 1:
171-
return _calculate(y_pred, y, p)
171+
return _calculate(y_pred, y, p) # type: ignore[arg-type]
172172

173-
y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0))
173+
y, y_pred = np.transpose(y, axes=(1, 0)), np.transpose(y_pred, axes=(1, 0)) # type: ignore[assignment]
174174
r2_values = [_calculate(y_pred_, y_, p) for y_pred_, y_ in zip(y_pred, y)]
175175
if multi_output == MultiOutput.RAW:
176176
return r2_values
177177
if multi_output == MultiOutput.UNIFORM:
178178
return np.mean(r2_values)
179-
if multi_output == multi_output.VARIANCE:
179+
if multi_output == MultiOutput.VARIANCE:
180180
weights = np.var(y, axis=1)
181-
return np.average(r2_values, weights=weights)
181+
return np.average(r2_values, weights=weights) # type: ignore[no-any-return]
182+
raise ValueError(f'Unsupported multi_output: {multi_output}, available options are ["raw_values", "uniform_average", "variance_weighted"].')

0 commit comments

Comments
 (0)