From 58a3d69c88e4efccb455d1d5317a8e9bced4d156 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 8 Jan 2026 08:26:50 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/modules/distributions/continuous.py | 9 +++++---- torchrl/modules/distributions/truncated_normal.py | 10 ++++++---- torchrl/modules/distributions/utils.py | 8 +++++--- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 6ac5a991ead..99362a784d5 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -370,14 +370,15 @@ def __init__( if not all(high > low): raise RuntimeError(err_msg) + _non_blocking = loc.device.type == "cuda" if not isinstance(high, torch.Tensor): high = torch.as_tensor(high, device=loc.device) elif high.device != loc.device: - high = high.to(loc.device) + high = high.to(loc.device, non_blocking=_non_blocking) if not isinstance(low, torch.Tensor): low = torch.as_tensor(low, device=loc.device) elif low.device != loc.device: - low = low.to(loc.device) + low = low.to(loc.device, non_blocking=_non_blocking) if not is_compiling() and not safe_is_current_stream_capturing(): self.non_trivial_max = (high != 1.0).any() self.non_trivial_min = (low != -1.0).any() @@ -391,10 +392,10 @@ def __init__( self.upscale = ( upscale if not isinstance(upscale, torch.Tensor) - else upscale.to(self.device) + else upscale.to(self.device, non_blocking=_non_blocking) ) - low = low.to(loc.device) + low = low.to(loc.device, non_blocking=_non_blocking) self.low = low self.high = high diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index f8d481265cb..c270a83a0da 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -35,8 +35,9 @@ class TruncatedStandardNormal(Distribution): def __init__(self, a, b, validate_args=None, device=None): self.a, self.b = broadcast_all(a, b) - self.a = self.a.to(device) - self.b = self.b.to(device) + _non_blocking = device is not None and torch.device(device).type == "cuda" + self.a = self.a.to(device, non_blocking=_non_blocking) + self.b = self.b.to(device, non_blocking=_non_blocking) if isinstance(a, Number) and isinstance(b, Number): batch_shape = torch.Size() else: @@ -146,8 +147,9 @@ class TruncatedNormal(TruncatedStandardNormal): def __init__(self, loc, scale, a, b, validate_args=None, device=None): scale = scale.clamp_min(self.eps) self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) - a = a.to(device) - b = b.to(device) + _non_blocking = device is not None and torch.device(device).type == "cuda" + a = a.to(device, non_blocking=_non_blocking) + b = b.to(device, non_blocking=_non_blocking) self._non_std_a = a self._non_std_b = b a = (a - self.loc) / self.scale diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index a64d55276c3..9c387ed9b50 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -16,21 +16,23 @@ def _cast_device(elt: torch.Tensor | float, device) -> torch.Tensor | float: if isinstance(elt, torch.Tensor): - return elt.to(device) + _non_blocking = device is not None and torch.device(device).type == "cuda" + return elt.to(device, non_blocking=_non_blocking) return elt def _cast_transform_device(transform, device): if transform is None: return transform - elif isinstance(transform, d.ComposeTransform): + _non_blocking = device is not None and torch.device(device).type == "cuda" + if isinstance(transform, d.ComposeTransform): for i, t in enumerate(transform.parts): transform.parts[i] = _cast_transform_device(t, device) elif isinstance(transform, d.Transform): for attribute in dir(transform): value = getattr(transform, attribute) if isinstance(value, torch.Tensor): - setattr(transform, attribute, value.to(device)) + setattr(transform, attribute, value.to(device, non_blocking=_non_blocking)) return transform else: raise TypeError(