@@ -126,8 +126,8 @@ class SupervisedTrainer(Trainer):
126126 more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
127127 to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
128128 `device`, `non_blocking`.
129- amp_kwargs: dict of the args for `torch.cuda.amp. autocast()` API, for more details:
130- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp. autocast.
129+ amp_kwargs: dict of the args for `torch.autocast("cuda" )` API, for more details:
130+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
131131 compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
132132 `torch.Tensor` before forward pass, then converted back afterward with copied meta information.
133133 compile_kwargs: dict of the args for `torch.compile()` API, for more details:
@@ -255,7 +255,7 @@ def _compute_pred_loss():
255255 engine .optimizer .zero_grad (set_to_none = engine .optim_set_to_none )
256256
257257 if engine .amp and engine .scaler is not None :
258- with torch .cuda . amp . autocast (** engine .amp_kwargs ):
258+ with torch .autocast ("cuda" , ** engine .amp_kwargs ):
259259 _compute_pred_loss ()
260260 engine .scaler .scale (engine .state .output [Keys .LOSS ]).backward ()
261261 engine .fire_event (IterationEvents .BACKWARD_COMPLETED )
@@ -341,8 +341,8 @@ class GanTrainer(Trainer):
341341 more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
342342 to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
343343 `device`, `non_blocking`.
344- amp_kwargs: dict of the args for `torch.cuda.amp. autocast()` API, for more details:
345- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp. autocast.
344+ amp_kwargs: dict of the args for `torch.autocast("cuda" )` API, for more details:
345+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
346346
347347 """
348348
@@ -518,8 +518,8 @@ class AdversarialTrainer(Trainer):
518518 more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
519519 to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
520520 `device`, `non_blocking`.
521- amp_kwargs: dict of the args for `torch.cuda.amp. autocast()` API, for more details:
522- https://pytorch.org/docs/stable/amp.html#torch.cuda.amp. autocast.
521+ amp_kwargs: dict of the args for `torch.autocast("cuda" )` API, for more details:
522+ https://pytorch.org/docs/stable/amp.html#torch.autocast.
523523 """
524524
525525 def __init__ (
@@ -689,7 +689,7 @@ def _compute_generator_loss() -> None:
689689 engine .state .g_optimizer .zero_grad (set_to_none = engine .optim_set_to_none )
690690
691691 if engine .amp and engine .state .g_scaler is not None :
692- with torch .cuda . amp . autocast (** engine .amp_kwargs ):
692+ with torch .autocast ("cuda" , ** engine .amp_kwargs ):
693693 _compute_generator_loss ()
694694
695695 engine .state .output [Keys .LOSS ] = (
@@ -737,7 +737,7 @@ def _compute_discriminator_loss() -> None:
737737 engine .state .d_network .zero_grad (set_to_none = engine .optim_set_to_none )
738738
739739 if engine .amp and engine .state .d_scaler is not None :
740- with torch .cuda . amp . autocast (** engine .amp_kwargs ):
740+ with torch .autocast ("cuda" , ** engine .amp_kwargs ):
741741 _compute_discriminator_loss ()
742742
743743 engine .state .d_scaler .scale (engine .state .output [AdversarialKeys .DISCRIMINATOR_LOSS ]).backward ()
0 commit comments