diff --git a/test/_utils_internal.py b/test/_utils_internal.py deleted file mode 100644 index 19deb33546f..00000000000 --- a/test/_utils_internal.py +++ /dev/null @@ -1,831 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import contextlib -import logging -import os -import os.path -import sys -import time -import unittest -from collections.abc import Callable -from functools import wraps - -import pytest -import torch -import torch.cuda -from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase -from tensordict.nn import TensorDictModuleBase -from torch import nn, vmap - -from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator -from torchrl.envs import MultiThreadedEnv, ObservationNorm -from torchrl.envs.batched_envs import ParallelEnv, SerialEnv -from torchrl.envs.libs.envpool import _has_envpool -from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv -from torchrl.envs.transforms import ( - Compose, - RewardClipping, - ToTensorImage, - TransformedEnv, -) -from torchrl.modules import MLP -from torchrl.objectives.value.advantages import _vmap_func - -# Get relative file path -# this returns relative path from current file. - -# Specified for test_utils.py -__version__ = "0.3" - -IS_WIN = sys.platform == "win32" -if IS_WIN: - mp_ctx = "spawn" -else: - mp_ctx = "fork" - -PYTHON_3_9 = sys.version_info.major == 3 and sys.version_info.minor <= 9 - - -def CARTPOLE_VERSIONED(): - # load gym - if gym_backend() is not None: - _set_gym_environments() - return _CARTPOLE_VERSIONED - - -def HALFCHEETAH_VERSIONED(): - # load gym - if gym_backend() is not None: - _set_gym_environments() - return _HALFCHEETAH_VERSIONED - - -def PONG_VERSIONED(): - # load gym - # Gymnasium says that the ale_py behavior changes from 1.0 - # but with python 3.12 it is already the case with 0.29.1 - try: - import ale_py # noqa - except ImportError: - pass - - if gym_backend() is not None: - _set_gym_environments() - return _PONG_VERSIONED - - -def CLIFFWALKING_VERSIONED(): - if gym_backend() is not None: - _set_gym_environments() - return _CLIFFWALKING_VERSIONED - - -def BREAKOUT_VERSIONED(): - # load gym - # Gymnasium says that the ale_py behavior changes from 1.0 - # but with python 3.12 it is already the case with 0.29.1 - try: - import ale_py # noqa - except ImportError: - pass - - if gym_backend() is not None: - _set_gym_environments() - return _BREAKOUT_VERSIONED - - -def PENDULUM_VERSIONED(): - # load gym - if gym_backend() is not None: - _set_gym_environments() - return _PENDULUM_VERSIONED - - -def _set_gym_environments(): - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED - - _CARTPOLE_VERSIONED = None - _HALFCHEETAH_VERSIONED = None - _PENDULUM_VERSIONED = None - _PONG_VERSIONED = None - _BREAKOUT_VERSIONED = None - _CLIFFWALKING_VERSIONED = None - - -@implement_for("gym", None, "0.21.0") -def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED - - _CARTPOLE_VERSIONED = "CartPole-v0" - _HALFCHEETAH_VERSIONED = "HalfCheetah-v2" - _PENDULUM_VERSIONED = "Pendulum-v0" - _PONG_VERSIONED = "Pong-v4" - _BREAKOUT_VERSIONED = "Breakout-v4" - _CLIFFWALKING_VERSIONED = "CliffWalking-v0" - - -@implement_for("gym", "0.21.0", None) -def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED - - _CARTPOLE_VERSIONED = "CartPole-v1" - _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" - _PENDULUM_VERSIONED = "Pendulum-v1" - _PONG_VERSIONED = "ALE/Pong-v5" - _BREAKOUT_VERSIONED = "ALE/Breakout-v5" - _CLIFFWALKING_VERSIONED = "CliffWalking-v0" - - -@implement_for("gymnasium", None, "1.0.0") -def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED - - _CARTPOLE_VERSIONED = "CartPole-v1" - _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" - _PENDULUM_VERSIONED = "Pendulum-v1" - _PONG_VERSIONED = "ALE/Pong-v5" - _BREAKOUT_VERSIONED = "ALE/Breakout-v5" - _CLIFFWALKING_VERSIONED = "CliffWalking-v0" - - -@implement_for("gymnasium", "1.0.0", "1.1.0") -def _set_gym_environments(): # noqa: F811 - raise ImportError - - -@implement_for("gymnasium", "1.1.0") -def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED - - _CARTPOLE_VERSIONED = "CartPole-v1" - _HALFCHEETAH_VERSIONED = "HalfCheetah-v5" - _PENDULUM_VERSIONED = "Pendulum-v1" - _PONG_VERSIONED = "ALE/Pong-v5" - _BREAKOUT_VERSIONED = "ALE/Breakout-v5" - _CLIFFWALKING_VERSIONED = "CliffWalking-v1" if not PYTHON_3_9 else "CliffWalking-v0" - - -if _has_gym: - _set_gym_environments() - - -def get_relative_path(curr_file, *path_components): - return os.path.join(os.path.dirname(curr_file), *path_components) - - -def get_available_devices(): - devices = [torch.device("cpu")] - n_cuda = torch.cuda.device_count() - if n_cuda > 0: - for i in range(n_cuda): - devices += [torch.device(f"cuda:{i}")] - return devices - - -def get_default_devices(): - num_cuda = torch.cuda.device_count() - if num_cuda == 0: - # if torch.mps.is_available(): - # return [torch.device("mps:0")] - return [torch.device("cpu")] - elif num_cuda == 1: - return [torch.device("cuda:0")] - else: - # then run on all devices - return get_available_devices() - - -def generate_seeds(seed, repeat): - seeds = [seed] - for _ in range(repeat - 1): - seed = seed_generator(seed) - seeds.append(seed) - return seeds - - -# Decorator to retry upon certain Exceptions. -def retry( - ExceptionToCheck: type[Exception], - tries: int = 3, - delay: int = 3, - skip_after_retries: bool = False, -) -> Callable[[Callable], Callable]: - def deco_retry(f): - @wraps(f) - def f_retry(*args, **kwargs): - mtries, mdelay = tries, delay - while mtries > 1: - try: - return f(*args, **kwargs) - except ExceptionToCheck as e: - msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) - logger.info(msg) - time.sleep(mdelay) - mtries -= 1 - try: - return f(*args, **kwargs) - except ExceptionToCheck as e: - if skip_after_retries: - raise pytest.skip( - f"Skipping after {tries} consecutive {str(e)}" - ) from e - else: - raise e - - return f_retry # true decorator - - return deco_retry - - -# After calling this function, any log record whose name contains 'record_name' -# and is emitted from the logger that has qualified name 'logger_qname' is -# appended to the 'records' list. -# NOTE: This function is based on testing utilities for 'torch._logging' -def capture_log_records(records, logger_qname, record_name): - assert isinstance(records, list) - logger = logging.getLogger(logger_qname) - - class EmitWrapper: - def __init__(self, old_emit): - self.old_emit = old_emit - - def __call__(self, record): - nonlocal records # noqa: F824 - self.old_emit(record) - if record_name in record.name: - records.append(record) - - for handler in logger.handlers: - new_emit = EmitWrapper(handler.emit) - contextlib.ExitStack().enter_context( - unittest.mock.patch.object(handler, "emit", new_emit) - ) - - -@pytest.fixture -def dtype_fixture(): - dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.double) - yield dtype - torch.set_default_dtype(dtype) - - -@contextlib.contextmanager -def set_global_var(module, var_name, value): - old_value = getattr(module, var_name) - setattr(module, var_name, value) - try: - yield - finally: - setattr(module, var_name, old_value) - - -def _make_envs( - env_name, - frame_skip, - transformed_in, - transformed_out, - N, - device="cpu", - kwargs=None, - local_mp_ctx=mp_ctx, -): - torch.manual_seed(0) - if not transformed_in: - - def create_env_fn(): - return GymEnv(env_name, frame_skip=frame_skip, device=device) - - else: - if env_name == PONG_VERSIONED(): - - def create_env_fn(): - base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) - in_keys = list(base_env.observation_spec.keys(True, True))[:1] - return TransformedEnv( - base_env, - Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]), - ) - - else: - - def create_env_fn(): - - base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) - in_keys = list(base_env.observation_spec.keys(True, True))[:1] - - return TransformedEnv( - base_env, - Compose( - ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1), - RewardClipping(0, 0.1), - ), - ) - - env0 = create_env_fn() - env_parallel = ParallelEnv( - N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx - ) - env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs) - - for key in env0.observation_spec.keys(True, True): - obs_key = key - break - else: - obs_key = None - - if transformed_out: - t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key) - - env0 = TransformedEnv( - env0, - t_out(), - ) - env_parallel = TransformedEnv( - env_parallel, - t_out(), - ) - env_serial = TransformedEnv( - env_serial, - t_out(), - ) - else: - t_out = None - - if _has_envpool: - env_multithread = _make_multithreaded_env( - env_name, - frame_skip, - t_out, - N, - device="cpu", - kwargs=None, - ) - else: - env_multithread = None - - return env_parallel, env_serial, env_multithread, env0 - - -def _make_multithreaded_env( - env_name, - frame_skip, - transformed_out, - N, - device="cpu", - kwargs=None, -): - - torch.manual_seed(0) - multithreaded_kwargs = ( - {"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {} - ) - env_multithread = MultiThreadedEnv( - N, - env_name, - create_env_kwargs=multithreaded_kwargs, - device=device, - ) - - if transformed_out: - for key in env_multithread.observation_spec.keys(True, True): - obs_key = key - break - else: - obs_key = None - env_multithread = TransformedEnv( - env_multithread, - get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(), - ) - return env_multithread - - -def get_transform_out(env_name, transformed_in, obs_key=None): - - if env_name == PONG_VERSIONED(): - if obs_key is None: - obs_key = "pixels" - - def t_out(): - return ( - Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)]) - if not transformed_in - else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)]) - ) - - elif env_name == HALFCHEETAH_VERSIONED: - if obs_key is None: - obs_key = ("observation", "velocity") - - def t_out(): - return Compose( - ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1), - RewardClipping(0, 0.1), - ) - - else: - if obs_key is None: - obs_key = "observation" - - def t_out(): - return ( - Compose( - ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1), - RewardClipping(0, 0.1), - ) - if not transformed_in - else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0)) - ) - - return t_out - - -def make_tc(td): - """Makes a tensorclass from a tensordict instance.""" - - class MyClass: - pass - - MyClass.__annotations__ = {} - for key in td.keys(): - MyClass.__annotations__[key] = torch.Tensor - return tensorclass(MyClass) - - -def rollout_consistency_assertion( - rollout, *, done_key="done", observation_key="observation", done_strict=False -): - """Tests that observations in "next" match observations in the next root tensordict when done is False, and don't match otherwise.""" - - done = rollout[..., :-1]["next", done_key].squeeze(-1) - # data resulting from step, when it's not done - r_not_done = rollout[..., :-1]["next"][~done] - # data resulting from step, when it's not done, after step_mdp - r_not_done_tp1 = rollout[:, 1:][~done] - torch.testing.assert_close( - r_not_done[observation_key], - r_not_done_tp1[observation_key], - msg=f"Key {observation_key} did not match", - ) - - if done_strict and not done.any(): - raise RuntimeError("No done detected, test could not complete.") - if done.any(): - # data resulting from step, when it's done - r_done = rollout[..., :-1]["next"][done] - # data resulting from step, when it's done, after step_mdp and reset - r_done_tp1 = rollout[..., 1:][done] - # check that at least one obs after reset does not match the version before reset - assert not torch.isclose( - r_done[observation_key], r_done_tp1[observation_key] - ).all() - - -def rand_reset(env): - """Generates a tensordict with reset keys that mimic the done spec. - - Values are drawn at random until at least one reset is present. - - """ - full_done_spec = env.full_done_spec - result = {} - for reset_key, list_of_done in zip(env.reset_keys, env.done_keys_groups): - val = full_done_spec[list_of_done[0]].rand() - while not val.any(): - val = full_done_spec[list_of_done[0]].rand() - result[reset_key] = val - # create a data structure that keeps the batch size of the nested specs - result = ( - full_done_spec.zero().update(result).exclude(*full_done_spec.keys(True, True)) - ) - return result - - -def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int): - index_batch_size = (0,) * (len(td.batch_size) - 1) - - # Check done and reset for root - observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1 - next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) - assert (td["next", "done"][observation_is_max]).all() - assert (~td["next", "done"][~observation_is_max]).all() - # Obs after done is 0 - assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all() - # Obs after not done is previous obs - assert ( - td["observation"][index_batch_size][1:][~next_is_done] - == td["next", "observation"][index_batch_size][:-1][~next_is_done] - ).all() - # Check observation and reward update with count action for root - action_is_count = td["action"].long().argmax(-1).to(torch.bool) - assert ( - td["next", "observation"][action_is_count] - == td["observation"][action_is_count] + 1 - ).all() - assert (td["next", "reward"][action_is_count] == 1).all() - # Check observation and reward do not update with no-count action for root - assert ( - td["next", "observation"][~action_is_count] - == td["observation"][~action_is_count] - ).all() - assert (td["next", "reward"][~action_is_count] == 0).all() - - # Check done and reset for nested_1 - observation_is_max = td["next", "nested_1", "observation"][..., 0] == max_steps + 1 - # done at the root always prevail - next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) - assert (td["next", "nested_1", "done"][observation_is_max]).all() - assert (~td["next", "nested_1", "done"][~observation_is_max]).all() - # Obs after done is 0 - assert ( - td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0 - ).all() - # Obs after not done is previous obs - assert ( - td["nested_1", "observation"][index_batch_size][1:][~next_is_done] - == td["next", "nested_1", "observation"][index_batch_size][:-1][~next_is_done] - ).all() - # Check observation and reward update with count action for nested_1 - action_is_count = td["nested_1"]["action"].to(torch.bool) - assert ( - td["next", "nested_1", "observation"][action_is_count] - == td["nested_1", "observation"][action_is_count] + 1 - ).all() - assert (td["next", "nested_1", "gift"][action_is_count] == 1).all() - # Check observation and reward do not update with no-count action for nested_1 - assert ( - td["next", "nested_1", "observation"][~action_is_count] - == td["nested_1", "observation"][~action_is_count] - ).all() - assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all() - - # Check done and reset for nested_2 - observation_is_max = td["next", "nested_2", "observation"][..., 0] == max_steps + 1 - # done at the root always prevail - next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) - assert (td["next", "nested_2", "done"][observation_is_max]).all() - assert (~td["next", "nested_2", "done"][~observation_is_max]).all() - # Obs after done is 0 - assert ( - td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0 - ).all() - # Obs after not done is previous obs - assert ( - td["nested_2", "observation"][index_batch_size][1:][~next_is_done] - == td["next", "nested_2", "observation"][index_batch_size][:-1][~next_is_done] - ).all() - # Check observation and reward update with count action for nested_2 - action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool) - assert ( - td["next", "nested_2", "observation"][action_is_count] - == td["nested_2", "observation"][action_is_count] + 1 - ).all() - assert (td["next", "nested_2", "reward"][action_is_count] == 1).all() - # Check observation and reward do not update with no-count action for nested_2 - assert ( - td["next", "nested_2", "observation"][~action_is_count] - == td["nested_2", "observation"][~action_is_count] - ).all() - assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all() - - -class LSTMNet(nn.Module): - """An embedder for an LSTM preceded by an MLP. - - The forward method returns the hidden states of the current state - (input hidden states) and the output, as - the environment returns the 'observation' and 'next_observation'. - - Because the LSTM kernel only returns the last hidden state, hidden states - are padded with zeros such that they have the right size to be stored in a - TensorDict of size [batch x time_steps]. - - If a 2D tensor is provided as input, it is assumed that it is a batch of data - with only one time step. This means that we explicitly assume that users will - unsqueeze inputs of a single batch with multiple time steps. - - Args: - out_features (int): number of output features. - lstm_kwargs (dict): the keyword arguments for the - :class:`~torch.nn.LSTM` layer. - mlp_kwargs (dict): the keyword arguments for the - :class:`~torchrl.modules.MLP` layer. - device (torch.device, optional): the device where the module should - be instantiated. - - Keyword Args: - lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that - indeicates where the LSTM class is to be retrieved. The ``"torchrl"`` - backend (:class:`~torchrl.modules.LSTM`) is slower but works with - :func:`~torch.vmap` and should work with :func:`~torch.compile`. - Defaults to ``"torch"``. - - Examples: - >>> batch = 7 - >>> time_steps = 6 - >>> in_features = 4 - >>> out_features = 10 - >>> hidden_size = 5 - >>> net = LSTMNet( - ... out_features, - ... {"input_size": hidden_size, "hidden_size": hidden_size}, - ... {"out_features": hidden_size}, - ... ) - >>> # test single step vs multi-step - >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step - >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) - >>> x = torch.randn(batch, in_features) # 2 dims = single step - >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) - - """ - - def __init__( - self, - out_features: int, - lstm_kwargs, - mlp_kwargs, - device=None, - *, - lstm_backend: str | None = None, - ) -> None: - super().__init__() - lstm_kwargs.update({"batch_first": True}) - self.mlp = MLP(device=device, **mlp_kwargs) - if lstm_backend is None: - lstm_backend = "torch" - self.lstm_backend = lstm_backend - if self.lstm_backend == "torch": - LSTM = nn.LSTM - else: - from torchrl.modules.tensordict_module.rnn import LSTM - self.lstm = LSTM(device=device, **lstm_kwargs) - self.linear = nn.LazyLinear(out_features, device=device) - - def _lstm( - self, - input: torch.Tensor, - hidden0_in: torch.Tensor | None = None, - hidden1_in: torch.Tensor | None = None, - ): - squeeze0 = False - squeeze1 = False - if input.ndimension() == 1: - squeeze0 = True - input = input.unsqueeze(0).contiguous() - - if input.ndimension() == 2: - squeeze1 = True - input = input.unsqueeze(1).contiguous() - batch, steps = input.shape[:2] - - if hidden1_in is None and hidden0_in is None: - shape = (batch, steps) if not squeeze1 else (batch,) - hidden0_in, hidden1_in = ( - torch.zeros( - *shape, - self.lstm.num_layers, - self.lstm.hidden_size, - device=input.device, - dtype=input.dtype, - ) - for _ in range(2) - ) - elif hidden1_in is None or hidden0_in is None: - raise RuntimeError( - f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" - ) - elif squeeze0: - hidden0_in = hidden0_in.unsqueeze(0) - hidden1_in = hidden1_in.unsqueeze(0) - - # we only need the first hidden state - if not squeeze1: - _hidden0_in = hidden0_in[:, 0] - _hidden1_in = hidden1_in[:, 0] - else: - _hidden0_in = hidden0_in - _hidden1_in = hidden1_in - hidden = ( - _hidden0_in.transpose(-3, -2).contiguous(), - _hidden1_in.transpose(-3, -2).contiguous(), - ) - - y0, hidden = self.lstm(input, hidden) - # dim 0 in hidden is num_layers, but that will conflict with tensordict - hidden = tuple(_h.transpose(0, 1) for _h in hidden) - y = self.linear(y0) - - out = [y, hidden0_in, hidden1_in, *hidden] - if squeeze1: - # squeezes time - out[0] = out[0].squeeze(1) - if not squeeze1: - # we pad the hidden states with zero to make tensordict happy - for i in range(3, 5): - out[i] = torch.stack( - [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] - + [out[i]], - 1, - ) - if squeeze0: - out = [_out.squeeze(0) for _out in out] - return tuple(out) - - def forward( - self, - input: torch.Tensor, - hidden0_in: torch.Tensor | None = None, - hidden1_in: torch.Tensor | None = None, - ): - input = self.mlp(input) - return self._lstm(input, hidden0_in, hidden1_in) - - -def _call_value_nets( - value_net: TensorDictModuleBase, - data: TensorDictBase, - params: TensorDictBase, - next_params: TensorDictBase, - single_call: bool, - value_key: NestedKey, - detach_next: bool, - vmap_randomness: str = "error", -): - in_keys = value_net.in_keys - if single_call: - for i, name in enumerate(data.names): - if name == "time": - ndim = i + 1 - break - else: - ndim = None - if ndim is not None: - # get data at t and last of t+1 - idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) - idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) - idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False)[idx0], - ], - ndim - 1, - ) - else: - if RL_WARNINGS: - logger.warning( - "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " - "This warning can be turned off by setting the environment variable RL_WARNINGS to False." - ) - ndim = data.ndim - idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) - idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) - data_in = torch.cat( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - ndim - 1, - ) - - # next_params should be None or be identical to params - if next_params is not None and next_params is not params: - raise ValueError( - "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." - ) - if params is not None: - with params.to_module(value_net): - value_est = value_net(data_in).get(value_key) - else: - value_est = value_net(data_in).get(value_key) - value, value_ = value_est[idx], value_est[idx_] - else: - data_in = torch.stack( - [ - data.select(*in_keys, value_key, strict=False), - data.get("next").select(*in_keys, value_key, strict=False), - ], - 0, - ) - if (params is not None) ^ (next_params is not None): - raise ValueError( - "params and next_params must be either both provided or not." - ) - elif params is not None: - params_stack = torch.stack([params, next_params], 0).contiguous() - data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( - data_in, params_stack - ) - else: - data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) - value_est = data_out.get(value_key) - value, value_ = value_est[0], value_est[1] - data.set(value_key, value) - data.set(("next", value_key), value_) - if detach_next: - value_ = value_.detach() - return value, value_ diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index 941107199fe..3381088d09b 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse -import os import sys import tempfile @@ -65,10 +64,7 @@ def test_gym(): import ale_py # noqa: F401 except Exception: # pragma: no cover pytest.skip("ALE not available (missing ale_py); skipping Atari gym test.") - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import PONG_VERSIONED - else: - from _utils_internal import PONG_VERSIONED + from torchrl.testing import PONG_VERSIONED try: env = GymEnv(PONG_VERSIONED()) diff --git a/test/test_actors.py b/test/test_actors.py index daec1718ddc..df8281e5d22 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -6,7 +6,6 @@ import argparse import importlib.util -import os import pytest import torch @@ -33,10 +32,7 @@ ValueOperator, ) -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import get_default_devices -else: - from _utils_internal import get_default_devices +from torchrl.testing import get_default_devices from torchrl.testing.mocking_classes import NestedCountingEnv _has_vllm = importlib.util.find_spec("vllm") is not None diff --git a/test/test_collectors.py b/test/test_collectors.py index 7ef4d98e0a1..1981ac5a552 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -8,7 +8,6 @@ import contextlib import functools import gc -import os import subprocess import sys import time @@ -88,37 +87,17 @@ RandomPolicy, SafeModule, ) -from torchrl.testing.modules import BiasModule, NonSerializableBiasModule -from torchrl.testing.mp_helpers import decorate_thread_sub_func -from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, -) -if os.getenv("PYTORCH_TEST_FBCODE"): - IS_FB = True - from pytorch.rl.test._utils_internal import ( - CARTPOLE_VERSIONED, - check_rollout_consistency_multikey_env, - generate_seeds, - get_available_devices, - get_default_devices, - LSTMNet, - PENDULUM_VERSIONED, - retry, - ) -else: - IS_FB = False - from _utils_internal import ( - CARTPOLE_VERSIONED, - check_rollout_consistency_multikey_env, - generate_seeds, - get_available_devices, - get_default_devices, - LSTMNet, - PENDULUM_VERSIONED, - retry, - ) +from torchrl.testing import ( + CARTPOLE_VERSIONED, + check_rollout_consistency_multikey_env, + generate_seeds, + get_available_devices, + get_default_devices, + LSTMNet, + PENDULUM_VERSIONED, + retry, +) from torchrl.testing.mocking_classes import ( ContinuousActionVecMockEnv, CountingBatchedEnv, @@ -137,6 +116,12 @@ MultiKeyCountingEnvPolicy, NestedCountingEnv, ) +from torchrl.testing.modules import BiasModule, NonSerializableBiasModule +from torchrl.testing.mp_helpers import decorate_thread_sub_func +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, +) # torch.set_default_dtype(torch.double) IS_WINDOWS = sys.platform == "win32" @@ -864,7 +849,6 @@ def test_env_that_errors(self, ctype): break @retry(AssertionError, tries=10, delay=0) - @pytest.mark.skipif(IS_FB, reason="Not compatible with fbcode") @pytest.mark.parametrize("to", [3, 10]) @pytest.mark.parametrize( "collector_cls", ["MultiSyncCollector", "MultiAsyncCollector"] diff --git a/test/test_distributions.py b/test/test_distributions.py index eda1b63a5d6..4540a49fb36 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -6,7 +6,6 @@ import argparse import importlib.util -import os import pytest import torch @@ -36,10 +35,7 @@ LLMMaskedCategorical, ) -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import get_default_devices -else: - from _utils_internal import get_default_devices +from torchrl.testing import get_default_devices _has_scipy = importlib.util.find_spec("scipy", None) is not None diff --git a/test/test_envs.py b/test/test_envs.py index fab8e48fe36..e7a4337229c 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -174,28 +174,16 @@ def check_no_lingering_multiprocessing_resources(request): _atari_found = False atari_confs = defaultdict(str) -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( - _make_envs, - CARTPOLE_VERSIONED, - check_rollout_consistency_multikey_env, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - ) -else: - from _utils_internal import ( - _make_envs, - CARTPOLE_VERSIONED, - check_rollout_consistency_multikey_env, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - ) +from torchrl.testing import ( + CARTPOLE_VERSIONED, + check_rollout_consistency_multikey_env, + get_default_devices, + HALFCHEETAH_VERSIONED, + make_envs as _make_envs, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, +) from torchrl.testing.mocking_classes import ( ActionObsMergeLinear, AutoResetHeteroCountingEnv, diff --git a/test/test_exploration.py b/test/test_exploration.py index 4be173e957d..377cc7c1ec6 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse -import os import pytest import torch @@ -41,10 +40,7 @@ OrnsteinUhlenbeckProcessModule, ) -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import get_default_devices -else: - from _utils_internal import get_default_devices +from torchrl.testing import get_default_devices from torchrl.testing.mocking_classes import ( ContinuousActionVecMockEnv, CountingEnvCountModule, diff --git a/test/test_helpers.py b/test/test_helpers.py index 026b3fc6490..d9665770891 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -6,7 +6,6 @@ import argparse import dataclasses -import os import pathlib import sys from time import sleep @@ -24,6 +23,15 @@ FlattenObservation, TransformedEnv, ) + +from torchrl.testing import generate_seeds, get_default_devices +from torchrl.testing.mocking_classes import ( + ContinuousActionConvMockEnvNumpy, + ContinuousActionVecMockEnv, + DiscreteActionConvMockEnvNumpy, + DiscreteActionVecMockEnv, + MockSerialEnv, +) from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.helpers.envs import ( EnvConfig, @@ -36,18 +44,6 @@ make_dqn_actor, ) -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import generate_seeds, get_default_devices -else: - from _utils_internal import generate_seeds, get_default_devices -from torchrl.testing.mocking_classes import ( - ContinuousActionConvMockEnvNumpy, - ContinuousActionVecMockEnv, - DiscreteActionConvMockEnvNumpy, - DiscreteActionVecMockEnv, - MockSerialEnv, -) - try: from hydra import compose, initialize from hydra.core.config_store import ConfigStore diff --git a/test/test_libs.py b/test/test_libs.py index 8c302c8726a..6ae625b381a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -136,34 +136,19 @@ _has_ray = importlib.util.find_spec("ray") is not None _has_ale = importlib.util.find_spec("ale_py") is not None _has_mujoco = importlib.util.find_spec("mujoco") is not None -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( - _make_multithreaded_env, - CARTPOLE_VERSIONED, - CLIFFWALKING_VERSIONED, - get_available_devices, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - rollout_consistency_assertion, - ) -else: - from _utils_internal import ( - _make_multithreaded_env, - CARTPOLE_VERSIONED, - CLIFFWALKING_VERSIONED, - get_available_devices, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - rollout_consistency_assertion, - ) +from torchrl.testing import ( + CARTPOLE_VERSIONED, + CLIFFWALKING_VERSIONED, + get_available_devices, + get_default_devices, + HALFCHEETAH_VERSIONED, + make_multithreaded_env as _make_multithreaded_env, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, + retry, + rollout_consistency_assertion, +) TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) diff --git a/test/test_modules.py b/test/test_modules.py index 9f7c47f9566..ff44b0266a8 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse -import os import re from numbers import Number @@ -58,12 +57,9 @@ from torchrl.modules.planners.mppi import MPPIPlanner from torchrl.objectives.value import TDLambdaEstimator -from torchrl.testing.mocking_classes import MockBatchedUnLockedEnv +from torchrl.testing import get_default_devices, retry -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import get_default_devices, retry -else: - from _utils_internal import get_default_devices, retry +from torchrl.testing.mocking_classes import MockBatchedUnLockedEnv @pytest.fixture diff --git a/test/test_objectives.py b/test/test_objectives.py index d1b354b4c0b..fe3ef984ec2 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -10,7 +10,6 @@ import importlib.util import itertools import operator -import os import sys import warnings from copy import deepcopy @@ -155,22 +154,13 @@ _split_and_pad_sequence, ) -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( # noqa - _call_value_nets, - dtype_fixture, - get_available_devices, - get_default_devices, - PENDULUM_VERSIONED, - ) -else: - from _utils_internal import ( # noqa - _call_value_nets, - dtype_fixture, - get_available_devices, - get_default_devices, - PENDULUM_VERSIONED, - ) +from torchrl.testing import ( # noqa + call_value_nets as _call_value_nets, + dtype_fixture, + get_available_devices, + get_default_devices, + PENDULUM_VERSIONED, +) from torchrl.testing.mocking_classes import ContinuousActionConvMockEnv _has_functorch = True diff --git a/test/test_postprocs.py b/test/test_postprocs.py index 37a8bcb166b..f6409249039 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -6,7 +6,6 @@ import argparse import functools -import os import pytest import torch @@ -16,10 +15,7 @@ from torchrl.collectors.utils import split_trajectories from torchrl.data.postprocs.postprocs import DensifyReward, MultiStep -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import get_default_devices -else: - from _utils_internal import get_default_devices +from torchrl.testing import get_default_devices class TestMultiStep: diff --git a/test/test_rb.py b/test/test_rb.py index 1c66e9b0119..3c7d2c850b3 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -109,21 +109,12 @@ ) from torchrl.modules import RandomPolicy - -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( - capture_log_records, - CARTPOLE_VERSIONED, - get_default_devices, - make_tc, - ) -else: - from _utils_internal import ( - capture_log_records, - CARTPOLE_VERSIONED, - get_default_devices, - make_tc, - ) +from torchrl.testing import ( + capture_log_records, + CARTPOLE_VERSIONED, + get_default_devices, + make_tc, +) from torchrl.testing.mocking_classes import CountingEnv OLD_TORCH = parse(torch.__version__) < parse("2.0.0") @@ -3956,10 +3947,7 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls): def test_rb_multidim_collector( self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device ): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED torch.manual_seed(0) env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device) diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 8d2e102f327..89d627eda82 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse -import os import zipfile from copy import deepcopy from pathlib import Path @@ -34,10 +33,7 @@ from torchrl.data.llm.utils import RolloutFromModel from torchrl.modules.models.llm import GPT2RewardModel -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import get_default_devices -else: - from _utils_internal import get_default_devices +from torchrl.testing import get_default_devices HERE = Path(__file__).parent diff --git a/test/test_specs.py b/test/test_specs.py index 691d52ffddc..ba68e0a3fb2 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -6,7 +6,6 @@ import argparse import contextlib -import os import numpy as np import pytest @@ -42,18 +41,7 @@ ) from torchrl.data.utils import check_no_exclusive_keys, consolidate_spec -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( - get_available_devices, - get_default_devices, - set_global_var, - ) -else: - from _utils_internal import ( - get_available_devices, - get_default_devices, - set_global_var, - ) +from torchrl.testing import get_available_devices, get_default_devices, set_global_var pytestmark = [ pytest.mark.filterwarnings("error"), diff --git a/test/test_storage_map.py b/test/test_storage_map.py index d8bd63db4ba..2311d7d1200 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -7,7 +7,6 @@ import argparse import functools import importlib.util -import os import sys import pytest @@ -25,10 +24,7 @@ ) from torchrl.envs import GymEnv -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import PENDULUM_VERSIONED -else: - from _utils_internal import PENDULUM_VERSIONED +from torchrl.testing import PENDULUM_VERSIONED _has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec( "gym", None diff --git a/test/test_trainer.py b/test/test_trainer.py index dd77db913c7..f0dbdfc25fc 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -25,10 +25,6 @@ except ImportError: _has_tb = False -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import PONG_VERSIONED -else: - from _utils_internal import PONG_VERSIONED from tensordict import TensorDict from torchrl.data import ( LazyMemmapStorage, @@ -38,6 +34,7 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym +from torchrl.testing import PONG_VERSIONED from torchrl.trainers import LogValidationReward, Trainer from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.trainers import ( diff --git a/test/test_transforms.py b/test/test_transforms.py index d2b4b8e3b61..73029e9f2c4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10,7 +10,6 @@ import contextlib import importlib.util import itertools -import os import pickle import re @@ -148,31 +147,17 @@ ) from torchrl.modules.utils import get_primers_from_module from torchrl.record.recorder import VideoRecorder -from torchrl.testing.modules import BiasModule -from torchrl.weight_update import RayModuleTransformScheme -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( # noqa - BREAKOUT_VERSIONED, - dtype_fixture, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - ) -else: - from _utils_internal import ( # noqa - BREAKOUT_VERSIONED, - dtype_fixture, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - ) +from torchrl.testing import ( # noqa + BREAKOUT_VERSIONED, + dtype_fixture, + get_default_devices, + HALFCHEETAH_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, + retry, +) from torchrl.testing.mocking_classes import ( ContinuousActionVecMockEnv, CountingBatchedEnv, @@ -191,6 +176,8 @@ NestedCountingEnv, StateLessCountingEnv, ) +from torchrl.testing.modules import BiasModule +from torchrl.weight_update import RayModuleTransformScheme _has_ray = importlib.util.find_spec("ray") is not None _has_ale = importlib.util.find_spec("ale_py") is not None @@ -860,10 +847,7 @@ def test_catframes_batching( self, batched_class, break_when_any_done, maybe_fork_ParallelEnv ): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED if batched_class is ParallelEnv: batched_class = maybe_fork_ParallelEnv @@ -1778,10 +1762,7 @@ def test_step_count_dmc(self): @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @pytest.mark.parametrize("break_when_any_done", [True, False]) def test_stepcount_batching(self, batched_class, break_when_any_done): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), @@ -6620,10 +6601,7 @@ def test_transform_env(self, out_key, reward_spec): @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @pytest.mark.parametrize("break_when_any_done", [True, False]) def test_rewardsum_batching(self, batched_class, break_when_any_done): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), RewardSum() @@ -7751,10 +7729,7 @@ def test_trans_parallel_env_check(self, mode, device, maybe_fork_ParallelEnv): @pytest.mark.parametrize("batched_class", [SerialEnv, ParallelEnv]) @pytest.mark.parametrize("break_when_any_done", [True, False]) def test_targetreturn_batching(self, batched_class, break_when_any_done): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), @@ -8282,10 +8257,7 @@ def make_env(): @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @pytest.mark.parametrize("break_when_any_done", [True, False]) def test_tensordictprimer_batching(self, batched_class, break_when_any_done): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), @@ -8523,10 +8495,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @pytest.mark.parametrize("break_when_any_done", [True, False]) def test_timemax_batching(self, batched_class, break_when_any_done): - if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import CARTPOLE_VERSIONED - else: - from _utils_internal import CARTPOLE_VERSIONED + from torchrl.testing import CARTPOLE_VERSIONED env = TransformedEnv( batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())), diff --git a/test/test_utils.py b/test/test_utils.py index bbe6556e6cc..d815b8ab482 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,22 +11,22 @@ from importlib import import_module from unittest import mock -import _utils_internal import pytest import torch - -from torchrl.objectives.utils import _pseudo_vmap - -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import capture_log_records, get_default_devices -else: - from _utils_internal import capture_log_records, get_default_devices from packaging import version from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend +from torchrl.objectives.utils import _pseudo_vmap + +from torchrl.testing import ( + capture_log_records, + get_default_devices, + gym_helpers as _gym_helpers, +) + TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @@ -105,25 +105,25 @@ class implement_for_test_functions: """ @staticmethod - @implement_for(lambda: import_module("_utils_internal"), "0.3") + @implement_for(lambda: import_module("torchrl.testing.utils"), "0.3") def select_correct_version(): """To test from+ range and that this function is not selected as the implementation.""" return "0.3+V1" @staticmethod - @implement_for("_utils_internal", "0.3") + @implement_for("torchrl.testing.utils", "0.3") def select_correct_version(): # noqa: F811 """To test that this function is selected as the implementation (last implementation).""" return "0.3+" @staticmethod - @implement_for(lambda: import_module("_utils_internal"), "0.2", "0.3") + @implement_for(lambda: import_module("torchrl.testing.utils"), "0.2", "0.3") def select_correct_version(): # noqa: F811 """To test that right bound is not included.""" return "0.2-0.3" @staticmethod - @implement_for("_utils_internal", "0.1", "0.2") + @implement_for("torchrl.testing.utils", "0.1", "0.2") def select_correct_version(): # noqa: F811 """To test that function with missing from-to range is ignored.""" return "0.1-0.2" @@ -135,12 +135,12 @@ def missing_module(): return "missing" @staticmethod - @implement_for("_utils_internal", None, "0.3") + @implement_for("torchrl.testing.utils", None, "0.3") def missing_version(): return "0-0.3" @staticmethod - @implement_for("_utils_internal", "0.4") + @implement_for("torchrl.testing.utils", "0.4") def missing_version(): # noqa: F811 return "0.4+" @@ -259,17 +259,17 @@ def test_set_gym_environments( with set_gym_backend(gymnasium): assert ( - _utils_internal._set_gym_environments is expected_fn_gymnasium + _gym_helpers._set_gym_environments is expected_fn_gymnasium ), expected_fn_gym with set_gym_backend(gym): assert ( - _utils_internal._set_gym_environments is expected_fn_gym + _gym_helpers._set_gym_environments is expected_fn_gym ), expected_fn_gymnasium with set_gym_backend(gymnasium): assert ( - _utils_internal._set_gym_environments is expected_fn_gymnasium + _gym_helpers._set_gym_environments is expected_fn_gymnasium ), expected_fn_gym @@ -290,7 +290,7 @@ def test_set_gym_environments_no_version_gymnasium_found(): msg = f"could not set anything related to gym backend {gymnasium_name} with version={gymnasium_version}." with pytest.raises(ImportError, match=msg): with set_gym_backend(gymnasium): - _utils_internal._set_gym_environments() + _gym_helpers._set_gym_environments() def test_set_gym_backend_types(): diff --git a/torchrl/testing/__init__.py b/torchrl/testing/__init__.py index 6a7be0f118e..ab7cb470346 100644 --- a/torchrl/testing/__init__.py +++ b/torchrl/testing/__init__.py @@ -9,24 +9,95 @@ particularly for distributed and Ray-based tests that require importable classes. """ +from torchrl.testing.assertions import ( + check_rollout_consistency_multikey_env, + rand_reset, + rollout_consistency_assertion, +) +from torchrl.testing.env_creators import ( + get_transform_out, + make_envs, + make_multithreaded_env, +) +from torchrl.testing.gym_helpers import ( + BREAKOUT_VERSIONED, + CARTPOLE_VERSIONED, + CLIFFWALKING_VERSIONED, + HALFCHEETAH_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, +) from torchrl.testing.llm_mocks import ( MockTransformerConfig, MockTransformerModel, MockTransformerOutput, ) +from torchrl.testing.modules import ( + BiasModule, + call_value_nets, + LSTMNet, + NonSerializableBiasModule, +) from torchrl.testing.ray_helpers import ( WorkerTransformerDoubleBuffer, WorkerTransformerNCCL, WorkerVLLMDoubleBuffer, WorkerVLLMNCCL, ) +from torchrl.testing.utils import ( + capture_log_records, + dtype_fixture, + generate_seeds, + get_available_devices, + get_default_devices, + IS_WIN, + make_tc, + mp_ctx, + PYTHON_3_9, + retry, + set_global_var, +) __all__ = [ - "WorkerVLLMNCCL", - "WorkerTransformerNCCL", - "WorkerVLLMDoubleBuffer", - "WorkerTransformerDoubleBuffer", + # Assertions + "check_rollout_consistency_multikey_env", + "rand_reset", + "rollout_consistency_assertion", + # Environment creators + "get_transform_out", + "make_envs", + "make_multithreaded_env", + # Gym helpers + "BREAKOUT_VERSIONED", + "CARTPOLE_VERSIONED", + "CLIFFWALKING_VERSIONED", + "HALFCHEETAH_VERSIONED", + "PENDULUM_VERSIONED", + "PONG_VERSIONED", + # LLM mocks "MockTransformerConfig", "MockTransformerModel", "MockTransformerOutput", + # Modules + "BiasModule", + "call_value_nets", + "LSTMNet", + "NonSerializableBiasModule", + # Ray helpers + "WorkerTransformerDoubleBuffer", + "WorkerTransformerNCCL", + "WorkerVLLMDoubleBuffer", + "WorkerVLLMNCCL", + # Utils + "capture_log_records", + "dtype_fixture", + "generate_seeds", + "get_available_devices", + "get_default_devices", + "IS_WIN", + "make_tc", + "mp_ctx", + "PYTHON_3_9", + "retry", + "set_global_var", ] diff --git a/torchrl/testing/assertions.py b/torchrl/testing/assertions.py new file mode 100644 index 00000000000..25738df53ca --- /dev/null +++ b/torchrl/testing/assertions.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Assertions and validation utilities for TorchRL tests.""" + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +__all__ = [ + "check_rollout_consistency_multikey_env", + "rand_reset", + "rollout_consistency_assertion", +] + + +def rollout_consistency_assertion( + rollout, *, done_key="done", observation_key="observation", done_strict=False +): + """Test that observations in 'next' match observations in the next root tensordict. + + Verifies consistency: when done is False the next observation should match, + and when done is True they should differ (indicating a reset occurred). + + Args: + rollout: The rollout tensordict to validate. + done_key: The key for the done signal. + observation_key: The key for observations. + done_strict: If True, raise an error if no done is detected. + """ + done = rollout[..., :-1]["next", done_key].squeeze(-1) + # data resulting from step, when it's not done + r_not_done = rollout[..., :-1]["next"][~done] + # data resulting from step, when it's not done, after step_mdp + r_not_done_tp1 = rollout[:, 1:][~done] + torch.testing.assert_close( + r_not_done[observation_key], + r_not_done_tp1[observation_key], + msg=f"Key {observation_key} did not match", + ) + + if done_strict and not done.any(): + raise RuntimeError("No done detected, test could not complete.") + if done.any(): + # data resulting from step, when it's done + r_done = rollout[..., :-1]["next"][done] + # data resulting from step, when it's done, after step_mdp and reset + r_done_tp1 = rollout[..., 1:][done] + # check that at least one obs after reset does not match the version before reset + assert not torch.isclose( + r_done[observation_key], r_done_tp1[observation_key] + ).all() + + +def rand_reset(env): + """Generate a tensordict with reset keys that mimic the done spec. + + Values are drawn at random until at least one reset is present. + + Args: + env: The environment to generate reset keys for. + + Returns: + A TensorDict containing the reset signals. + """ + full_done_spec = env.full_done_spec + result = {} + for reset_key, list_of_done in zip(env.reset_keys, env.done_keys_groups): + val = full_done_spec[list_of_done[0]].rand() + while not val.any(): + val = full_done_spec[list_of_done[0]].rand() + result[reset_key] = val + # create a data structure that keeps the batch size of the nested specs + result = ( + full_done_spec.zero().update(result).exclude(*full_done_spec.keys(True, True)) + ) + return result + + +def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int): + """Check rollout consistency for environments with multiple observation/action keys. + + Validates that: + - Done and reset behavior is correct for root, nested_1, and nested_2 + - Observations update correctly based on actions + - Rewards are computed correctly + + Args: + td: The rollout tensordict to validate. + max_steps: The maximum steps before done in the environment. + """ + index_batch_size = (0,) * (len(td.batch_size) - 1) + + # Check done and reset for root + observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1 + next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) + assert (td["next", "done"][observation_is_max]).all() + assert (~td["next", "done"][~observation_is_max]).all() + # Obs after done is 0 + assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all() + # Obs after not done is previous obs + assert ( + td["observation"][index_batch_size][1:][~next_is_done] + == td["next", "observation"][index_batch_size][:-1][~next_is_done] + ).all() + # Check observation and reward update with count action for root + action_is_count = td["action"].long().argmax(-1).to(torch.bool) + assert ( + td["next", "observation"][action_is_count] + == td["observation"][action_is_count] + 1 + ).all() + assert (td["next", "reward"][action_is_count] == 1).all() + # Check observation and reward do not update with no-count action for root + assert ( + td["next", "observation"][~action_is_count] + == td["observation"][~action_is_count] + ).all() + assert (td["next", "reward"][~action_is_count] == 0).all() + + # Check done and reset for nested_1 + observation_is_max = td["next", "nested_1", "observation"][..., 0] == max_steps + 1 + # done at the root always prevail + next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) + assert (td["next", "nested_1", "done"][observation_is_max]).all() + assert (~td["next", "nested_1", "done"][~observation_is_max]).all() + # Obs after done is 0 + assert ( + td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0 + ).all() + # Obs after not done is previous obs + assert ( + td["nested_1", "observation"][index_batch_size][1:][~next_is_done] + == td["next", "nested_1", "observation"][index_batch_size][:-1][~next_is_done] + ).all() + # Check observation and reward update with count action for nested_1 + action_is_count = td["nested_1"]["action"].to(torch.bool) + assert ( + td["next", "nested_1", "observation"][action_is_count] + == td["nested_1", "observation"][action_is_count] + 1 + ).all() + assert (td["next", "nested_1", "gift"][action_is_count] == 1).all() + # Check observation and reward do not update with no-count action for nested_1 + assert ( + td["next", "nested_1", "observation"][~action_is_count] + == td["nested_1", "observation"][~action_is_count] + ).all() + assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all() + + # Check done and reset for nested_2 + observation_is_max = td["next", "nested_2", "observation"][..., 0] == max_steps + 1 + # done at the root always prevail + next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1) + assert (td["next", "nested_2", "done"][observation_is_max]).all() + assert (~td["next", "nested_2", "done"][~observation_is_max]).all() + # Obs after done is 0 + assert ( + td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0 + ).all() + # Obs after not done is previous obs + assert ( + td["nested_2", "observation"][index_batch_size][1:][~next_is_done] + == td["next", "nested_2", "observation"][index_batch_size][:-1][~next_is_done] + ).all() + # Check observation and reward update with count action for nested_2 + action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool) + assert ( + td["next", "nested_2", "observation"][action_is_count] + == td["nested_2", "observation"][action_is_count] + 1 + ).all() + assert (td["next", "nested_2", "reward"][action_is_count] == 1).all() + # Check observation and reward do not update with no-count action for nested_2 + assert ( + td["next", "nested_2", "observation"][~action_is_count] + == td["nested_2", "observation"][~action_is_count] + ).all() + assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all() diff --git a/torchrl/testing/env_creators.py b/torchrl/testing/env_creators.py new file mode 100644 index 00000000000..e2f635be081 --- /dev/null +++ b/torchrl/testing/env_creators.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Environment creation utilities for TorchRL tests.""" + +from __future__ import annotations + +import torch + +from torchrl.envs import MultiThreadedEnv, ObservationNorm +from torchrl.envs.batched_envs import ParallelEnv, SerialEnv +from torchrl.envs.libs.envpool import _has_envpool +from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import ( + Compose, + RewardClipping, + ToTensorImage, + TransformedEnv, +) +from torchrl.testing.gym_helpers import HALFCHEETAH_VERSIONED, PONG_VERSIONED +from torchrl.testing.utils import mp_ctx + +__all__ = [ + "get_transform_out", + "make_envs", + "make_multithreaded_env", +] + + +def make_envs( + env_name, + frame_skip, + transformed_in, + transformed_out, + N, + device="cpu", + kwargs=None, + local_mp_ctx=mp_ctx, +): + """Create parallel, serial, multithreaded, and single environment instances. + + This helper creates environments suitable for testing batched environment behavior. + + Args: + env_name: The gym environment name. + frame_skip: Number of frames to skip. + transformed_in: Whether to apply transforms inside the base env. + transformed_out: Whether to apply transforms outside the batched env. + N: Number of environments in the batch. + device: Device for the environments. + kwargs: Additional keyword arguments for environment creation. + local_mp_ctx: Multiprocessing context ('fork' or 'spawn'). + + Returns: + Tuple of (env_parallel, env_serial, env_multithread, env0). + """ + torch.manual_seed(0) + if not transformed_in: + + def create_env_fn(): + return GymEnv(env_name, frame_skip=frame_skip, device=device) + + else: + if env_name == PONG_VERSIONED(): + + def create_env_fn(): + base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + in_keys = list(base_env.observation_spec.keys(True, True))[:1] + return TransformedEnv( + base_env, + Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]), + ) + + else: + + def create_env_fn(): + + base_env = GymEnv(env_name, frame_skip=frame_skip, device=device) + in_keys = list(base_env.observation_spec.keys(True, True))[:1] + + return TransformedEnv( + base_env, + Compose( + ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1), + RewardClipping(0, 0.1), + ), + ) + + env0 = create_env_fn() + env_parallel = ParallelEnv( + N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx + ) + env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs) + + for key in env0.observation_spec.keys(True, True): + obs_key = key + break + else: + obs_key = None + + if transformed_out: + t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key) + + env0 = TransformedEnv( + env0, + t_out(), + ) + env_parallel = TransformedEnv( + env_parallel, + t_out(), + ) + env_serial = TransformedEnv( + env_serial, + t_out(), + ) + else: + t_out = None + + if _has_envpool: + env_multithread = make_multithreaded_env( + env_name, + frame_skip, + t_out, + N, + device="cpu", + kwargs=None, + ) + else: + env_multithread = None + + return env_parallel, env_serial, env_multithread, env0 + + +def make_multithreaded_env( + env_name, + frame_skip, + transformed_out, + N, + device="cpu", + kwargs=None, +): + """Create a multithreaded environment using envpool. + + Args: + env_name: The gym environment name. + frame_skip: Number of frames to skip. + transformed_out: Transform factory to apply, or None. + N: Number of environments in the batch. + device: Device for the environment. + kwargs: Additional keyword arguments (unused, for API compatibility). + + Returns: + A MultiThreadedEnv instance, optionally wrapped with transforms. + """ + torch.manual_seed(0) + multithreaded_kwargs = ( + {"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {} + ) + env_multithread = MultiThreadedEnv( + N, + env_name, + create_env_kwargs=multithreaded_kwargs, + device=device, + ) + + if transformed_out: + for key in env_multithread.observation_spec.keys(True, True): + obs_key = key + break + else: + obs_key = None + env_multithread = TransformedEnv( + env_multithread, + get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(), + ) + return env_multithread + + +def get_transform_out(env_name, transformed_in, obs_key=None): + """Create a transform factory for output transforms based on environment type. + + Args: + env_name: The gym environment name. + transformed_in: Whether transforms were already applied inside. + obs_key: The observation key to transform. + + Returns: + A callable that returns a Compose transform. + """ + if env_name == PONG_VERSIONED(): + if obs_key is None: + obs_key = "pixels" + + def t_out(): + return ( + Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)]) + if not transformed_in + else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)]) + ) + + elif env_name == HALFCHEETAH_VERSIONED: + if obs_key is None: + obs_key = ("observation", "velocity") + + def t_out(): + return Compose( + ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1), + RewardClipping(0, 0.1), + ) + + else: + if obs_key is None: + obs_key = "observation" + + def t_out(): + return ( + Compose( + ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1), + RewardClipping(0, 0.1), + ) + if not transformed_in + else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0)) + ) + + return t_out diff --git a/torchrl/testing/gym_helpers.py b/torchrl/testing/gym_helpers.py new file mode 100644 index 00000000000..2e873249274 --- /dev/null +++ b/torchrl/testing/gym_helpers.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Versioned gym environment name helpers for TorchRL tests.""" + +from __future__ import annotations + +import sys + +from torchrl._utils import implement_for +from torchrl.envs.libs.gym import _has_gym, gym_backend + +__all__ = [ + "BREAKOUT_VERSIONED", + "CARTPOLE_VERSIONED", + "CLIFFWALKING_VERSIONED", + "HALFCHEETAH_VERSIONED", + "PENDULUM_VERSIONED", + "PONG_VERSIONED", +] + +PYTHON_3_9 = sys.version_info.major == 3 and sys.version_info.minor <= 9 + +# Module-level variables that will be set by _set_gym_environments +_CARTPOLE_VERSIONED = None +_HALFCHEETAH_VERSIONED = None +_PENDULUM_VERSIONED = None +_PONG_VERSIONED = None +_BREAKOUT_VERSIONED = None +_CLIFFWALKING_VERSIONED = None + + +def CARTPOLE_VERSIONED(): + """Return the versioned CartPole environment name for the current gym backend.""" + if gym_backend() is not None: + _set_gym_environments() + return _CARTPOLE_VERSIONED + + +def HALFCHEETAH_VERSIONED(): + """Return the versioned HalfCheetah environment name for the current gym backend.""" + if gym_backend() is not None: + _set_gym_environments() + return _HALFCHEETAH_VERSIONED + + +def PONG_VERSIONED(): + """Return the versioned Pong environment name for the current gym backend.""" + # Gymnasium says that the ale_py behavior changes from 1.0 + # but with python 3.12 it is already the case with 0.29.1 + try: + import ale_py # noqa: F401 + except ImportError: + pass + + if gym_backend() is not None: + _set_gym_environments() + return _PONG_VERSIONED + + +def CLIFFWALKING_VERSIONED(): + """Return the versioned CliffWalking environment name for the current gym backend.""" + if gym_backend() is not None: + _set_gym_environments() + return _CLIFFWALKING_VERSIONED + + +def BREAKOUT_VERSIONED(): + """Return the versioned Breakout environment name for the current gym backend.""" + # Gymnasium says that the ale_py behavior changes from 1.0 + # but with python 3.12 it is already the case with 0.29.1 + try: + import ale_py # noqa: F401 + except ImportError: + pass + + if gym_backend() is not None: + _set_gym_environments() + return _BREAKOUT_VERSIONED + + +def PENDULUM_VERSIONED(): + """Return the versioned Pendulum environment name for the current gym backend.""" + if gym_backend() is not None: + _set_gym_environments() + return _PENDULUM_VERSIONED + + +def _set_gym_environments(): + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED + + _CARTPOLE_VERSIONED = None + _HALFCHEETAH_VERSIONED = None + _PENDULUM_VERSIONED = None + _PONG_VERSIONED = None + _BREAKOUT_VERSIONED = None + _CLIFFWALKING_VERSIONED = None + + +@implement_for("gym", None, "0.21.0") +def _set_gym_environments(): # noqa: F811 + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED + + _CARTPOLE_VERSIONED = "CartPole-v0" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v2" + _PENDULUM_VERSIONED = "Pendulum-v0" + _PONG_VERSIONED = "Pong-v4" + _BREAKOUT_VERSIONED = "Breakout-v4" + _CLIFFWALKING_VERSIONED = "CliffWalking-v0" + + +@implement_for("gym", "0.21.0", None) +def _set_gym_environments(): # noqa: F811 + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED + + _CARTPOLE_VERSIONED = "CartPole-v1" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" + _PENDULUM_VERSIONED = "Pendulum-v1" + _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" + _CLIFFWALKING_VERSIONED = "CliffWalking-v0" + + +@implement_for("gymnasium", None, "1.0.0") +def _set_gym_environments(): # noqa: F811 + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED + + _CARTPOLE_VERSIONED = "CartPole-v1" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" + _PENDULUM_VERSIONED = "Pendulum-v1" + _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" + _CLIFFWALKING_VERSIONED = "CliffWalking-v0" + + +@implement_for("gymnasium", "1.0.0", "1.1.0") +def _set_gym_environments(): # noqa: F811 + raise ImportError + + +@implement_for("gymnasium", "1.1.0") +def _set_gym_environments(): # noqa: F811 + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED, _CLIFFWALKING_VERSIONED + + _CARTPOLE_VERSIONED = "CartPole-v1" + _HALFCHEETAH_VERSIONED = "HalfCheetah-v5" + _PENDULUM_VERSIONED = "Pendulum-v1" + _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" + _CLIFFWALKING_VERSIONED = "CliffWalking-v1" if not PYTHON_3_9 else "CliffWalking-v0" + + +if _has_gym: + _set_gym_environments() diff --git a/torchrl/testing/modules.py b/torchrl/testing/modules.py index 84dffae8485..e0730371d2b 100644 --- a/torchrl/testing/modules.py +++ b/torchrl/testing/modules.py @@ -1,7 +1,20 @@ from __future__ import annotations import torch -from torch import nn +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn, vmap + +from torchrl._utils import logger, RL_WARNINGS +from torchrl.modules import MLP +from torchrl.objectives.value.advantages import _vmap_func + +__all__ = [ + "BiasModule", + "LSTMNet", + "NonSerializableBiasModule", + "call_value_nets", +] class BiasModule(nn.Module): @@ -24,3 +37,259 @@ class NonSerializableBiasModule(BiasModule): def __getstate__(self): # Simulate a non-serializable policy by raising on pickling raise RuntimeError("NonSerializableBiasModule cannot be pickled") + + +class LSTMNet(nn.Module): + """An embedder for an LSTM preceded by an MLP. + + The forward method returns the hidden states of the current state + (input hidden states) and the output, as + the environment returns the 'observation' and 'next_observation'. + + Because the LSTM kernel only returns the last hidden state, hidden states + are padded with zeros such that they have the right size to be stored in a + TensorDict of size [batch x time_steps]. + + If a 2D tensor is provided as input, it is assumed that it is a batch of data + with only one time step. This means that we explicitly assume that users will + unsqueeze inputs of a single batch with multiple time steps. + + Args: + out_features (int): number of output features. + lstm_kwargs (dict): the keyword arguments for the + :class:`~torch.nn.LSTM` layer. + mlp_kwargs (dict): the keyword arguments for the + :class:`~torchrl.modules.MLP` layer. + device (torch.device, optional): the device where the module should + be instantiated. + + Keyword Args: + lstm_backend (str, optional): one of ``"torchrl"`` or ``"torch"`` that + indicates where the LSTM class is to be retrieved. The ``"torchrl"`` + backend (:class:`~torchrl.modules.LSTM`) is slower but works with + :func:`~torch.vmap` and should work with :func:`~torch.compile`. + Defaults to ``"torch"``. + + Examples: + >>> batch = 7 + >>> time_steps = 6 + >>> in_features = 4 + >>> out_features = 10 + >>> hidden_size = 5 + >>> net = LSTMNet( + ... out_features, + ... {"input_size": hidden_size, "hidden_size": hidden_size}, + ... {"out_features": hidden_size}, + ... ) + >>> # test single step vs multi-step + >>> x = torch.randn(batch, time_steps, in_features) # >3 dims = multi-step + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + >>> x = torch.randn(batch, in_features) # 2 dims = single step + >>> y, hidden0_in, hidden1_in, hidden0_out, hidden1_out = net(x) + + """ + + def __init__( + self, + out_features: int, + lstm_kwargs, + mlp_kwargs, + device=None, + *, + lstm_backend: str | None = None, + ) -> None: + super().__init__() + lstm_kwargs.update({"batch_first": True}) + self.mlp = MLP(device=device, **mlp_kwargs) + if lstm_backend is None: + lstm_backend = "torch" + self.lstm_backend = lstm_backend + if self.lstm_backend == "torch": + LSTM = nn.LSTM + else: + from torchrl.modules.tensordict_module.rnn import LSTM + self.lstm = LSTM(device=device, **lstm_kwargs) + self.linear = nn.LazyLinear(out_features, device=device) + + def _lstm( + self, + input: torch.Tensor, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ): + squeeze0 = False + squeeze1 = False + if input.ndimension() == 1: + squeeze0 = True + input = input.unsqueeze(0).contiguous() + + if input.ndimension() == 2: + squeeze1 = True + input = input.unsqueeze(1).contiguous() + batch, steps = input.shape[:2] + + if hidden1_in is None and hidden0_in is None: + shape = (batch, steps) if not squeeze1 else (batch,) + hidden0_in, hidden1_in = ( + torch.zeros( + *shape, + self.lstm.num_layers, + self.lstm.hidden_size, + device=input.device, + dtype=input.dtype, + ) + for _ in range(2) + ) + elif hidden1_in is None or hidden0_in is None: + raise RuntimeError( + f"got type(hidden0)={type(hidden0_in)} and type(hidden1)={type(hidden1_in)}" + ) + elif squeeze0: + hidden0_in = hidden0_in.unsqueeze(0) + hidden1_in = hidden1_in.unsqueeze(0) + + # we only need the first hidden state + if not squeeze1: + _hidden0_in = hidden0_in[:, 0] + _hidden1_in = hidden1_in[:, 0] + else: + _hidden0_in = hidden0_in + _hidden1_in = hidden1_in + hidden = ( + _hidden0_in.transpose(-3, -2).contiguous(), + _hidden1_in.transpose(-3, -2).contiguous(), + ) + + y0, hidden = self.lstm(input, hidden) + # dim 0 in hidden is num_layers, but that will conflict with tensordict + hidden = tuple(_h.transpose(0, 1) for _h in hidden) + y = self.linear(y0) + + out = [y, hidden0_in, hidden1_in, *hidden] + if squeeze1: + # squeezes time + out[0] = out[0].squeeze(1) + if not squeeze1: + # we pad the hidden states with zero to make tensordict happy + for i in range(3, 5): + out[i] = torch.stack( + [torch.zeros_like(out[i]) for _ in range(input.shape[1] - 1)] + + [out[i]], + 1, + ) + if squeeze0: + out = [_out.squeeze(0) for _out in out] + return tuple(out) + + def forward( + self, + input: torch.Tensor, + hidden0_in: torch.Tensor | None = None, + hidden1_in: torch.Tensor | None = None, + ): + input = self.mlp(input) + return self._lstm(input, hidden0_in, hidden1_in) + + +def call_value_nets( + value_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + next_params: TensorDictBase, + single_call: bool, + value_key: NestedKey, + detach_next: bool, + vmap_randomness: str = "error", +): + """Call value networks to compute values at t and t+1. + + This is a testing utility for computing value estimates in advantage + calculations. + + Args: + value_net: The value network module. + data: Input tensordict with observations. + params: Parameters for the value network at time t. + next_params: Parameters for the value network at time t+1. + single_call: Whether to use a single forward pass for both t and t+1. + value_key: The key where values are stored. + detach_next: Whether to detach the next value from the computation graph. + vmap_randomness: Randomness mode for vmap. + + Returns: + Tuple of (value, value_next). + """ + in_keys = value_net.in_keys + if single_call: + for i, name in enumerate(data.names): + if name == "time": + ndim = i + 1 + break + else: + ndim = None + if ndim is not None: + # get data at t and last of t+1 + idx0 = (slice(None),) * (ndim - 1) + (slice(-1, None),) + idx = (slice(None),) * (ndim - 1) + (slice(None, -1),) + idx_ = (slice(None),) * (ndim - 1) + (slice(1, None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False)[idx0], + ], + ndim - 1, + ) + else: + if RL_WARNINGS: + logger.warning( + "Got a tensordict without a time-marked dimension, assuming time is along the last dimension. " + "This warning can be turned off by setting the environment variable RL_WARNINGS to False." + ) + ndim = data.ndim + idx = (slice(None),) * (ndim - 1) + (slice(None, data.shape[ndim - 1]),) + idx_ = (slice(None),) * (ndim - 1) + (slice(data.shape[ndim - 1], None),) + data_in = torch.cat( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + ndim - 1, + ) + + # next_params should be None or be identical to params + if next_params is not None and next_params is not params: + raise ValueError( + "the value at t and t+1 cannot be retrieved in a single call without recurring to vmap when both params and next params are passed." + ) + if params is not None: + with params.to_module(value_net): + value_est = value_net(data_in).get(value_key) + else: + value_est = value_net(data_in).get(value_key) + value, value_ = value_est[idx], value_est[idx_] + else: + data_in = torch.stack( + [ + data.select(*in_keys, value_key, strict=False), + data.get("next").select(*in_keys, value_key, strict=False), + ], + 0, + ) + if (params is not None) ^ (next_params is not None): + raise ValueError( + "params and next_params must be either both provided or not." + ) + elif params is not None: + params_stack = torch.stack([params, next_params], 0).contiguous() + data_out = _vmap_func(value_net, (0, 0), randomness=vmap_randomness)( + data_in, params_stack + ) + else: + data_out = vmap(value_net, (0,), randomness=vmap_randomness)(data_in) + value_est = data_out.get(value_key) + value, value_ = value_est[0], value_est[1] + data.set(value_key, value) + data.set(("next", value_key), value_) + if detach_next: + value_ = value_.detach() + return value, value_ diff --git a/torchrl/testing/utils.py b/torchrl/testing/utils.py new file mode 100644 index 00000000000..0e3e0d020bb --- /dev/null +++ b/torchrl/testing/utils.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""General testing utilities for TorchRL tests.""" + +from __future__ import annotations + +# Version for testing implement_for decorator +__version__ = "0.3" + +import contextlib +import logging +import sys +import time +import unittest +from collections.abc import Callable +from functools import wraps + +import pytest +import torch +import torch.cuda +from tensordict import tensorclass + +from torchrl._utils import logger, seed_generator + +__all__ = [ + "capture_log_records", + "dtype_fixture", + "generate_seeds", + "get_available_devices", + "get_default_devices", + "IS_WIN", + "make_tc", + "mp_ctx", + "PYTHON_3_9", + "retry", + "set_global_var", +] + +IS_WIN = sys.platform == "win32" +if IS_WIN: + mp_ctx = "spawn" +else: + mp_ctx = "fork" + +PYTHON_3_9 = sys.version_info.major == 3 and sys.version_info.minor <= 9 + + +def get_available_devices(): + """Return a list of all available torch devices (CPU and all CUDA devices).""" + devices = [torch.device("cpu")] + n_cuda = torch.cuda.device_count() + if n_cuda > 0: + for i in range(n_cuda): + devices += [torch.device(f"cuda:{i}")] + return devices + + +def get_default_devices(): + """Return a sensible default list of devices for testing. + + Returns [cpu] if no CUDA, [cuda:0] if one GPU, all devices if multiple GPUs. + """ + num_cuda = torch.cuda.device_count() + if num_cuda == 0: + return [torch.device("cpu")] + elif num_cuda == 1: + return [torch.device("cuda:0")] + else: + return get_available_devices() + + +def generate_seeds(seed, repeat): + """Generate a list of seeds from a starting seed using the seed_generator.""" + seeds = [seed] + for _ in range(repeat - 1): + seed = seed_generator(seed) + seeds.append(seed) + return seeds + + +def retry( + ExceptionToCheck: type[Exception], + tries: int = 3, + delay: int = 3, + skip_after_retries: bool = False, +) -> Callable[[Callable], Callable]: + """Decorator to retry a function upon certain Exceptions. + + Args: + ExceptionToCheck: The exception type to catch and retry. + tries: Number of attempts before giving up. + delay: Seconds to wait between retries. + skip_after_retries: If True, skip the test after all retries fail. + + Returns: + A decorator that wraps the function with retry logic. + """ + + def deco_retry(f): + @wraps(f) + def f_retry(*args, **kwargs): + mtries, mdelay = tries, delay + while mtries > 1: + try: + return f(*args, **kwargs) + except ExceptionToCheck as e: + msg = "%s, Retrying in %d seconds..." % (str(e), mdelay) + logger.info(msg) + time.sleep(mdelay) + mtries -= 1 + try: + return f(*args, **kwargs) + except ExceptionToCheck as e: + if skip_after_retries: + raise pytest.skip( + f"Skipping after {tries} consecutive {str(e)}" + ) from e + else: + raise e + + return f_retry + + return deco_retry + + +def capture_log_records(records, logger_qname, record_name): + """Capture log records matching a name pattern from a specific logger. + + After calling this function, any log record whose name contains 'record_name' + and is emitted from the logger that has qualified name 'logger_qname' is + appended to the 'records' list. + + NOTE: This function is based on testing utilities for 'torch._logging'. + """ + assert isinstance(records, list) + log = logging.getLogger(logger_qname) + + class EmitWrapper: + def __init__(self, old_emit): + self.old_emit = old_emit + + def __call__(self, record): + nonlocal records # noqa: F824 + self.old_emit(record) + if record_name in record.name: + records.append(record) + + for handler in log.handlers: + new_emit = EmitWrapper(handler.emit) + contextlib.ExitStack().enter_context( + unittest.mock.patch.object(handler, "emit", new_emit) + ) + + +@pytest.fixture +def dtype_fixture(): + """Pytest fixture that sets the default dtype to double for the test duration.""" + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + yield dtype + torch.set_default_dtype(dtype) + + +@contextlib.contextmanager +def set_global_var(module, var_name, value): + """Context manager to temporarily set a module's global variable.""" + old_value = getattr(module, var_name) + setattr(module, var_name, value) + try: + yield + finally: + setattr(module, var_name, old_value) + + +def make_tc(td): + """Create a tensorclass type from a tensordict instance. + + Creates a new tensorclass with fields matching the keys of the input tensordict. + """ + + class MyClass: + pass + + MyClass.__annotations__ = {} + for key in td.keys(): + MyClass.__annotations__[key] = torch.Tensor + return tensorclass(MyClass)