Skip to content

Commit 160df23

Browse files
committed
Fixes for mypy type checking
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent db1d2af commit 160df23

File tree

16 files changed

+99
-99
lines changed

16 files changed

+99
-99
lines changed

monai/apps/detection/networks/retinanet_network.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def __init__(
8888

8989
for layer in self.conv.children():
9090
if isinstance(layer, conv_type): # type: ignore
91-
torch.nn.init.normal_(layer.weight, std=0.01)
92-
torch.nn.init.constant_(layer.bias, 0)
91+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
92+
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
9393

9494
self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
9595
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
@@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):
167167

168168
for layer in self.conv.children():
169169
if isinstance(layer, conv_type): # type: ignore
170-
torch.nn.init.normal_(layer.weight, std=0.01)
171-
torch.nn.init.zeros_(layer.bias)
170+
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
171+
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]
172172

173173
def forward(self, x: list[Tensor]) -> list[Tensor]:
174174
"""
@@ -297,7 +297,7 @@ def __init__(
297297
)
298298
self.feature_extractor = feature_extractor
299299

300-
self.feature_map_channels: int = self.feature_extractor.out_channels
300+
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
301301
self.num_anchors = num_anchors
302302
self.classification_head = RetinaNetClassificationHead(
303303
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims

monai/apps/detection/utils/box_coder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,15 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:
221221

222222
pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
223223
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
224-
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
224+
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]
225225

226226
# When convert float32 to float16, Inf or Nan may occur
227227
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
228228
raise ValueError("pred_whd_axis is NaN or Inf.")
229229

230230
# Distance from center to box's corner.
231231
c_to_c_whd_axis = (
232-
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
232+
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
233233
)
234234

235235
pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)

monai/apps/reconstruction/networks/blocks/varnetblock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:
5555
Returns:
5656
Output of DC block with the same shape as x
5757
"""
58-
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
58+
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
5959

