Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions test/compile/test_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Tests for torch.compile compatibility of value estimation functions."""
from __future__ import annotations

import pytest
import torch

from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td_lambda_return_estimate,
vec_generalized_advantage_estimate,
vec_td_lambda_return_estimate,
)


class TestValueFunctionCompile:
"""Test compilation of value estimation functions."""

@pytest.fixture
def value_data(self):
"""Create test data for value functions."""
batch_size = 32
time_steps = 15
feature_dim = 1

return {
"gamma": 0.99,
"lmbda": 0.95,
"state_value": torch.randn(batch_size, time_steps, feature_dim),
"next_state_value": torch.randn(batch_size, time_steps, feature_dim),
"reward": torch.randn(batch_size, time_steps, feature_dim),
"done": torch.zeros(batch_size, time_steps, feature_dim, dtype=torch.bool),
"terminated": torch.zeros(
batch_size, time_steps, feature_dim, dtype=torch.bool
),
}

def test_td_lambda_return_estimate_compiles_fullgraph(self, value_data):
"""Test that td_lambda_return_estimate (non-vectorized) compiles with fullgraph=True."""
result_eager = td_lambda_return_estimate(
gamma=value_data["gamma"],
lmbda=value_data["lmbda"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

compiled_fn = torch.compile(
td_lambda_return_estimate,
fullgraph=True,
backend="inductor",
)

result_compiled = compiled_fn(
gamma=value_data["gamma"],
lmbda=value_data["lmbda"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4)

def test_generalized_advantage_estimate_compiles_fullgraph(self, value_data):
"""Test that generalized_advantage_estimate (non-vectorized) compiles with fullgraph=True."""
advantage_eager, value_target_eager = generalized_advantage_estimate(
gamma=value_data["gamma"],
lmbda=value_data["lmbda"],
state_value=value_data["state_value"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

compiled_fn = torch.compile(
generalized_advantage_estimate,
fullgraph=True,
backend="inductor",
)

advantage_compiled, value_target_compiled = compiled_fn(
gamma=value_data["gamma"],
lmbda=value_data["lmbda"],
state_value=value_data["state_value"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

torch.testing.assert_close(
advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4
)
torch.testing.assert_close(
value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4
)

def test_vec_td_lambda_return_estimate_fails_fullgraph(self, value_data):
"""Test that vec_td_lambda_return_estimate fails with fullgraph=True due to data-dependent shapes."""
compiled_fn = torch.compile(
vec_td_lambda_return_estimate,
fullgraph=True,
backend="inductor",
)

# This should fail because of data-dependent shapes in _get_num_per_traj
with pytest.raises(Exception):
compiled_fn(
gamma=value_data["gamma"],
lmbda=value_data["lmbda"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

def test_vec_generalized_advantage_estimate_fails_fullgraph(self, value_data):
"""Test that vec_generalized_advantage_estimate fails with fullgraph=True due to data-dependent shapes."""
compiled_fn = torch.compile(
vec_generalized_advantage_estimate,
fullgraph=True,
backend="inductor",
)

# This should fail because of data-dependent shapes in _get_num_per_traj
with pytest.raises(Exception):
compiled_fn(
gamma=value_data["gamma"],
lmbda=value_data["lmbda"],
state_value=value_data["state_value"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

def test_td_lambda_with_tensor_gamma_compiles_fullgraph(self, value_data):
"""Test that td_lambda_return_estimate compiles with 0-d tensor gamma (fullgraph=True).

This tests the fix for PendingUnbackedSymbolNotFound error that occurred when
torch.full_like received a 0-d tensor and internally called .item().
"""
# Use 0-d tensor gamma/lmbda - this was the problematic case
gamma_tensor = torch.tensor(value_data["gamma"])
lmbda_tensor = torch.tensor(value_data["lmbda"])

result_eager = td_lambda_return_estimate(
gamma=gamma_tensor,
lmbda=lmbda_tensor,
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

compiled_fn = torch.compile(
td_lambda_return_estimate,
fullgraph=True,
backend="inductor",
)

result_compiled = compiled_fn(
gamma=gamma_tensor,
lmbda=lmbda_tensor,
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

torch.testing.assert_close(result_eager, result_compiled, rtol=1e-4, atol=1e-4)

def test_gae_with_tensor_gamma_compiles_fullgraph(self, value_data):
"""Test that generalized_advantage_estimate compiles with 0-d tensor gamma (fullgraph=True).

