Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torch import nn, vmap

from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
from torchrl.envs.libs.envpool import _has_envpool
Expand Down Expand Up @@ -594,14 +593,6 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()


def decorate_thread_sub_func(func, num_threads):
def new_func(*args, **kwargs):
assert torch.get_num_threads() == num_threads
return func(*args, **kwargs)

return CloudpickleWrapper(new_func)


class LSTMNet(nn.Module):
"""An embedder for an LSTM preceded by an MLP.

Expand Down
3 changes: 1 addition & 2 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
SafeModule,
)
from torchrl.testing.modules import BiasModule, NonSerializableBiasModule
from torchrl.testing.mp_helpers import decorate_thread_sub_func
from torchrl.weight_update import (
MultiProcessWeightSyncScheme,
SharedMemWeightSyncScheme,
Expand All @@ -99,7 +100,6 @@
from pytorch.rl.test._utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
Expand All @@ -112,7 +112,6 @@
from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
Expand Down
15 changes: 15 additions & 0 deletions torchrl/testing/mp_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations

import torch

from torchrl.data.utils import CloudpickleWrapper


def decorate_thread_sub_func(func, num_threads):
"""Decorate a function to assert that the number of threads is correct."""

def new_func(*args, **kwargs):
assert torch.get_num_threads() == num_threads
return func(*args, **kwargs)

return CloudpickleWrapper(new_func)
Loading