diff --git a/test/test_distributions.py b/test/test_distributions.py index 4540a49fb36..8cee2868b6f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -6,6 +6,7 @@ import argparse import importlib.util +from functools import partial import pytest import torch @@ -16,6 +17,7 @@ from torch import autograd, nn from torch.utils._pytree import tree_map from torchrl.modules import ( + IndependentNormal, OneHotCategorical, OneHotOrdinal, Ordinal, @@ -169,6 +171,184 @@ def test_tanhnormal_event_dims(self, event_dims): exp_shape, ) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "callable_scale", + [torch.ones_like, partial(torch.full_like, fill_value=0.5)], + ids=["ones_like", "full_like_partial"], + ) + def test_tanhnormal_callable_scale(self, device, callable_scale): + """Test that TanhNormal supports callable scale for compile-friendliness. + + Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...)) + avoids explicit device transfers and prevents graph breaks in torch.compile. + """ + torch.manual_seed(0) + loc = torch.randn(3, 4, device=device) + + # Create distribution with callable scale + dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1) + + # Check that the scale was properly resolved + expected_scale = callable_scale(loc) + torch.testing.assert_close(dist.scale, expected_scale) + + # Test sampling + sample = dist.sample() + assert sample.shape == loc.shape + assert sample.device == loc.device + assert (sample >= -1).all() + assert (sample <= 1).all() + + # Test log_prob + log_prob = dist.log_prob(sample) + assert torch.isfinite(log_prob).all() + + # Test rsample with gradient + loc_grad = torch.randn(3, 4, device=device, requires_grad=True) + dist_grad = TanhNormal(loc=loc_grad, scale=callable_scale, low=-1, high=1) + sample_grad = dist_grad.rsample() + loss = sample_grad.sum() + loss.backward() + assert loc_grad.grad is not None + assert torch.isfinite(loc_grad.grad).all() + + @pytest.mark.parametrize("device", get_default_devices()) + def test_tanhnormal_callable_scale_update(self, device): + """Test that TanhNormal.update() works with callable scale.""" + torch.manual_seed(0) + loc = torch.randn(3, 4, device=device) + callable_scale = torch.ones_like + + dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1) + + # Update with new loc and callable scale + new_loc = torch.randn(3, 4, device=device) + dist.update(new_loc, callable_scale) + + # Check that scale was properly resolved + torch.testing.assert_close(dist.scale, torch.ones_like(new_loc)) + + # Verify distribution works after update + sample = dist.sample() + assert sample.shape == new_loc.shape + assert torch.isfinite(dist.log_prob(sample)).all() + + +class TestIndependentNormal: + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "callable_scale", + [torch.ones_like, partial(torch.full_like, fill_value=0.5)], + ids=["ones_like", "full_like_partial"], + ) + def test_independentnormal_callable_scale(self, device, callable_scale): + """Test that IndependentNormal supports callable scale for compile-friendliness. + + Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...)) + avoids explicit device transfers and prevents graph breaks in torch.compile. + """ + torch.manual_seed(0) + loc = torch.randn(3, 4, device=device) + + # Create distribution with callable scale + dist = IndependentNormal(loc=loc, scale=callable_scale) + + # Check that the scale was properly resolved + expected_scale = callable_scale(loc) + torch.testing.assert_close(dist.base_dist.scale, expected_scale) + + # Test sampling + sample = dist.sample() + assert sample.shape == loc.shape + assert sample.device == loc.device + + # Test log_prob + log_prob = dist.log_prob(sample) + assert torch.isfinite(log_prob).all() + + # Test rsample with gradient + loc_grad = torch.randn(3, 4, device=device, requires_grad=True) + dist_grad = IndependentNormal(loc=loc_grad, scale=callable_scale) + sample_grad = dist_grad.rsample() + loss = sample_grad.sum() + loss.backward() + assert loc_grad.grad is not None + assert torch.isfinite(loc_grad.grad).all() + + @pytest.mark.parametrize("device", get_default_devices()) + def test_independentnormal_callable_scale_update(self, device): + """Test that IndependentNormal.update() works with callable scale.""" + torch.manual_seed(0) + loc = torch.randn(3, 4, device=device) + callable_scale = torch.ones_like + + dist = IndependentNormal(loc=loc, scale=callable_scale) + + # Update with new loc and callable scale + new_loc = torch.randn(3, 4, device=device) + dist.update(new_loc, callable_scale) + + # Check that scale was properly resolved + torch.testing.assert_close(dist.base_dist.scale, torch.ones_like(new_loc)) + + # Verify distribution works after update + sample = dist.sample() + assert sample.shape == new_loc.shape + assert torch.isfinite(dist.log_prob(sample)).all() + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"]) + def test_independentnormal_scale_types(self, device, scale_type): + """Test that IndependentNormal supports all scale types: tensor, float, callable.""" + torch.manual_seed(0) + loc = torch.randn(3, 4, device=device) + + if scale_type == "tensor": + scale = torch.ones(3, 4, device=device) + elif scale_type == "float": + scale = 1.0 + else: # callable + scale = torch.ones_like + + dist = IndependentNormal(loc=loc, scale=scale) + + # Test sampling + sample = dist.sample() + assert sample.shape == loc.shape + assert sample.device == loc.device + + # Test log_prob + log_prob = dist.log_prob(sample) + assert torch.isfinite(log_prob).all() + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"]) + def test_tanhnormal_scale_types(self, device, scale_type): + """Test that TanhNormal supports all scale types: tensor, float, callable.""" + torch.manual_seed(0) + loc = torch.randn(3, 4, device=device) + + if scale_type == "tensor": + scale = torch.ones(3, 4, device=device) + elif scale_type == "float": + scale = 1.0 + else: # callable + scale = torch.ones_like + + dist = TanhNormal(loc=loc, scale=scale, low=-1, high=1) + + # Test sampling + sample = dist.sample() + assert sample.shape == loc.shape + assert sample.device == loc.device + assert (sample >= -1).all() + assert (sample <= 1).all() + + # Test log_prob + log_prob = dist.log_prob(sample) + assert torch.isfinite(log_prob).all() + class TestTruncatedNormal: @pytest.mark.parametrize( diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 10cbdda7600..8a0c30785a5 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os import warnings import weakref from warnings import warn diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 99362a784d5..d6e537c8dc8 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -5,7 +5,7 @@ from __future__ import annotations import weakref -from collections.abc import Sequence +from collections.abc import Callable, Sequence from numbers import Number import numpy as np @@ -58,7 +58,11 @@ class IndependentNormal(D.Independent): Args: loc (torch.Tensor): normal distribution location parameter - scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) + scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance). + Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor. + Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``) + avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks + in :func:`torch.compile`. upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula: .. math:: @@ -69,6 +73,20 @@ class IndependentNormal(D.Independent): tanh_loc (bool, optional): if ``False``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; + + Example: + >>> import torch + >>> from functools import partial + >>> from torchrl.modules.distributions import IndependentNormal + >>> loc = torch.zeros(3, 4) + >>> # Using a callable scale avoids device transfers and graph breaks in torch.compile + >>> dist = IndependentNormal(loc, scale=torch.ones_like) + >>> # For a custom scale value, use partial to create a callable + >>> dist = IndependentNormal(loc, scale=partial(torch.full_like, fill_value=0.1)) + >>> sample = dist.sample() + >>> sample.shape + torch.Size([3, 4]) + """ num_params: int = 2 @@ -76,7 +94,7 @@ class IndependentNormal(D.Independent): def __init__( self, loc: torch.Tensor, - scale: torch.Tensor, + scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor], upscale: float = 5.0, tanh_loc: bool = False, event_dim: int = 1, @@ -86,11 +104,25 @@ def __init__( self.upscale = upscale self._event_dim = event_dim self._kwargs = kwargs + # Support callable scale (e.g., torch.ones_like) for compile-friendliness + if callable(scale) and not isinstance(scale, torch.Tensor): + scale = scale(loc) + elif not isinstance(scale, torch.Tensor): + scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype) + elif scale.device != loc.device: + scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda") super().__init__(D.Normal(loc, scale, **kwargs), event_dim) def update(self, loc, scale): if self.tanh_loc: loc = self.upscale * (loc / self.upscale).tanh() + # Support callable scale (e.g., torch.ones_like) for compile-friendliness + if callable(scale) and not isinstance(scale, torch.Tensor): + scale = scale(loc) + elif not isinstance(scale, torch.Tensor): + scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype) + elif scale.device != loc.device: + scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda") super().__init__(D.Normal(loc, scale, **self._kwargs), self._event_dim) @property @@ -316,7 +348,11 @@ class TanhNormal(FasterTransformedDistribution): Args: loc (torch.Tensor): normal distribution location parameter - scale (torch.Tensor): normal distribution sigma parameter (squared root of variance) + scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance). + Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor. + Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``) + avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks + in :func:`torch.compile`. upscale (torch.Tensor or number): 'a' scaling factor in the formula: .. math:: @@ -331,6 +367,20 @@ class TanhNormal(FasterTransformedDistribution): value is kept. Default is ``False``; safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows. This will currently break with :func:`torch.compile`. + + Example: + >>> import torch + >>> from functools import partial + >>> from torchrl.modules.distributions import TanhNormal + >>> loc = torch.zeros(3, 4) + >>> # Using a callable scale avoids device transfers and graph breaks in torch.compile + >>> dist = TanhNormal(loc, scale=torch.ones_like) + >>> # For a custom scale value, use partial to create a callable + >>> dist = TanhNormal(loc, scale=partial(torch.full_like, fill_value=0.1)) + >>> sample = dist.sample() + >>> sample.shape + torch.Size([3, 4]) + """ arg_constraints = { @@ -343,7 +393,7 @@ class TanhNormal(FasterTransformedDistribution): def __init__( self, loc: torch.Tensor, - scale: torch.Tensor, + scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor], upscale: torch.Tensor | Number = 5.0, low: torch.Tensor | Number = -1.0, high: torch.Tensor | Number = 1.0, @@ -353,8 +403,14 @@ def __init__( ): if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) - if not isinstance(scale, torch.Tensor): - scale = torch.as_tensor(scale, dtype=torch.get_default_dtype()) + _non_blocking = loc.device.type == "cuda" + # Support callable scale (e.g., torch.ones_like) for compile-friendliness + if callable(scale) and not isinstance(scale, torch.Tensor): + scale = scale(loc) + elif not isinstance(scale, torch.Tensor): + scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype) + elif scale.device != loc.device: + scale = scale.to(loc.device, non_blocking=_non_blocking) if event_dims is None: event_dims = min(1, loc.ndim) @@ -370,7 +426,6 @@ def __init__( if not all(high > low): raise RuntimeError(err_msg) - _non_blocking = loc.device.type == "cuda" if not isinstance(high, torch.Tensor): high = torch.as_tensor(high, device=loc.device) elif high.device != loc.device: @@ -435,6 +490,13 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: # loc must be rescaled if tanh_loc if is_compiling() or (self.non_trivial_max or self.non_trivial_min): loc = loc + (self.high - self.low) / 2 + self.low + # Support callable scale (e.g., torch.ones_like) for compile-friendliness + if callable(scale) and not isinstance(scale, torch.Tensor): + scale = scale(loc) + elif not isinstance(scale, torch.Tensor): + scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype) + elif scale.device != loc.device: + scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda") self.loc = loc self.scale = scale diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 9c387ed9b50..67999b73290 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -32,7 +32,9 @@ def _cast_transform_device(transform, device): for attribute in dir(transform): value = getattr(transform, attribute) if isinstance(value, torch.Tensor): - setattr(transform, attribute, value.to(device, non_blocking=_non_blocking)) + setattr( + transform, attribute, value.to(device, non_blocking=_non_blocking) + ) return transform else: raise TypeError(