This tests the fix for PendingUnbackedSymbolNotFound error that occurred when
torch.full_like received a 0-d tensor and internally called .item().
"""
# Use 0-d tensor gamma/lmbda - this was the problematic case
gamma_tensor = torch.tensor(value_data["gamma"])
lmbda_tensor = torch.tensor(value_data["lmbda"])

advantage_eager, value_target_eager = generalized_advantage_estimate(
gamma=gamma_tensor,
lmbda=lmbda_tensor,
state_value=value_data["state_value"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

compiled_fn = torch.compile(
generalized_advantage_estimate,
fullgraph=True,
backend="inductor",
)

advantage_compiled, value_target_compiled = compiled_fn(
gamma=gamma_tensor,
lmbda=lmbda_tensor,
state_value=value_data["state_value"],
next_state_value=value_data["next_state_value"],
reward=value_data["reward"],
done=value_data["done"],
terminated=value_data["terminated"],
)

torch.testing.assert_close(
advantage_eager, advantage_compiled, rtol=1e-4, atol=1e-4
)
torch.testing.assert_close(
value_target_eager, value_target_compiled, rtol=1e-4, atol=1e-4
)


class TestTDLambdaEstimatorCompile:
"""Test TDLambdaEstimator compile-friendly vectorized property."""

def test_vectorized_property_returns_true_in_eager_mode(self):
"""Test that TDLambdaEstimator.vectorized returns True in eager mode when set to True."""
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.objectives.value.advantages import TDLambdaEstimator

value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
estimator = TDLambdaEstimator(
gamma=0.99,
lmbda=0.95,
value_network=value_net,
vectorized=True,
)

assert estimator.vectorized is True
assert estimator._vectorized is True

def test_vectorized_property_returns_false_in_eager_mode_when_set_false(self):
"""Test that TDLambdaEstimator.vectorized returns False in eager mode when set to False."""
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.objectives.value.advantages import TDLambdaEstimator

value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
estimator = TDLambdaEstimator(
gamma=0.99,
lmbda=0.95,
value_network=value_net,
vectorized=False,
)

assert estimator.vectorized is False
assert estimator._vectorized is False

def test_vectorized_setter_works(self):
"""Test that TDLambdaEstimator.vectorized setter works correctly."""
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.objectives.value.advantages import TDLambdaEstimator

value_net = TensorDictModule(
nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)
estimator = TDLambdaEstimator(
gamma=0.99,
lmbda=0.95,
value_network=value_net,
vectorized=True,
)

assert estimator.vectorized is True

estimator.vectorized = False
assert estimator.vectorized is False
assert estimator._vectorized is False

estimator.vectorized = True
assert estimator.vectorized is True
assert estimator._vectorized is True


if __name__ == "__main__":
pytest.main([__file__, "-v"])
49 changes: 37 additions & 12 deletions torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,15 @@ def vec_generalized_advantage_estimate(
gammalmbdas = _make_gammas_tensor(gammalmbdas, time_steps, True)
gammalmbdas = gammalmbdas.cumprod(-2)

first_below_thr = gammalmbdas < 1e-7
# if we have multiple gammas, we only want to truncate if _all_ of
# the geometric sequences fall below the threshold
first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1)
if first_below_thr.any():
first_below_thr = torch.where(first_below_thr)[0][0].item()
gammalmbdas = gammalmbdas[..., :first_below_thr, :]
# Skip data-dependent truncation optimization during compile (causes guards)
if not is_dynamo_compiling():
first_below_thr = gammalmbdas < 1e-7
# if we have multiple gammas, we only want to truncate if _all_ of
# the geometric sequences fall below the threshold
first_below_thr = first_below_thr.flatten(0, 1).all(0).all(-1)
if first_below_thr.any():
first_below_thr = torch.where(first_below_thr)[0][0].item()
gammalmbdas = gammalmbdas[..., :first_below_thr, :]

not_terminated = (~terminated).to(dtype)
td0 = reward + not_terminated * gamma * next_state_value - state_value
Expand Down Expand Up @@ -524,7 +526,14 @@ def td1_return_estimate(
single_gamma = False
if not (isinstance(gamma, torch.Tensor) and gamma.shape == not_done.shape):
single_gamma = True
gamma = torch.full_like(next_state_value, gamma)
if isinstance(gamma, torch.Tensor):
# Use expand instead of full_like to avoid .item() call which creates
# unbacked symbols during torch.compile tracing.
if gamma.device != next_state_value.device:
gamma = gamma.to(next_state_value.device)
gamma = gamma.expand(next_state_value.shape)
else:
gamma = torch.full_like(next_state_value, gamma)

if rolling_gamma is None:
rolling_gamma = True
Expand Down Expand Up @@ -847,12 +856,26 @@ def td_lambda_return_estimate(
single_gamma = False
if not (isinstance(gamma, torch.Tensor) and gamma.shape == done.shape):
single_gamma = True
gamma = torch.full_like(next_state_value, gamma)
if isinstance(gamma, torch.Tensor):
# Use expand instead of full_like to avoid .item() call which creates
# unbacked symbols during torch.compile tracing.
if gamma.device != next_state_value.device:
gamma = gamma.to(next_state_value.device)
gamma = gamma.expand(next_state_value.shape)
else:
gamma = torch.full_like(next_state_value, gamma)

single_lambda = False
if not (isinstance(lmbda, torch.Tensor) and lmbda.shape == done.shape):
single_lambda = True
lmbda = torch.full_like(next_state_value, lmbda)
if isinstance(lmbda, torch.Tensor):
# Use expand instead of full_like to avoid .item() call which creates
# unbacked symbols during torch.compile tracing.
if lmbda.device != next_state_value.device:
lmbda = lmbda.to(next_state_value.device)
lmbda = lmbda.expand(next_state_value.shape)
else:
lmbda = torch.full_like(next_state_value, lmbda)

if rolling_gamma is None:
rolling_gamma = True
Expand Down Expand Up @@ -1004,7 +1027,8 @@ def _fast_td_lambda_return_estimate(
# the only valid next states are those where the trajectory does not terminate
next_state_value = (~terminated).int() * next_state_value

gamma_tensor = torch.tensor([gamma], device=device)
# Use torch.full to create directly on device (avoids DeviceCopy in cudagraph)
gamma_tensor = torch.full((1,), gamma, device=device)
gammalmbda = gamma_tensor * lmbda

num_per_traj = _get_num_per_traj(done)
Expand Down Expand Up @@ -1125,7 +1149,8 @@ def _is_scalar(tensor):

if rolling_gamma is None:
rolling_gamma = True
if not rolling_gamma:
if not rolling_gamma and not is_dynamo_compiling():
# Skip this validation during compile to avoid CUDA syncs
terminated_follows_terminated = terminated[..., 1:, :][
terminated[..., :-1, :]
].all()
Expand Down
Loading