Skip to content

Commit a7b615e

Browse files
committed
Cleaning up autocast usage
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent c50b7aa commit a7b615e

File tree

11 files changed

+30
-32
lines changed

11 files changed

+30
-32
lines changed

monai/apps/deepedit/interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d
7272

7373
with torch.no_grad():
7474
if engine.amp:
75-
with torch.cuda.amp.autocast():
75+
with torch.autocast("cuda"):
7676
predictions = engine.inferer(inputs, engine.network)
7777
else:
7878
predictions = engine.inferer(inputs, engine.network)

monai/apps/deepgrow/interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d
6767
engine.network.eval()
6868
with torch.no_grad():
6969
if engine.amp:
70-
with torch.cuda.amp.autocast():
70+
with torch.autocast("cuda"):
7171
predictions = engine.inferer(inputs, engine.network)
7272
else:
7373
predictions = engine.inferer(inputs, engine.network)

monai/bundle/scripts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,9 +1256,8 @@ def verify_net_in_out(
12561256
if input_dtype == torch.float16:
12571257
# fp16 can only be executed in gpu mode
12581258
net.to("cuda")
1259-
from torch.cuda.amp import autocast
12601259

1261-
with autocast():
1260+
with torch.autocast("cuda"):
12621261
output = net(test_data.cuda(), **extra_forward_args_)
12631262
net.to(device_)
12641263
else:

monai/engines/evaluator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class Evaluator(Workflow):
8282
default to `True`.
8383
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
8484
`device`, `non_blocking`.
85-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
86-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
85+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
86+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
8787
8888
"""
8989

@@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
214214
default to `True`.
215215
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
216216
`device`, `non_blocking`.
217-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
218-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
217+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
218+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
219219
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
220220
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
221221
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -329,7 +329,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
329329
# execute forward computation
330330
with engine.mode(engine.network):
331331
if engine.amp:
332-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
332+
with torch.autocast("cuda", **engine.amp_kwargs):
333333
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
334334
else:
335335
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
@@ -399,8 +399,8 @@ class EnsembleEvaluator(Evaluator):
399399
default to `True`.
400400
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
401401
`device`, `non_blocking`.
402-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
403-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
402+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
403+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
404404
405405
"""
406406

@@ -492,7 +492,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tenso
492492
for idx, network in enumerate(engine.networks):
493493
with engine.mode(network):
494494
if engine.amp:
495-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
495+
with torch.autocast("cuda",**engine.amp_kwargs):
496496
if isinstance(engine.state.output, dict):
497497
engine.state.output.update(
498498
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}

monai/engines/trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ class SupervisedTrainer(Trainer):
126126
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
127127
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
128128
`device`, `non_blocking`.
129-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
130-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
129+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
130+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
131131
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
132132
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
133133
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -255,7 +255,7 @@ def _compute_pred_loss():
255255
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
256256

257257
if engine.amp and engine.scaler is not None:
258-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
258+
with torch.autocast("cuda",**engine.amp_kwargs):
259259
_compute_pred_loss()
260260
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
261261
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
@@ -341,8 +341,8 @@ class GanTrainer(Trainer):
341341
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
342342
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
343343
`device`, `non_blocking`.
344-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
345-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
344+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
345+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
346346
347347
"""
348348

@@ -518,8 +518,8 @@ class AdversarialTrainer(Trainer):
518518
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
519519
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
520520
`device`, `non_blocking`.
521-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
522-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
521+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
522+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
523523
"""
524524

525525
def __init__(
@@ -689,7 +689,7 @@ def _compute_generator_loss() -> None:
689689
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
690690

691691
if engine.amp and engine.state.g_scaler is not None:
692-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
692+
with torch.autocast("cuda",**engine.amp_kwargs):
693693
_compute_generator_loss()
694694

695695
engine.state.output[Keys.LOSS] = (
@@ -737,7 +737,7 @@ def _compute_discriminator_loss() -> None:
737737
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
738738

739739
if engine.amp and engine.state.d_scaler is not None:
740-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
740+
with torch.autocast("cuda",**engine.amp_kwargs):
741741
_compute_discriminator_loss()
742742

743743
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()

monai/engines/workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ class Workflow(Engine):
9090
default to `True`.
9191
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
9292
`device`, `non_blocking`.
93-
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
94-
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
93+
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
94+
https://pytorch.org/docs/stable/amp.html#torch.autocast.
9595
9696
Raises:
9797
TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.

monai/networks/layers/vector_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
100100
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
101101
102102
"""
103-
with torch.cuda.amp.autocast(enabled=False):
103+
with torch.autocast("cuda",enabled=False):
104104
encoding_indices_view = list(inputs.shape)
105105
del encoding_indices_view[1]
106106

@@ -138,7 +138,7 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
138138
Returns:
139139
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
140140
"""
141-
with torch.cuda.amp.autocast(enabled=False):
141+
with torch.autocast("cuda",enabled=False):
142142
embedding: torch.Tensor = (
143143
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
144144
)

monai/networks/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ def __init__(self, mod):
12381238

12391239
def forward(self, x):
12401240
dtype = x.dtype
1241-
with torch.amp.autocast("cuda", enabled=False):
1241+
with torch.autocast("cuda", enabled=False):
12421242
ret = self.mod.forward(x.to(torch.float32)).to(dtype)
12431243
return ret
12441244

@@ -1255,7 +1255,7 @@ def __init__(self, mod):
12551255

12561256
def forward(self, *args):
12571257
from_dtype = args[0].dtype
1258-
with torch.amp.autocast("cuda", enabled=False):
1258+
with torch.autocast("cuda", enabled=False):
12591259
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
12601260
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
12611261

tests/config/test_cv2_dist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import numpy as np
1717
import torch
1818
import torch.distributed as dist
19-
from torch.cuda.amp import autocast
2019

2120
# FIXME: test for the workaround of https://github.com/Project-MONAI/MONAI/issues/5291
2221
from monai.config.deviceconfig import print_config
@@ -33,7 +32,7 @@ def main_worker(rank, ngpus_per_node, port):
3332
model, device_ids=[rank], output_device=rank, find_unused_parameters=False
3433
)
3534
x = torch.ones(1, 1, 12, 12, 12).to(rank)
36-
with autocast(enabled=True):
35+
with torch.autocast("cuda"):
3736
model(x)
3837

3938
if dist.is_initialized():

tests/data/meta_tensor/test_meta_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_amp(self):
256256
conv = torch.nn.Conv2d(im.shape[1], 5, 3)
257257
conv.to(device)
258258
im_conv = conv(im)
259-
with torch.cuda.amp.autocast():
259+
with torch.autocast("cuda"):
260260
im_conv2 = conv(im)
261261
self.check(im_conv2, im_conv, ids=False, rtol=1e-2, atol=1e-2)
262262

0 commit comments

Comments
 (0)