Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import importlib.util
from functools import partial

import pytest
import torch
Expand All @@ -16,6 +17,7 @@
from torch import autograd, nn
from torch.utils._pytree import tree_map
from torchrl.modules import (
IndependentNormal,
OneHotCategorical,
OneHotOrdinal,
Ordinal,
Expand Down Expand Up @@ -169,6 +171,184 @@ def test_tanhnormal_event_dims(self, event_dims):
exp_shape,
)

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
"callable_scale",
[torch.ones_like, partial(torch.full_like, fill_value=0.5)],
ids=["ones_like", "full_like_partial"],
)
def test_tanhnormal_callable_scale(self, device, callable_scale):
"""Test that TanhNormal supports callable scale for compile-friendliness.

Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...))
avoids explicit device transfers and prevents graph breaks in torch.compile.
"""
torch.manual_seed(0)
loc = torch.randn(3, 4, device=device)

# Create distribution with callable scale
dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1)

# Check that the scale was properly resolved
expected_scale = callable_scale(loc)
torch.testing.assert_close(dist.scale, expected_scale)

# Test sampling
sample = dist.sample()
assert sample.shape == loc.shape
assert sample.device == loc.device
assert (sample >= -1).all()
assert (sample <= 1).all()

# Test log_prob
log_prob = dist.log_prob(sample)
assert torch.isfinite(log_prob).all()

# Test rsample with gradient
loc_grad = torch.randn(3, 4, device=device, requires_grad=True)
dist_grad = TanhNormal(loc=loc_grad, scale=callable_scale, low=-1, high=1)
sample_grad = dist_grad.rsample()
loss = sample_grad.sum()
loss.backward()
assert loc_grad.grad is not None
assert torch.isfinite(loc_grad.grad).all()

@pytest.mark.parametrize("device", get_default_devices())
def test_tanhnormal_callable_scale_update(self, device):
"""Test that TanhNormal.update() works with callable scale."""
torch.manual_seed(0)
loc = torch.randn(3, 4, device=device)
callable_scale = torch.ones_like

dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1)

# Update with new loc and callable scale
new_loc = torch.randn(3, 4, device=device)
dist.update(new_loc, callable_scale)

# Check that scale was properly resolved
torch.testing.assert_close(dist.scale, torch.ones_like(new_loc))

# Verify distribution works after update
sample = dist.sample()
assert sample.shape == new_loc.shape
assert torch.isfinite(dist.log_prob(sample)).all()


class TestIndependentNormal:
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
"callable_scale",
[torch.ones_like, partial(torch.full_like, fill_value=0.5)],
ids=["ones_like", "full_like_partial"],
)
def test_independentnormal_callable_scale(self, device, callable_scale):
"""Test that IndependentNormal supports callable scale for compile-friendliness.

Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...))
avoids explicit device transfers and prevents graph breaks in torch.compile.
"""
torch.manual_seed(0)
loc = torch.randn(3, 4, device=device)

# Create distribution with callable scale
dist = IndependentNormal(loc=loc, scale=callable_scale)

# Check that the scale was properly resolved
expected_scale = callable_scale(loc)
torch.testing.assert_close(dist.base_dist.scale, expected_scale)

# Test sampling
sample = dist.sample()
assert sample.shape == loc.shape
assert sample.device == loc.device

# Test log_prob
log_prob = dist.log_prob(sample)
assert torch.isfinite(log_prob).all()

# Test rsample with gradient
loc_grad = torch.randn(3, 4, device=device, requires_grad=True)
dist_grad = IndependentNormal(loc=loc_grad, scale=callable_scale)
sample_grad = dist_grad.rsample()
loss = sample_grad.sum()
loss.backward()
assert loc_grad.grad is not None
assert torch.isfinite(loc_grad.grad).all()

@pytest.mark.parametrize("device", get_default_devices())
def test_independentnormal_callable_scale_update(self, device):
"""Test that IndependentNormal.update() works with callable scale."""
torch.manual_seed(0)
loc = torch.randn(3, 4, device=device)
callable_scale = torch.ones_like

dist = IndependentNormal(loc=loc, scale=callable_scale)

# Update with new loc and callable scale
new_loc = torch.randn(3, 4, device=device)
dist.update(new_loc, callable_scale)

# Check that scale was properly resolved
torch.testing.assert_close(dist.base_dist.scale, torch.ones_like(new_loc))

# Verify distribution works after update
sample = dist.sample()
assert sample.shape == new_loc.shape
assert torch.isfinite(dist.log_prob(sample)).all()

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"])
def test_independentnormal_scale_types(self, device, scale_type):
"""Test that IndependentNormal supports all scale types: tensor, float, callable."""
torch.manual_seed(0)
loc = torch.randn(3, 4, device=device)

if scale_type == "tensor":
scale = torch.ones(3, 4, device=device)
elif scale_type == "float":
scale = 1.0
else: # callable
scale = torch.ones_like

dist = IndependentNormal(loc=loc, scale=scale)

# Test sampling
sample = dist.sample()
assert sample.shape == loc.shape
assert sample.device == loc.device

# Test log_prob
log_prob = dist.log_prob(sample)
assert torch.isfinite(log_prob).all()

@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"])
def test_tanhnormal_scale_types(self, device, scale_type):
"""Test that TanhNormal supports all scale types: tensor, float, callable."""
torch.manual_seed(0)
loc = torch.randn(3, 4, device=device)

if scale_type == "tensor":
scale = torch.ones(3, 4, device=device)
elif scale_type == "float":
scale = 1.0
else: # callable
scale = torch.ones_like

