Skip to content

Commit 7ee3bf1

Browse files
committed
Formatting
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent 199db95 commit 7ee3bf1

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

monai/engines/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tenso
492492
for idx, network in enumerate(engine.networks):
493493
with engine.mode(network):
494494
if engine.amp:
495-
with torch.autocast("cuda",**engine.amp_kwargs):
495+
with torch.autocast("cuda", **engine.amp_kwargs):
496496
if isinstance(engine.state.output, dict):
497497
engine.state.output.update(
498498
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}

monai/engines/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _compute_pred_loss():
255255
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
256256

257257
if engine.amp and engine.scaler is not None:
258-
with torch.autocast("cuda",**engine.amp_kwargs):
258+
with torch.autocast("cuda", **engine.amp_kwargs):
259259
_compute_pred_loss()
260260
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
261261
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
@@ -689,7 +689,7 @@ def _compute_generator_loss() -> None:
689689
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
690690

691691
if engine.amp and engine.state.g_scaler is not None:
692-
with torch.autocast("cuda",**engine.amp_kwargs):
692+
with torch.autocast("cuda", **engine.amp_kwargs):
693693
_compute_generator_loss()
694694

695695
engine.state.output[Keys.LOSS] = (
@@ -737,7 +737,7 @@ def _compute_discriminator_loss() -> None:
737737
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
738738

739739
if engine.amp and engine.state.d_scaler is not None:
740-
with torch.autocast("cuda",**engine.amp_kwargs):
740+
with torch.autocast("cuda", **engine.amp_kwargs):
741741
_compute_discriminator_loss()
742742

743743
engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()

monai/networks/layers/vector_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
100100
torch.Tensor: Quantization indices of shape [B,H,W,D,1]
101101
102102
"""
103-
with torch.autocast("cuda",enabled=False):
103+
with torch.autocast("cuda", enabled=False):
104104
encoding_indices_view = list(inputs.shape)
105105
del encoding_indices_view[1]
106106

@@ -138,7 +138,7 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
138138
Returns:
139139
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
140140
"""
141-
with torch.autocast("cuda",enabled=False):
141+
with torch.autocast("cuda", enabled=False):
142142
embedding: torch.Tensor = (
143143
self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
144144
)

0 commit comments

Comments
 (0)