Skip to content

Commit 4223f69

Browse files
committed
Refactor tensor handling in __to_tensor method to optimize device management for scalar metrics.
Signed-off-by: Wil Kong <[email protected]>
1 parent 90f7c03 commit 4223f69

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,11 +656,20 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
656656
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
657657

658658
def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
659-
value = (
660-
value.clone().detach()
661-
if isinstance(value, Tensor)
662-
else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
663-
)
659+
if isinstance(value, Tensor):
660+
# Keep tensor on its original device to avoid unnecessary transfers
661+
value = value.clone().detach()
662+
else:
663+
if self.device.type == "cuda":
664+
# Place scalar metrics on CPU to avoid CPU-GPU transfer and synchronization.
665+
# `torch.tensor(value, device="cuda")` contains such synchronization, while the metric
666+
# itself is only used on the CPU side. So placing metric on CPU for scalar inputs is more efficient.
667+
device = "cpu"
668+
else:
669+
# For non-CUDA devices, maintain original behavior
670+
device = self.device
671+
value = torch.tensor(value, device=device, dtype=_get_default_dtype())
672+
664673
if not torch.numel(value) == 1:
665674
raise ValueError(
666675
f"`self.log({name}, {value})` was called, but the tensor must have a single element."

0 commit comments

Comments
 (0)