6060
def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
6161
"""

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
753753
if isinstance(_affine, torch.Tensor):
754754
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
755755
else:
756-
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
756+
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
757757
if suppress_zeros:
758758
spacing[spacing == 0] = 1.0
759759
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)

monai/data/video_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]:
177177
for codec, ext in all_codecs.items():
178178
writer = cv2.VideoWriter()
179179
fname = os.path.join(tmp_dir, f"test{ext}")
180-
fourcc = cv2.VideoWriter_fourcc(*codec)
180+
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
181181
noviderr = writer.open(fname, fourcc, 1, (10, 10))
182182
if noviderr:
183183
codecs[codec] = ext

monai/engines/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_nam
309309
self.scheduler = scheduler
310310

311311
def get_target(self, images, noise, timesteps):
312-
return self.scheduler.get_velocity(images, noise, timesteps)
312+
return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]
313313

314314

315315
def default_make_latent(

monai/inferers/inferer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ def sample(
882882
)
883883

884884
# 2. compute previous image: x_t -> x_t-1
885-
image, _ = scheduler.step(model_output, t, image)
885+
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
886886
if save_intermediates and t % intermediate_steps == 0:
887887
intermediates.append(image)
888888
if save_intermediates:
@@ -986,8 +986,8 @@ def get_likelihood(
986986
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
987987

988988
# get the posterior mean and variance
989-
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
990-
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
989+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
990+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
991991

992992
log_posterior_variance = torch.log(posterior_variance)
993993
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
@@ -1436,7 +1436,7 @@ def sample( # type: ignore[override]
14361436
)
14371437

14381438
# 3. compute previous image: x_t -> x_t-1
1439-
image, _ = scheduler.step(model_output, t, image)
1439+
image, _ = scheduler.step(model_output, t, image) # type: ignore[operator]
14401440
if save_intermediates and t % intermediate_steps == 0:
14411441
intermediates.append(image)
14421442
if save_intermediates:
@@ -1562,8 +1562,8 @@ def get_likelihood( # type: ignore[override]
15621562
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
15631563

15641564
# get the posterior mean and variance
1565-
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
1566-
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
1565+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator]
1566+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator]
15671567

15681568
log_posterior_variance = torch.log(posterior_variance)
15691569
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance

monai/inferers/merger.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def __init__(
5353
cropped_shape: Sequence[int] | None = None,
5454
device: torch.device | str | None = None,
5555
) -> None:
56-
self.merged_shape = merged_shape
57-
self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape
56+
self.merged_shape: tuple[int, ...] = tuple(merged_shape)
57+
self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)
5858
self.device = device
5959
self.is_finalized = False
6060

@@ -231,9 +231,9 @@ def __init__(
231231
dtype: np.dtype | str = "float32",
232232
value_dtype: np.dtype | str = "float32",
233233
count_dtype: np.dtype | str = "uint8",
234-
store: zarr.storage.Store | str = "merged.zarr",
235-
value_store: zarr.storage.Store | str | None = None,
236-
count_store: zarr.storage.Store | str | None = None,
234+
store: zarr.storage.Store | str = "merged.zarr", # type: ignore
235+
value_store: zarr.storage.Store | str | None = None, # type: ignore
236+
count_store: zarr.storage.Store | str | None = None, # type: ignore
237237
compressor: str | None = None,
238238
value_compressor: str | None = None,
239239
count_compressor: str | None = None,
@@ -251,18 +251,18 @@ def __init__(
251251
if version_geq(get_package_version("zarr"), "3.0.0"):
252252
if value_store is None:
253253
self.tmpdir = TemporaryDirectory()
254-
self.value_store = zarr.storage.LocalStore(self.tmpdir.name)
254+
self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
255255
else:
256-
self.value_store = value_store
256+
self.value_store = value_store # type: ignore
257257
if count_store is None:
258258
self.tmpdir = TemporaryDirectory()
259-
self.count_store = zarr.storage.LocalStore(self.tmpdir.name)
259+
self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
260260
else:
261-
self.count_store = count_store
261+
self.count_store = count_store # type: ignore
262262
else:
263263
self.tmpdir = None
264-
self.value_store = zarr.storage.TempStore() if value_store is None else value_store
265-
self.count_store = zarr.storage.TempStore() if count_store is None else count_store
264+
self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
265+
self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
266266
self.chunks = chunks
267267
self.compressor = compressor
268268
self.value_compressor = value_compressor
@@ -314,7 +314,7 @@ def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:
314314
map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)
315315
with self.lock:
316316
self.values[map_slice] += values.numpy()
317-
self.counts[map_slice] += 1
317+
self.counts[map_slice] += 1 # type: ignore[operator]
318318

319319
def finalize(self) -> zarr.Array:
320320
"""
@@ -332,7 +332,7 @@ def finalize(self) -> zarr.Array:
332332
if not self.is_finalized:
333333
# use chunks for division to fit into memory
334334
for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):
335-
self.output[chunk] = self.values[chunk] / self.counts[chunk]
335+
self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator]
336336
# finalize the shape
337337
self.output.resize(self.cropped_shape)
338338
# set finalize flag to protect performing in-place division again

monai/losses/sure_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Callable, Optional
14+
from typing import Callable, Optional, cast
1515

1616
import torch
1717
import torch.nn as nn
@@ -92,7 +92,7 @@ def sure_loss_function(
9292
y_ref = operator(x)
9393

9494
# get perturbed output
95-
x_perturbed = x + eps * perturb_noise
95+
x_perturbed = x + eps * perturb_noise # type: ignore
9696
y_perturbed = operator(x_perturbed)
9797
# divergence
9898
divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore

monai/networks/blocks/feature_pyramid_network.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454

5555
from collections import OrderedDict
5656
from collections.abc import Callable
57+
from typing import cast
5758

59+
import torch
5860
import torch.nn.functional as F
5961
from torch import Tensor, nn
6062

@@ -194,8 +196,8 @@ def __init__(
194196
conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]
195197
for m in self.modules():
196198
if isinstance(m, conv_type_):
197-
nn.init.kaiming_uniform_(m.weight, a=1)
198-
nn.init.constant_(m.bias, 0.0)
199+
nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)
200+
nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)
199201

200202
if extra_blocks is not None:
201203
if not isinstance(extra_blocks, ExtraFPNBlock):

0 commit comments

Comments
 (0)