dist = TanhNormal(loc=loc, scale=scale, low=-1, high=1)

# Test sampling
sample = dist.sample()
assert sample.shape == loc.shape
assert sample.device == loc.device
assert (sample >= -1).all()
assert (sample <= 1).all()

# Test log_prob
log_prob = dist.log_prob(sample)
assert torch.isfinite(log_prob).all()


class TestTruncatedNormal:
@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import warnings
import weakref
from warnings import warn
Expand Down
78 changes: 70 additions & 8 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import weakref
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from numbers import Number

import numpy as np
Expand Down Expand Up @@ -58,7 +58,11 @@ class IndependentNormal(D.Independent):

Args:
loc (torch.Tensor): normal distribution location parameter
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
in :func:`torch.compile`.
upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:

.. math::
Expand All @@ -69,14 +73,28 @@ class IndependentNormal(D.Independent):
tanh_loc (bool, optional): if ``False``, the above formula is used for
the location scaling, otherwise the raw value
is kept. Default is ``False``;

Example:
>>> import torch
>>> from functools import partial
>>> from torchrl.modules.distributions import IndependentNormal
>>> loc = torch.zeros(3, 4)
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
>>> dist = IndependentNormal(loc, scale=torch.ones_like)
>>> # For a custom scale value, use partial to create a callable
>>> dist = IndependentNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
>>> sample = dist.sample()
>>> sample.shape
torch.Size([3, 4])

"""

num_params: int = 2

def __init__(
self,
loc: torch.Tensor,
scale: torch.Tensor,
scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
upscale: float = 5.0,
tanh_loc: bool = False,
event_dim: int = 1,
Expand All @@ -86,11 +104,25 @@ def __init__(
self.upscale = upscale
self._event_dim = event_dim
self._kwargs = kwargs
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
if callable(scale) and not isinstance(scale, torch.Tensor):
scale = scale(loc)
elif not isinstance(scale, torch.Tensor):
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
elif scale.device != loc.device:
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
super().__init__(D.Normal(loc, scale, **kwargs), event_dim)

def update(self, loc, scale):
if self.tanh_loc:
loc = self.upscale * (loc / self.upscale).tanh()
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
if callable(scale) and not isinstance(scale, torch.Tensor):
scale = scale(loc)
elif not isinstance(scale, torch.Tensor):
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
elif scale.device != loc.device:
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
super().__init__(D.Normal(loc, scale, **self._kwargs), self._event_dim)

@property
Expand Down Expand Up @@ -316,7 +348,11 @@ class TanhNormal(FasterTransformedDistribution):

Args:
loc (torch.Tensor): normal distribution location parameter
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
in :func:`torch.compile`.
upscale (torch.Tensor or number): 'a' scaling factor in the formula:

.. math::
Expand All @@ -331,6 +367,20 @@ class TanhNormal(FasterTransformedDistribution):
value is kept. Default is ``False``;
safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
This will currently break with :func:`torch.compile`.

Example:
>>> import torch
>>> from functools import partial
>>> from torchrl.modules.distributions import TanhNormal
>>> loc = torch.zeros(3, 4)
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
>>> dist = TanhNormal(loc, scale=torch.ones_like)
>>> # For a custom scale value, use partial to create a callable
>>> dist = TanhNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
>>> sample = dist.sample()
>>> sample.shape
torch.Size([3, 4])

"""

arg_constraints = {
Expand All @@ -343,7 +393,7 @@ class TanhNormal(FasterTransformedDistribution):
def __init__(
self,
loc: torch.Tensor,
scale: torch.Tensor,
scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
upscale: torch.Tensor | Number = 5.0,
low: torch.Tensor | Number = -1.0,
high: torch.Tensor | Number = 1.0,
Expand All @@ -353,8 +403,14 @@ def __init__(
):
if not isinstance(loc, torch.Tensor):
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
if not isinstance(scale, torch.Tensor):
scale = torch.as_tensor(scale, dtype=torch.get_default_dtype())
_non_blocking = loc.device.type == "cuda"
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
if callable(scale) and not isinstance(scale, torch.Tensor):
scale = scale(loc)
elif not isinstance(scale, torch.Tensor):
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
elif scale.device != loc.device:
scale = scale.to(loc.device, non_blocking=_non_blocking)
if event_dims is None:
event_dims = min(1, loc.ndim)

Expand All @@ -370,7 +426,6 @@ def __init__(
if not all(high > low):
raise RuntimeError(err_msg)

_non_blocking = loc.device.type == "cuda"
if not isinstance(high, torch.Tensor):
high = torch.as_tensor(high, device=loc.device)
elif high.device != loc.device:
Expand Down Expand Up @@ -435,6 +490,13 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
# loc must be rescaled if tanh_loc
if is_compiling() or (self.non_trivial_max or self.non_trivial_min):
loc = loc + (self.high - self.low) / 2 + self.low
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
if callable(scale) and not isinstance(scale, torch.Tensor):
scale = scale(loc)
elif not isinstance(scale, torch.Tensor):
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
elif scale.device != loc.device:
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
self.loc = loc
self.scale = scale

Expand Down
4 changes: 3 additions & 1 deletion torchrl/modules/distributions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def _cast_transform_device(transform, device):
for attribute in dir(transform):
value = getattr(transform, attribute)
if isinstance(value, torch.Tensor):
setattr(transform, attribute, value.to(device, non_blocking=_non_blocking))
setattr(
transform, attribute, value.to(device, non_blocking=_non_blocking)
)
return transform
else:
raise TypeError(
Expand Down
Loading