Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
831 changes: 0 additions & 831 deletions test/_utils_internal.py

This file was deleted.

6 changes: 1 addition & 5 deletions test/smoke_test_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import argparse
import os
import sys
import tempfile

Expand Down Expand Up @@ -65,10 +64,7 @@ def test_gym():
import ale_py # noqa: F401
except Exception: # pragma: no cover
pytest.skip("ALE not available (missing ale_py); skipping Atari gym test.")
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import PONG_VERSIONED
else:
from _utils_internal import PONG_VERSIONED
from torchrl.testing import PONG_VERSIONED

try:
env = GymEnv(PONG_VERSIONED())
Expand Down
6 changes: 1 addition & 5 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import importlib.util
import os

import pytest
import torch
Expand All @@ -33,10 +32,7 @@
ValueOperator,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices
from torchrl.testing import get_default_devices
from torchrl.testing.mocking_classes import NestedCountingEnv

_has_vllm = importlib.util.find_spec("vllm") is not None
Expand Down
48 changes: 16 additions & 32 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import contextlib
import functools
import gc
import os
import subprocess
import sys
import time
Expand Down Expand Up @@ -88,37 +87,17 @@
RandomPolicy,
SafeModule,
)
from torchrl.testing.modules import BiasModule, NonSerializableBiasModule
from torchrl.testing.mp_helpers import decorate_thread_sub_func
from torchrl.weight_update import (
MultiProcessWeightSyncScheme,
SharedMemWeightSyncScheme,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
IS_FB = True
from pytorch.rl.test._utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
retry,
)
else:
IS_FB = False
from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
retry,
)
from torchrl.testing import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
generate_seeds,
get_available_devices,
get_default_devices,
LSTMNet,
PENDULUM_VERSIONED,
retry,
)
from torchrl.testing.mocking_classes import (
ContinuousActionVecMockEnv,
CountingBatchedEnv,
Expand All @@ -137,6 +116,12 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from torchrl.testing.modules import BiasModule, NonSerializableBiasModule
from torchrl.testing.mp_helpers import decorate_thread_sub_func
from torchrl.weight_update import (
MultiProcessWeightSyncScheme,
SharedMemWeightSyncScheme,
)

# torch.set_default_dtype(torch.double)
IS_WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -864,7 +849,6 @@ def test_env_that_errors(self, ctype):
break

@retry(AssertionError, tries=10, delay=0)
@pytest.mark.skipif(IS_FB, reason="Not compatible with fbcode")
@pytest.mark.parametrize("to", [3, 10])
@pytest.mark.parametrize(
"collector_cls", ["MultiSyncCollector", "MultiAsyncCollector"]
Expand Down
6 changes: 1 addition & 5 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import importlib.util
import os

import pytest
import torch
Expand Down Expand Up @@ -36,10 +35,7 @@
LLMMaskedCategorical,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices
from torchrl.testing import get_default_devices

_has_scipy = importlib.util.find_spec("scipy", None) is not None

Expand Down
32 changes: 10 additions & 22 deletions test/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,28 +174,16 @@ def check_no_lingering_multiprocessing_resources(request):
_atari_found = False
atari_confs = defaultdict(str)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
)
else:
from _utils_internal import (
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
)
from torchrl.testing import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
get_default_devices,
HALFCHEETAH_VERSIONED,
make_envs as _make_envs,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
)
from torchrl.testing.mocking_classes import (
ActionObsMergeLinear,
AutoResetHeteroCountingEnv,
Expand Down
6 changes: 1 addition & 5 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import argparse
import os

import pytest
import torch
Expand Down Expand Up @@ -41,10 +40,7 @@
OrnsteinUhlenbeckProcessModule,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices
from torchrl.testing import get_default_devices
from torchrl.testing.mocking_classes import (
ContinuousActionVecMockEnv,
CountingEnvCountModule,
Expand Down
22 changes: 9 additions & 13 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import dataclasses
import os
import pathlib
import sys
from time import sleep
Expand All @@ -24,6 +23,15 @@
FlattenObservation,
TransformedEnv,
)

from torchrl.testing import generate_seeds, get_default_devices
from torchrl.testing.mocking_classes import (
ContinuousActionConvMockEnvNumpy,
ContinuousActionVecMockEnv,
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
MockSerialEnv,
)
from torchrl.trainers.helpers import transformed_env_constructor
from torchrl.trainers.helpers.envs import (
EnvConfig,
Expand All @@ -36,18 +44,6 @@
make_dqn_actor,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import generate_seeds, get_default_devices
else:
from _utils_internal import generate_seeds, get_default_devices
from torchrl.testing.mocking_classes import (
ContinuousActionConvMockEnvNumpy,
ContinuousActionVecMockEnv,
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
MockSerialEnv,
)

try:
from hydra import compose, initialize
from hydra.core.config_store import ConfigStore
Expand Down
41 changes: 13 additions & 28 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,34 +136,19 @@
_has_ray = importlib.util.find_spec("ray") is not None
_has_ale = importlib.util.find_spec("ale_py") is not None
_has_mujoco = importlib.util.find_spec("mujoco") is not None
if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
CLIFFWALKING_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
else:
from _utils_internal import (
_make_multithreaded_env,
CARTPOLE_VERSIONED,
CLIFFWALKING_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)
from torchrl.testing import (
CARTPOLE_VERSIONED,
CLIFFWALKING_VERSIONED,
get_available_devices,
get_default_devices,
HALFCHEETAH_VERSIONED,
make_multithreaded_env as _make_multithreaded_env,
PENDULUM_VERSIONED,
PONG_VERSIONED,
rand_reset,
retry,
rollout_consistency_assertion,
)

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

Expand Down
8 changes: 2 additions & 6 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import argparse
import os
import re

from numbers import Number
Expand Down Expand Up @@ -58,12 +57,9 @@
from torchrl.modules.planners.mppi import MPPIPlanner
from torchrl.objectives.value import TDLambdaEstimator

from torchrl.testing.mocking_classes import MockBatchedUnLockedEnv
from torchrl.testing import get_default_devices, retry

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices, retry
else:
from _utils_internal import get_default_devices, retry
from torchrl.testing.mocking_classes import MockBatchedUnLockedEnv


@pytest.fixture
Expand Down
24 changes: 7 additions & 17 deletions test/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import importlib.util
import itertools
import operator
import os
import sys
import warnings
from copy import deepcopy
Expand Down Expand Up @@ -155,22 +154,13 @@
_split_and_pad_sequence,
)

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import ( # noqa
_call_value_nets,
dtype_fixture,
get_available_devices,
get_default_devices,
PENDULUM_VERSIONED,
)
else:
from _utils_internal import ( # noqa
_call_value_nets,
dtype_fixture,
get_available_devices,
get_default_devices,
PENDULUM_VERSIONED,
)
from torchrl.testing import ( # noqa
call_value_nets as _call_value_nets,
dtype_fixture,
get_available_devices,
get_default_devices,
PENDULUM_VERSIONED,
)
from torchrl.testing.mocking_classes import ContinuousActionConvMockEnv

_has_functorch = True
Expand Down
6 changes: 1 addition & 5 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import argparse
import functools
import os

import pytest
import torch
Expand All @@ -16,10 +15,7 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.data.postprocs.postprocs import DensifyReward, MultiStep

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
else:
from _utils_internal import get_default_devices
from torchrl.testing import get_default_devices


class TestMultiStep:
Expand Down
Loading
Loading