diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index b337f687a62..e125bf927ee 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -118,6 +118,7 @@ uv_pip_install \ "pybind11[global]>=2.13" \ pyyaml \ scipy \ + psutil \ hydra-core \ tensorboard \ "imageio==2.26.0" \ diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 432eb99020c..0ae27d63113 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -21,6 +21,7 @@ dependencies: - pybind11[global] - pyyaml - scipy + - psutil - hydra-core - tensorboard - imageio==2.26.0 diff --git a/.github/unittest/linux_libs/scripts_minari/requirements.txt b/.github/unittest/linux_libs/scripts_minari/requirements.txt index ae21314f263..de947f62d2c 100644 --- a/.github/unittest/linux_libs/scripts_minari/requirements.txt +++ b/.github/unittest/linux_libs/scripts_minari/requirements.txt @@ -12,6 +12,7 @@ expecttest pybind11[global] pyyaml scipy +psutil hydra-core minari[gcs,hdf5,hf,create] gymnasium>=1.2.0 diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index f98f5963872..850cfe4936d 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.1 torchvision==0.16 cpuonly -c pytorch -y else - pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 + python -m pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 # conda install pytorch==2.1 torchvision==0.16 pytorch-cuda=11.8 "numpy<2.0" -c pytorch -c nvidia -y fi @@ -47,16 +47,17 @@ fi #pip install -U charset-normalizer # install tensordict -if [[ "$RELEASE" == 0 ]]; then - conda install anaconda::cmake -y - python -m pip install "pybind11[global]" - python -m pip install git+https://github.com/pytorch/tensordict.git -else - python -m pip install tensordict -fi +# +# NOTE: +# - The olddeps CI job runs on older Python/torch stacks. +# - Installing from tensordict `main` (git+https) is brittle as `main` may drop +# support for older Python versions at any time, which can lead to "tensordict +# not installed" failures in downstream smoke tests. +# - Use the same (pinned) range as TorchRL itself to keep this job stable. +python -m pip install "${TORCHRL_TENSORDICT_SPEC:-tensordict>=0.10.0,<0.11.0}" # smoke test -python -c "import tensordict" +python -c "import tensordict; print(f'tensordict: {tensordict.__version__}')" printf "* Installing torchrl\n" python -m pip install -e . --no-build-isolation diff --git a/.github/unittest/linux_optdeps/scripts/environment.yml b/.github/unittest/linux_optdeps/scripts/environment.yml index 3c16131fec9..97e33013b0d 100644 --- a/.github/unittest/linux_optdeps/scripts/environment.yml +++ b/.github/unittest/linux_optdeps/scripts/environment.yml @@ -17,5 +17,6 @@ dependencies: - pybind11[global] - pyyaml - scipy + - psutil - coverage - ray diff --git a/.github/unittest/linux_sota/scripts/environment.yml b/.github/unittest/linux_sota/scripts/environment.yml index e3258f5973e..5089c8570ca 100644 --- a/.github/unittest/linux_sota/scripts/environment.yml +++ b/.github/unittest/linux_sota/scripts/environment.yml @@ -20,6 +20,7 @@ dependencies: - pybind11[global] - pyyaml - scipy + - psutil - hydra-core - imageio==2.26.0 - dm_control diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index 4c8147ecbd1..c899dc5f693 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -84,6 +84,7 @@ uv pip install \ pybind11 \ pyyaml \ scipy \ + psutil \ hydra-core \ "imageio==2.26.0" \ dm_control \ diff --git a/.github/unittest/windows_optdepts/scripts/environment.yml b/.github/unittest/windows_optdepts/scripts/environment.yml index 2740c77f434..43f264d49ff 100644 --- a/.github/unittest/windows_optdepts/scripts/environment.yml +++ b/.github/unittest/windows_optdepts/scripts/environment.yml @@ -15,4 +15,5 @@ dependencies: - expecttest - pyyaml - scipy + - psutil - coverage diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 964bdd82296..72c17b4e04d 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -169,6 +169,9 @@ jobs: set -euo pipefail export PYTHON_VERSION="3.9" export CU_VERSION="cu118" + # Olddeps runs on Python 3.9: pin tensordict to a Python-3.9-compatible range. + # (Avoid installing tensordict from git main, which may drop older Python support.) + export TORCHRL_TENSORDICT_SPEC="tensordict>=0.10.0,<0.11.0" export TAR_OPTIONS="--no-same-owner" if [[ "${{ github.ref }}" =~ release/* ]]; then export RELEASE=1 diff --git a/Makefile b/Makefile new file mode 100644 index 00000000000..0a77ecfe940 --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +# TorchRL Development Makefile + +.PHONY: clean build develop test + +# Clean all build artifacts (use when switching Python/PyTorch versions) +clean: + rm -rf build/ + rm -rf dist/ + rm -rf *.egg-info/ + rm -rf torchrl/*.egg-info/ + rm -f torchrl/_torchrl*.so + rm -f torchrl/version.py + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete 2>/dev/null || true + +# Build C++ extensions in-place +build: + python setup.py build_ext --inplace + +# Full clean + build +rebuild: clean build + +# Development install (editable) +develop: rebuild + pip install -e . --no-build-isolation + +# Run tests +test: + python -m pytest test/ -v --timeout 60 diff --git a/benchmarks/test_non_tensor_env_benchmark.py b/benchmarks/test_non_tensor_env_benchmark.py index fcddfee632b..94721754a37 100644 --- a/benchmarks/test_non_tensor_env_benchmark.py +++ b/benchmarks/test_non_tensor_env_benchmark.py @@ -38,7 +38,7 @@ def test_non_tensor_env_rollout_speed( ): """Benchmarks a single rollout, after a warmup rollout, for non-tensor stacking envs. - Mirrors `test/test_env.py::TestNonTensorEnv`'s option matrix (single/serial/parallel, + Mirrors `test/test_envs.py::TestNonTensorEnv`'s option matrix (single/serial/parallel, break_when_any_done, use_buffers). """ with set_capture_non_tensor_stack(False): diff --git a/setup.py b/setup.py index b63a367edba..f0e4408bbcd 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import contextlib import glob import importlib.util +import json import logging import os import re @@ -8,6 +11,7 @@ import sys from pathlib import Path +import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension @@ -15,6 +19,67 @@ ROOT_DIR = Path(__file__).parent.resolve() _RELEASE_BRANCH_RE = re.compile(r"^release/v(?P.+)$") +_BUILD_INFO_FILE = ROOT_DIR / "build" / ".torchrl_build_info.json" + + +def _check_and_clean_stale_builds(): + """Check if existing build was made with a different PyTorch version and clean if so. + + This prevents ABI incompatibility issues when switching between PyTorch versions. + """ + current_torch_version = torch.__version__ + current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + if _BUILD_INFO_FILE.exists(): + try: + with open(_BUILD_INFO_FILE) as f: + build_info = json.load(f) + old_torch = build_info.get("torch_version") + old_python = build_info.get("python_version") + + if ( + old_torch != current_torch_version + or old_python != current_python_version + ): + logger.warning( + f"Detected PyTorch/Python version change: " + f"PyTorch {old_torch} -> {current_torch_version}, " + f"Python {old_python} -> {current_python_version}. " + f"Cleaning stale build artifacts..." + ) + # Clean stale .so files for current Python version + so_pattern = ( + ROOT_DIR + / "torchrl" + / f"_torchrl.cpython-{sys.version_info.major}{sys.version_info.minor}*.so" + ) + for so_file in glob.glob(str(so_pattern)): + logger.warning(f"Removing stale: {so_file}") + os.remove(so_file) + # Clean build directory + build_dir = ROOT_DIR / "build" + if build_dir.exists(): + import shutil + + for item in build_dir.iterdir(): + if item.name.startswith("temp.") or item.name.startswith( + "lib." + ): + logger.warning(f"Removing stale build dir: {item}") + shutil.rmtree(item) + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"Could not read build info: {e}") + + # Write current build info + _BUILD_INFO_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(_BUILD_INFO_FILE, "w") as f: + json.dump( + { + "torch_version": current_torch_version, + "python_version": current_python_version, + }, + f, + ) def get_extensions(): @@ -162,6 +227,9 @@ def set_version(): def main(): """Main setup function for building TorchRL with C++ extensions.""" + # Check for stale builds from different PyTorch/Python versions + _check_and_clean_stale_builds() + with set_version(): pretend_version = os.environ.get("SETUPTOOLS_SCM_PRETEND_VERSION") _has_setuptools_scm = importlib.util.find_spec("setuptools_scm") is not None diff --git a/test/llm/libs/test_mlgym.py b/test/llm/libs/test_mlgym.py index a86732e622f..723d215a39a 100644 --- a/test/llm/libs/test_mlgym.py +++ b/test/llm/libs/test_mlgym.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import importlib.util from functools import partial @@ -16,7 +17,9 @@ from torchrl.envs.llm import make_mlgym from torchrl.modules.llm import TransformersWrapper -pytest.importorskip("mlgym") +pytestmark = pytest.mark.skipif( + not importlib.util.find_spec("mlgym"), reason="mlgym not available" +) class TestMLGYM: diff --git a/test/llm/test_collectors.py b/test/llm/test_llm_collectors.py similarity index 100% rename from test/llm/test_collectors.py rename to test/llm/test_llm_collectors.py diff --git a/test/llm/test_envs.py b/test/llm/test_llm_envs.py similarity index 99% rename from test/llm/test_envs.py rename to test/llm/test_llm_envs.py index fa8b03db774..1eebec03a29 100644 --- a/test/llm/test_envs.py +++ b/test/llm/test_llm_envs.py @@ -29,7 +29,6 @@ ) from torchrl.modules.llm import TransformersWrapper, vLLMWrapper -from transformers import AutoTokenizer _has_ray = importlib.util.find_spec("ray") is not None _has_transformers = importlib.util.find_spec("transformers") is not None @@ -43,6 +42,11 @@ and (importlib.util.find_spec("immutabledict") is not None) ) +pytestmark = pytest.mark.skipif( + not (_has_datasets & _has_transformers & _has_vllm & _has_ray), + reason="requires datasets, transformers, vllm, and ray", +) + @pytest.fixture(scope="module", autouse=True) def set_seed(): @@ -75,6 +79,8 @@ def set_list_to_stack_for_test(): class TestChatEnv: @pytest.fixture def tokenizer(self): + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") @pytest.mark.parametrize("input_mode", ["text", "tokens", "history"]) @@ -789,6 +795,7 @@ def delayed_calculator(cls, operation: str, a: float, b: float) -> dict: @classmethod def make_env(cls): from torchrl.envs.llm.transforms.tools import SimpleToolTransform + from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") env = ChatEnv( @@ -871,6 +878,7 @@ def test_async_mcp_tools(self): def test_mcp_python_execution(self): """Test actual MCP Python execution with mcp-run-python server.""" from torchrl.envs.llm.transforms import MCPToolTransform + from transformers import AutoTokenizer # Setup environment for MCP (Deno needs to be in PATH) environ = os.environ.copy() diff --git a/test/llm/test_objectives.py b/test/llm/test_llm_objectives.py similarity index 100% rename from test/llm/test_objectives.py rename to test/llm/test_llm_objectives.py diff --git a/test/llm/test_transforms.py b/test/llm/test_llm_transforms.py similarity index 100% rename from test/llm/test_transforms.py rename to test/llm/test_llm_transforms.py diff --git a/test/llm/test_updaters.py b/test/llm/test_llm_updaters.py similarity index 100% rename from test/llm/test_updaters.py rename to test/llm/test_llm_updaters.py diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py index b496e749c78..aff217b8e5c 100644 --- a/test/llm/test_wrapper.py +++ b/test/llm/test_wrapper.py @@ -8,10 +8,10 @@ import gc import importlib.util import threading - import time from concurrent.futures import ThreadPoolExecutor, wait from functools import partial +from typing import Any, TYPE_CHECKING import pytest import torch @@ -31,18 +31,21 @@ ) from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper from torchrl.modules.llm.policies.vllm_wrapper import vLLMWrapper -from transformers import AutoTokenizer _has_transformers = importlib.util.find_spec("transformers") is not None _has_vllm = importlib.util.find_spec("vllm") is not None -_has_datasets = importlib.util.find_spec("datasets") is not None _has_ray = importlib.util.find_spec("ray") is not None +# _has_datasets = importlib.util.find_spec("datasets") is not None TransformersWrapperMaxTokens = partial( TransformersWrapper, generate_kwargs={"max_new_tokens": 10, "do_sample": True} ) +if TYPE_CHECKING: + from transformers import AutoModelForCausalLM, AutoTokenizer + from vllm import LLM + @pytest.fixture(scope="function", autouse=True) def set_seed(): @@ -62,9 +65,7 @@ def set_list_to_stack_fixture(): @pytest.fixture(scope="module") -def vllm_instance() -> tuple[ - vllm.LLM, transformers.AutoTokenizer # noqa # type: ignore -]: # noqa # type: ignore +def vllm_instance() -> tuple[LLM, AutoTokenizer]: # noqa # type: ignore """Create vLLM model and tokenizer for testing.""" if not _has_vllm: pytest.skip("vllm not available") @@ -83,6 +84,8 @@ def vllm_instance() -> tuple[ max_model_len=32768, gpu_memory_utilization=0.3, # Limit to 30% GPU memory to avoid OOM with multiple engines ) + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") tokenizer.pad_token = tokenizer.eos_token return model, tokenizer @@ -92,7 +95,7 @@ def vllm_instance() -> tuple[ @pytest.fixture(scope="module") def async_vllm_instance() -> tuple[ - Any, transformers.AutoTokenizer # noqa # type: ignore + Any, AutoTokenizer # noqa # type: ignore ]: # noqa # type: ignore """Create async vLLM engine and tokenizer for testing.""" if not _has_vllm: @@ -114,6 +117,8 @@ def async_vllm_instance() -> tuple[ max_num_batched_tokens=32768, gpu_memory_utilization=0.3, # Limit to 30% GPU memory to avoid OOM with multiple engines ) + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") tokenizer.pad_token = tokenizer.eos_token return async_engine, tokenizer @@ -123,7 +128,7 @@ def async_vllm_instance() -> tuple[ @pytest.fixture(scope="module") def transformers_instance() -> tuple[ - transformers.AutoModelForCausalLM, transformers.AutoTokenizer # noqa # type: ignore + AutoModelForCausalLM, AutoTokenizer # noqa # type: ignore ]: # noqa # type: ignore """Create transformers model and tokenizer for testing.""" if not _has_transformers: diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py index b18181c573f..52a09c59a77 100644 --- a/test/services/test_python_executor_service.py +++ b/test/services/test_python_executor_service.py @@ -1,12 +1,15 @@ """Tests for PythonExecutorService with Ray service registry.""" from __future__ import annotations +# Skip all tests if Ray is not available +import importlib.util + import pytest -# Skip all tests if Ray is not available -pytest.importorskip("ray") +pytestmark = pytest.mark.skipif( + not importlib.util.find_spec("ray"), reason="Ray not available" +) -import ray from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter from torchrl.services import get_services @@ -14,6 +17,8 @@ @pytest.fixture def ray_init(): """Initialize Ray for tests.""" + import ray + if not ray.is_initialized(): ray.init() yield @@ -73,6 +78,8 @@ def test_service_execution(self, ray_init): result = x + y print(f"Result: {result}") """ + import ray + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is True @@ -84,6 +91,8 @@ def test_service_execution(self, ray_init): def test_service_execution_error(self, ray_init): """Test that the service handles execution errors.""" + import ray + namespace = "test_executor_error" services = get_services(backend="ray", namespace=namespace) @@ -111,6 +120,8 @@ def test_service_execution_error(self, ray_init): def test_multiple_executions(self, ray_init): """Test multiple concurrent executions.""" + import ray + namespace = "test_executor_multi" services = get_services(backend="ray", namespace=namespace) diff --git a/test/services/test_services.py b/test/services/test_services.py index 60e25547c62..cc44cb3f8d9 100644 --- a/test/services/test_services.py +++ b/test/services/test_services.py @@ -5,16 +5,14 @@ from __future__ import annotations -import pytest - -pytest.importorskip("ray") +import importlib.util # Import from mocking_classes which is a proper module import sys from pathlib import Path -import ray -from ray.util.state import get_actor as get_actor_by_id +import pytest + from torchrl._utils import logger sys.path.insert(0, str(Path(__file__).parent)) @@ -23,12 +21,18 @@ from test_services_fixtures import SimpleService, TokenizerService from torchrl.services import get_services, RayService +pytestmark = pytest.mark.skipif( + not importlib.util.find_spec("ray"), reason="Ray not available" +) + @pytest.fixture(scope="module", autouse=True) def ray_init(): """Initialize Ray once for the entire test module.""" import os + import ray + if ray.is_initialized(): ray.shutdown(raise_on_error=False) @@ -48,6 +52,9 @@ def ray_init(): @pytest.fixture(scope="function", autouse=True) def kill_all_actors(): """Kill all actors after each test.""" + import ray + from ray.util.state import get_actor as get_actor_by_id + yield if not ray.is_initialized(): return @@ -102,6 +109,8 @@ def test_initialization_with_existing_ray(self): def test_register_service(self): """Test registering a new service.""" + import ray + services = RayService(namespace="test_register") try: actor = services.register("simple", SimpleService, value=42) @@ -115,6 +124,8 @@ def test_register_service(self): def test_register_with_ray_options(self): """Test registering a service with Ray options.""" + import ray + services = RayService(namespace="test_options") try: actor = services.register( @@ -142,8 +153,9 @@ def test_register_duplicate_raises(self): def test_get_service(self): """Test retrieving a registered service.""" - services = RayService(namespace="test_get") + import ray + services = RayService(namespace="test_get") try: # Register a service original_actor = services.register("simple", SimpleService, value=100) @@ -168,6 +180,8 @@ def test_get_nonexistent_raises(self): def test_getitem_access(self): """Test dict-like access with [].""" + import ray + services = RayService(namespace="test_getitem") try: @@ -211,6 +225,8 @@ def test_list_services(self): def test_cross_worker_visibility(self): """Test that services registered by one worker are visible to another.""" + import ray + namespace = "test_cross_worker" # Worker 1: Register a service @@ -229,6 +245,8 @@ def test_cross_worker_visibility(self): def test_namespace_isolation(self): """Test that different namespaces isolate services.""" + import ray + # Register in namespace A services_a = RayService(namespace="namespace_a") services_a.register("service", SimpleService, value=111) @@ -253,6 +271,8 @@ def test_namespace_isolation(self): def test_options_method(self): """Test the register_with_options() method for explicit configuration.""" + import ray + services = RayService(namespace="test_options_method") try: @@ -273,6 +293,8 @@ def test_options_method(self): def test_service_persistence(self): """Test that services persist across RayService instances.""" + import ray + namespace = "test_persistence" # Create first instance and register @@ -394,6 +416,8 @@ class TestIntegrationScenarios: def test_tokenizer_sharing(self): """Test sharing a tokenizer across workers.""" + import ray + namespace = "test_tokenizer_integration" # Setup: Register tokenizer @@ -421,6 +445,8 @@ def test_tokenizer_sharing(self): def test_stateful_service(self): """Test that services maintain state across calls.""" + import ray + services = RayService(namespace="test_stateful") services.register("counter", SimpleService, value=0) @@ -444,6 +470,8 @@ def test_stateful_service(self): def test_conditional_registration(self): """Test pattern: register only if not exists.""" + import ray + namespace = "test_conditional" services1 = get_services(backend="ray", namespace=namespace) @@ -480,6 +508,8 @@ def test_conditional_registration(self): def test_multiple_services_management(self): """Test managing multiple different services.""" + import ray + services = RayService(namespace="test_multiple") try: diff --git a/test/test_collector.py b/test/test_collectors.py similarity index 98% rename from test/test_collector.py rename to test/test_collectors.py index c243067706c..1c52fb37615 100644 --- a/test/test_collector.py +++ b/test/test_collectors.py @@ -23,6 +23,7 @@ import torchrl.collectors._multi_base import torchrl.collectors._runner from packaging import version +from pyvers import implement_for from tensordict import ( assert_allclose_td, LazyStackedTensorDict, @@ -147,6 +148,16 @@ _has_cuda = torch.cuda.is_available() +@implement_for("torch", "2.5") +def has_mps(): + return torch.mps.is_available() + + +@implement_for("torch", None, "2.5") +def has_mps(): # noqa: F811 + return torch.backends.mps.is_available() + + class WrappablePolicy(nn.Module): def __init__(self, out_features: int, multiple_outputs: bool = False): super().__init__() @@ -485,8 +496,6 @@ def policy_outplace(td): def test_collector_output_keys( self, collector_class, init_random_frames, explicit_spec, split_trajs ): - from torchrl.envs.libs.gym import GymEnv - out_features = 1 hidden_size = 12 total_frames = 200 @@ -636,8 +645,6 @@ def test_collector_vecnorm_envcreator(self, static_seed): are modified after the collector is run for more steps. """ - from torchrl.envs.libs.gym import GymEnv - num_envs = 4 env_make = EnvCreator( lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm()) @@ -730,7 +737,7 @@ def env_fn(seed): return env policy = make_policy(env_name) - + torchrl_logger.info("Sync") collector = Collector( create_env_fn=env_fn, create_env_kwargs={"seed": seed}, @@ -740,6 +747,7 @@ def env_fn(seed): total_frames=20000, device="cpu", ) + torchrl_logger.info("Loop") try: assert collector._use_buffers for i, d in enumerate(collector): @@ -753,8 +761,10 @@ def env_fn(seed): with pytest.raises(AssertionError): assert_allclose_td(b1, b2) finally: + torchrl_logger.info("Shutting down sync") collector.shutdown() + torchrl_logger.info("Concurrent") ccollector = AsyncCollector( create_env_fn=env_fn, create_env_kwargs={"seed": seed}, @@ -763,6 +773,7 @@ def env_fn(seed): max_frames_per_traj=2000, total_frames=20000, ) + torchrl_logger.info("Loop") for i, d in enumerate(ccollector): if i == 0: b1c = d @@ -781,6 +792,7 @@ def env_fn(seed): assert_allclose_td(b1c, b1) assert_allclose_td(b2c, b2) finally: + torchrl_logger.info("Shutting down concurrent") ccollector.shutdown() del ccollector @@ -1062,10 +1074,23 @@ def test_no_deepcopy_policy(self, collector_type): # warnings.warn("Tensordict is registered in PyTree", category=UserWarning) + # Skip multi-collectors on macOS with older PyTorch when MPS is available. + # On macOS: "fork" causes segfaults after MPS initialization (even with CPU tensors), + # and "spawn" on older PyTorch (<2.5) can't handle some multiprocessing scenarios. + is_multi_collector = collector_type is not Collector + is_macos = sys.platform == "darwin" + is_old_pytorch = version.parse(torch.__version__).base_version < "2.5.0" + mps_available = torch.backends.mps.is_available() + if is_multi_collector and is_macos and is_old_pytorch and mps_available: + pytest.skip( + "Multi-collectors are not supported on macOS with MPS available and PyTorch < 2.5.0 " + "due to multiprocessing compatibility issues with MPS initialization." + ) + shared_device = torch.device("cpu") if torch.cuda.is_available(): original_device = torch.device("cuda:0") - elif torch.mps.is_available(): + elif has_mps(): original_device = torch.device("mps") else: pytest.skip("No GPU or MPS device") @@ -2457,8 +2482,6 @@ def env_fn(seed): class TestAutoWrap: @pytest.fixture def env_maker(self): - from torchrl.envs.libs.gym import GymEnv - return lambda: GymEnv(PENDULUM_VERSIONED()) def _create_collector_kwargs(self, env_maker, collector_class, policy, num_envs): @@ -2740,17 +2763,6 @@ def test_collector_nested_env_combinations( @pytest.mark.parametrize("batch_size", [(), (5,), (5, 2)]) def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20): - if os.getenv("PYTORCH_TEST_FBCODE"): - from torchrl.testing.mocking_classes import ( - CountingEnvCountPolicy, - NestedCountingEnv, - ) - else: - from torchrl.testing.mocking_classes import ( - CountingEnvCountPolicy, - NestedCountingEnv, - ) - env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) env_fn = lambda: NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim) torch.manual_seed(0) @@ -2962,7 +2974,7 @@ def test_multi_collector_consistency( @pytest.mark.skipif( - not torch.cuda.is_available() and not torch.mps.is_available(), + not torch.cuda.is_available() and (not has_mps()), reason="No casting if no cuda", ) class TestUpdateParams: @@ -3105,6 +3117,19 @@ def test_param_sync( def test_param_sync_mixed_device( self, give_weights, collector, policy_device, env_device, weight_sync_scheme ): + # Skip multi-collectors on macOS with older PyTorch when MPS is available. + # On macOS: "fork" causes segfaults after MPS initialization (even with CPU tensors), + # and "spawn" on older PyTorch (<2.5) can't handle some multiprocessing scenarios. + is_multi_collector = collector is not Collector + is_macos = sys.platform == "darwin" + is_old_pytorch = version.parse(torch.__version__).base_version < "2.5.0" + mps_available = torch.backends.mps.is_available() + if is_multi_collector and is_macos and is_old_pytorch and mps_available: + pytest.skip( + "Multi-collectors are not supported on macOS with MPS available and PyTorch < 2.5.0 " + "due to multiprocessing compatibility issues with MPS initialization." + ) + with torch.device("cpu"): policy = TestUpdateParams.Policy() policy.param = nn.Parameter(policy.param.data.to(policy_device)) @@ -4051,6 +4076,17 @@ def test_collector_postproc_zeros( 3. Postproc is not applied when replay buffer is used with extend_buffer=False 4. The behavior is consistent across Sync, MultiaSync, and MultiSync collectors """ + # Skip multi-collectors with replay buffer on older Python. + # There's a known shared memory visibility race condition with Python < 3.10 and the + # "spawn" multiprocessing start method. The child process writes to shared memory, + # but the main process may sample before the writes are fully visible. + is_multi_collector = collector_class != Collector + if is_multi_collector and use_replay_buffer and sys.version_info < (3, 10): + pytest.skip( + "Multi-collectors with replay buffer are not supported on Python < 3.10 " + "due to shared memory visibility issues with the 'spawn' start method." + ) + # Create a simple dummy environment def make_env(): env = DiscreteActionVecMockEnv() diff --git a/test/test_env.py b/test/test_envs.py similarity index 100% rename from test/test_env.py rename to test/test_envs.py diff --git a/test/test_cost.py b/test/test_objectives.py similarity index 100% rename from test/test_cost.py rename to test/test_objectives.py diff --git a/torchrl/_utils.py b/torchrl/_utils.py index c72b6f7e11a..fe2ee81c853 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -644,8 +644,48 @@ def get_trace(): traceback.print_stack() +def _make_process_no_warn_cls(ctx=None): + """Create a _ProcessNoWarn class that inherits from the appropriate Process class. + + When using multiprocessing contexts (e.g., fork or spawn), the Process class + used must match the context to ensure synchronization primitives like locks + work correctly. This factory function creates a _ProcessNoWarn class that + inherits from the context's Process class. + + Args: + ctx: A multiprocessing context (e.g., from mp.get_context('fork')). + If None, uses the default mp.Process. + + Returns: + A _ProcessNoWarn class that inherits from the appropriate Process base. + + .. note:: + For the "spawn" start method, this returns pre-defined module-level classes + to ensure they can be pickled correctly. + """ + if ctx is None: + return _ProcessNoWarn + + start_method = ctx.get_start_method() + if start_method == "fork": + return _ProcessNoWarnFork + elif start_method == "spawn": + return _ProcessNoWarnSpawn + elif start_method == "forkserver": + return _ProcessNoWarnForkserver + else: + # For unknown start methods, fall back to default + return _ProcessNoWarn + + +# Keep the old class name as a default for backwards compatibility class _ProcessNoWarn(mp.Process): - """A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess.""" + """A private Process class that shuts down warnings on the subprocess and controls the number of threads in the subprocess. + + .. note:: + When using multiprocessing contexts with synchronization primitives (locks, etc.), + use :func:`_make_process_no_warn_cls` with the context to ensure compatibility. + """ @wraps(mp.Process.__init__) def __init__(self, *args, num_threads=None, _start_method=None, **kwargs): @@ -669,6 +709,81 @@ def run(self, *args, **kwargs): return mp.Process.run(self, *args, **kwargs) +# Pre-defined _ProcessNoWarn classes for different multiprocessing start methods. +# These must be defined at module level to be picklable with the "spawn" start method. +# +# We use a mixin pattern to avoid code duplication while still having +# distinct module-level classes that can be pickled. + + +class _ProcessNoWarnMixin: + """Mixin class providing the common functionality for _ProcessNoWarn variants.""" + + def _init_process_no_warn(self, num_threads=None, _start_method=None): + import torchrl + + self.filter_warnings_subprocess = torchrl.filter_warnings_subprocess + self.num_threads = num_threads + if _start_method is not None: + self._start_method = _start_method + + def run(self, *args, **kwargs): + if self.num_threads is not None: + torch.set_num_threads(self.num_threads) + if self.filter_warnings_subprocess: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return super().run(*args, **kwargs) + return super().run(*args, **kwargs) + + +# Spawn-specific class (for macOS default and Windows) +try: + _spawn_ctx = mp.get_context("spawn") + + class _ProcessNoWarnSpawn(_ProcessNoWarnMixin, _spawn_ctx.Process): + """_ProcessNoWarn for the 'spawn' multiprocessing context.""" + + def __init__(self, *args, num_threads=None, _start_method=None, **kwargs): + self._init_process_no_warn(num_threads, _start_method) + super().__init__(*args, **kwargs) + +except ValueError: + _ProcessNoWarnSpawn = _ProcessNoWarn + + +# Fork-specific class (for Linux default, not available on Windows) +try: + _fork_ctx = mp.get_context("fork") + + class _ProcessNoWarnFork(_ProcessNoWarnMixin, _fork_ctx.Process): + """_ProcessNoWarn for the 'fork' multiprocessing context.""" + + def __init__(self, *args, num_threads=None, _start_method=None, **kwargs): + self._init_process_no_warn(num_threads, _start_method) + super().__init__(*args, **kwargs) + +except ValueError: + _ProcessNoWarnFork = _ProcessNoWarn + + +# Forkserver-specific class (not available on Windows) +try: + _forkserver_ctx = mp.get_context("forkserver") + + class _ProcessNoWarnForkserver(_ProcessNoWarnMixin, _forkserver_ctx.Process): + """_ProcessNoWarn for the 'forkserver' multiprocessing context.""" + + def __init__(self, *args, num_threads=None, _start_method=None, **kwargs): + self._init_process_no_warn(num_threads, _start_method) + super().__init__(*args, **kwargs) + +except ValueError: + _ProcessNoWarnForkserver = _ProcessNoWarn + + def print_directory_tree(path, indent="", display_metadata=True): """Prints the directory tree starting from the specified path. diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 8e9193803d6..d6cc18de6e0 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -16,7 +16,6 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase from torch import nn as nn from torch.utils.data import IterableDataset -from torchrl._utils import logger as torchrl_logger from torchrl.collectors.utils import _map_weight from torchrl.collectors.weight_update import WeightUpdaterBase @@ -484,9 +483,6 @@ def _weight_update_impl( weights_dict = {model_id: policy_or_weights} elif weights_dict is None: weights_dict = {model_id: policy_or_weights} - torchrl_logger.debug( - f"Calling weight update with {model_id=} and {weights_dict.keys()=}" - ) for target_model_id, weights in weights_dict.items(): if target_model_id not in self._weight_sync_schemes: raise KeyError( @@ -497,13 +493,9 @@ def _weight_update_impl( weights, target_model_id ) # Use new send() API with worker_ids support - torchrl_logger.debug("weight update -- getting scheme") scheme = self._weight_sync_schemes.get(target_model_id) if not isinstance(scheme, WeightSyncScheme): raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}") - torchrl_logger.debug( - f"calling send() on scheme {type(scheme).__name__}" - ) self._send_weights_scheme( scheme=scheme, processed_weights=processed_weights, @@ -515,7 +507,6 @@ def _weight_update_impl( raise RuntimeError else: # No weight updater configured, try fallback - torchrl_logger.debug("No weight update configured, trying fallback.") self._maybe_fallback_update(policy_or_weights, model_id=model_id) def _maybe_fallback_update( @@ -543,12 +534,8 @@ def _receive_weights_scheme(self): if not hasattr(self, "_receiver_schemes"): raise RuntimeError("No receiver schemes registered.") - for model_id, scheme in self._receiver_schemes.items(): - torchrl_logger.debug( - f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}" - ) - received_weights = scheme.receive() - torchrl_logger.debug(f"Received weights: {type(received_weights)=}") + for scheme in self._receiver_schemes.values(): + scheme.receive() # Overloads for receive_weights to support multiple calling conventions @overload @@ -723,11 +710,8 @@ def register_scheme_receiver( # Perform initial synchronization if synchronize_weights: - for model_id, scheme in weight_recv_schemes.items(): + for scheme in weight_recv_schemes.values(): if not scheme.synchronized_on_receiver: - torchrl_logger.debug( - f"Synchronizing weights for scheme {type(scheme).__name__} for model '{model_id}'" - ) scheme.connect(worker_idx=self.worker_idx) def __iter__(self) -> Iterator[TensorDictBase]: diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 109ff5bcf6c..ae803f194b5 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -19,7 +19,7 @@ from torchrl._utils import ( _check_for_faulty_process, _get_mp_ctx, - _ProcessNoWarn, + _make_process_no_warn_cls, _set_mp_start_method_if_unset, RL_WARNINGS, ) @@ -956,6 +956,7 @@ def _run_processes(self) -> None: # Extract parent pipes for external use (e.g., polling, receiving messages) self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs] + _ProcessNoWarnCtx = _make_process_no_warn_cls(ctx) # Initialize all weight sync schemes now that pipes are available # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes) # can be initialized here since all required resources exist @@ -963,9 +964,9 @@ def _run_processes(self) -> None: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: torchrl_logger.debug( - f"Init scheme {type(scheme)} on sender side of {type(self)} with {model_id=} and model {_resolve_model(self, model_id)}." + f"Init weight sync scheme {type(scheme).__name__} for {model_id=}." ) - scheme.init_on_sender(model_id=model_id, context=self) + scheme.init_on_sender(model_id=model_id, context=self, ctx=ctx) # Create a policy on the right device policy_factory = self.policy_factory @@ -1072,7 +1073,7 @@ def _run_processes(self) -> None: "weight_sync_schemes": self._weight_sync_schemes, "worker_idx": i, # Worker index for queue-based weight distribution } - proc = _ProcessNoWarn( + proc = _ProcessNoWarnCtx( target=_main_async_collector, num_threads=self.num_sub_threads, _start_method=ctx.get_start_method(), diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 9e11563f6f9..0f6d1ec57c9 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -159,9 +159,6 @@ def _main_async_collector( if verbose: torchrl_logger.debug(f"mp worker {idx} received {msg}") except EOFError: - torchrl_logger.debug( - f"Failed to receive data. Last message received: {msg}" - ) raise elif not run_free: if verbose: @@ -198,8 +195,6 @@ def _main_async_collector( continue else: # placeholder, will be checked after - if msg != "continue": - torchrl_logger.debug(f"mp worker {idx} will reset {msg} to 'continue'") msg = "continue" if msg == "run_free": run_free = True @@ -208,9 +203,6 @@ def _main_async_collector( # Capture shutdown / update / seed signal, but continue should not be expected if pipe_child.poll(1e-4): data_in, msg = pipe_child.recv() - torchrl_logger.debug( - f"mp worker {idx} received {msg} while running free" - ) if msg == "continue": # Switch back to run_free = False run_free = False @@ -270,8 +262,6 @@ def _main_async_collector( has_timed_out = False continue except queue.Full: - if verbose: - torchrl_logger.debug(f"mp worker {idx} has timed out") has_timed_out = True continue diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index a567a15c040..8d18bafa8bf 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -14,7 +14,7 @@ from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase from torch import nn -from torchrl import compile_with_warmup, logger as torchrl_logger +from torchrl import compile_with_warmup from torchrl._utils import ( _ends_with, _make_ordinal_device, @@ -918,29 +918,23 @@ def _maybe_make_final_rollout(self, make_rollout: bool): # If the policy has a valid spec, we use it self._policy_output_keys = set() - if ( - make_rollout - and hasattr( - self._wrapped_policy_uncompiled - if has_meta_params - else self._wrapped_policy, - "spec", - ) - and ( - self._wrapped_policy_uncompiled - if has_meta_params - else self._wrapped_policy - ).spec - is not None - and all( - v is not None - for v in ( - self._wrapped_policy_uncompiled - if has_meta_params - else self._wrapped_policy - ).spec.values(True, True) - ) - ): + _policy_to_check = ( + self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy + ) + _has_spec = hasattr(_policy_to_check, "spec") + _spec_not_none = False + _all_values_not_none = False + if _has_spec: + _spec = _policy_to_check.spec + _spec_not_none = _spec is not None + if _spec_not_none: + _all_values_not_none = all( + v is not None for v in _spec.values(True, True) + ) + _condition = ( + make_rollout and _has_spec and _spec_not_none and _all_values_not_none + ) + if _condition: if any( key not in self._final_rollout.keys(isinstance(key, tuple)) for key in ( @@ -1255,17 +1249,14 @@ def cuda_check(tensor: torch.Tensor): while self._frames < self.total_frames: self._iter += 1 - torchrl_logger.debug("Collector: rollout.") tensordict_out = self.rollout() if tensordict_out is None: # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out # frames are updated within the rollout function - torchrl_logger.debug("Collector: No tensordict_out. Yielding.") yield continue self._increment_frames(tensordict_out.numel()) tensordict_out = self._postproc(tensordict_out) - torchrl_logger.debug("Collector: postproc done.") if self.return_same_td: # This is used with multiprocessed collectors to use the buffers # stored in the tensordict. @@ -1276,10 +1267,6 @@ def cuda_check(tensor: torch.Tensor): yield tensordict_out elif self.replay_buffer is not None and not self._ignore_rb: self.replay_buffer.extend(tensordict_out) - torchrl_logger.debug( - f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " - "Buffer write count: {self.replay_buffer.write_count}. Yielding." - ) yield else: # we must clone the values, as the tensordict is updated in-place. @@ -1534,17 +1521,11 @@ def rollout(self) -> TensorDictBase: and not self._ignore_rb and not self.extend_buffer ): - torchrl_logger.debug( - f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." - ) self.replay_buffer.add(self._carrier) if self._increment_frames(self._carrier.numel()): return else: if self.storing_device is not None: - torchrl_logger.debug( - f"Collector: Moving to {self.storing_device} and adding to queue." - ) non_blocking = ( not self.no_cuda_sync or self.storing_device.type == "cuda" ) @@ -1570,7 +1551,6 @@ def rollout(self) -> TensorDictBase: self.interruptor is not None and self.interruptor.collection_stopped() ): - torchrl_logger.debug("Collector: Interruptor stopped.") if ( self.replay_buffer is not None and not self._ignore_rb @@ -1597,7 +1577,6 @@ def rollout(self) -> TensorDictBase: break else: if self._use_buffers: - torchrl_logger.debug("Returning final rollout within buffer.") result = self._final_rollout try: result = torch.stack( @@ -1620,9 +1599,6 @@ def rollout(self) -> TensorDictBase: ): return else: - torchrl_logger.debug( - "Returning final rollout with NO buffer (maybe_dense_stack)." - ) result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 323116c6995..ef0fbf40a4f 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -221,7 +221,6 @@ def _run_collector( warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) - torchrl_logger.debug(f"RANK {rank} -- init collector") # NOTE: # - `weight_sync_schemes` here are the *distributed* schemes used to send # weights from the main process to this node. @@ -242,7 +241,6 @@ def _run_collector( if weight_sync_schemes is not None: for model_id, scheme in weight_sync_schemes.items(): - torchrl_logger.debug(f"RANK {rank} -- init receiver for model '{model_id}'") # Provide both collector context and distributed store / rank so the # scheme can wire its transport correctly. scheme.init_on_receiver( @@ -251,15 +249,7 @@ def _run_collector( # store=_store, worker_idx=rank, ) - torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") scheme.connect() - torchrl_logger.debug( - f"RANK {rank} -- initial weight sync for '{model_id}' completed" - ) - else: - torchrl_logger.debug( - f"RANK {rank} -- {collector_class.__name__} without weight_sync_schemes \n\n" - ) total_frames = 0 while True: @@ -279,14 +269,9 @@ def _run_collector( torchrl_logger.debug( f"RANK {rank} -- got data, total frames = {total_frames}" ) - torchrl_logger.debug( - f"RANK {rank} -- data batch_size={data.batch_size}, " - f"keys={list(data.keys(False, True))}" - ) torchrl_logger.debug( f"RANK {rank} -- sending TensorDict payload to rank 0" ) - torchrl_logger.debug(f"RANK {rank} -- {data=}") if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index d2033b2de14..8b1b3cd7cd0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -9,6 +9,7 @@ import gc import os import time +import warnings import weakref from collections import OrderedDict from collections.abc import Callable, Mapping, Sequence @@ -35,7 +36,6 @@ from torchrl._utils import ( _check_for_faulty_process, _make_ordinal_device, - _ProcessNoWarn, logger as torchrl_logger, VERBOSE, ) @@ -1436,16 +1436,13 @@ def _start_workers(self) -> None: if self._mp_start_method is not None: ctx = mp.get_context(self._mp_start_method) - proc_fun = ctx.Process - num_sub_threads = self.num_sub_threads else: ctx = mp.get_context("spawn") - proc_fun = functools.partial( - _ProcessNoWarn, - num_threads=self.num_sub_threads, - _start_method=self._mp_start_method, - ) - num_sub_threads = None + # Use ctx.Process directly to ensure all multiprocessing primitives + # (Queue, Pipe, Process, Event) come from the same context. + # Warning filtering and num_threads are handled in the worker functions. + proc_fun = ctx.Process + num_sub_threads = self.num_sub_threads _num_workers = self.num_workers @@ -1481,6 +1478,8 @@ def look_for_cuda(tensor, has_cuda=has_cuda): env_fun = self.create_env_fn[idx] if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)): env_fun = CloudpickleWrapper(env_fun) + import torchrl + kwargs[idx].update( { "parent_pipe": parent_pipe, @@ -1490,6 +1489,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): "has_lazy_inputs": self.has_lazy_inputs, "num_threads": num_sub_threads, "non_blocking": self.non_blocking, + "filter_warnings": torchrl.filter_warnings_subprocess, } ) if self._use_buffers: @@ -2410,7 +2410,12 @@ def _run_worker_pipe_shared_mem( has_lazy_inputs: bool = False, verbose: bool = False, num_threads: int | None = None, # for fork start method + filter_warnings: bool = False, ) -> None: + pid = os.getpid() + # Handle warning filtering (moved from _ProcessNoWarn) + if filter_warnings: + warnings.filterwarnings("ignore") if num_threads is not None: torch.set_num_threads(num_threads) device = shared_tensordict.device @@ -2430,7 +2435,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda): else: event = None parent_pipe.close() - pid = os.getpid() if not isinstance(env_fun, EnvBase): env = env_fun(**env_fun_kwargs) else: @@ -2673,7 +2677,11 @@ def _run_worker_pipe_direct( verbose: bool = False, num_threads: int | None = None, # for fork start method consolidate: bool = True, + filter_warnings: bool = False, ) -> None: + # Handle warning filtering (moved from _ProcessNoWarn) + if filter_warnings: + warnings.filterwarnings("ignore") if num_threads is not None: torch.set_num_threads(num_threads) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 1b0aa359134..e002cd5797d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3871,7 +3871,6 @@ def fake_tensordict(self) -> TensorDictBase: fake_obs = observation_spec.zero() fake_reward = reward_spec.zero() fake_done = full_done_spec.zero() - fake_state = state_spec.zero() fake_action = action_spec.zero() diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 8253fda80ed..07a971e99b6 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -19,7 +19,6 @@ # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell from torchrl._utils import timeit -from torchrl.envs.utils import step_mdp from torchrl.modules.models.models import MLP @@ -244,6 +243,7 @@ def forward(self, tensordict): which amends to q(s_{t+1} | s_t, a_t, o_{t+1}) """ + # from torchrl.envs.utils import step_mdp tensordict_out = [] *batch, time_steps = tensordict.shape @@ -261,9 +261,9 @@ def forward(self, tensordict): self.rssm_posterior(_tensordict) tensordict_out.append(_tensordict) + # _tensordict = step_mdp(_tensordict, keep_other=True) if t < time_steps - 1: # Translate ("next", *) to the non-next key required for the current step input - _tensordict = step_mdp(_tensordict, keep_other=True) _tensordict = _tensordict.select(*self.in_keys, strict=False) _tensordict = update_values[t + 1].update(_tensordict) diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index fd2625002c8..c5e1bf15e15 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -177,10 +177,6 @@ def _make_store( self._store_port = initial_port try: - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Creating TCPStore on {host}:{self._store_port} " - f"(attempt {attempt + 1}/{max_retries})" - ) store = torch.distributed.TCPStore( host_name=host, port=self._store_port, @@ -189,9 +185,6 @@ def _make_store( wait_for_workers=False, # Don't block - workers may not be started yet ) self._store_info = {"host": host, "port": self._store_port} - torchrl_logger.debug( - f"DistributedWeightSyncScheme: TCPStore created successfully: {self._store_info}" - ) return store except (RuntimeError, OSError) as e: error_msg = str(e).lower() @@ -199,10 +192,6 @@ def _make_store( "address already in use" in error_msg or "eaddrinuse" in error_msg ): - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Port {self._store_port} already in use, " - f"retrying ({attempt + 1}/{max_retries})..." - ) last_error = e # Add small random delay to reduce collision probability time.sleep(random.uniform(0.01, 0.1)) @@ -218,10 +207,6 @@ def _make_store( # Connect as client if store_info is None: raise ValueError("store_info is required when connecting as client") - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Connecting to TCPStore at " - f"{store_info['host']}:{store_info['port']}" - ) store = torch.distributed.TCPStore( host_name=store_info["host"], port=store_info["port"], @@ -414,20 +399,12 @@ def _background_receive_loop(self): 3. Sends an acknowledgment back 4. Repeats until stop event is set """ - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Background receiver started for worker {self._worker_idx}" - ) while not self._stop_event.is_set(): try: instruction = self._wait_for_instruction() if instruction is None: continue if instruction in ("receive", "update_weights"): - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Worker {self._worker_idx} " - "received 'receive' instruction" - ) - # Receive weights via torch.distributed weights = self._receiver_transport.receive_weights( model=self.model, @@ -440,9 +417,6 @@ def _background_receive_loop(self): if self.context is not None and hasattr( self.context, "update_policy_weights_" ): - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Cascading weight update to sub-collectors for {model_id=}" - ) self.context.update_policy_weights_( model_id=model_id, policy_or_weights=weights ) @@ -450,15 +424,7 @@ def _background_receive_loop(self): # Send acknowledgment self._send_ack("updated") - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Worker {self._worker_idx} " - "received and applied weights" - ) - elif instruction == "stop": - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Worker {self._worker_idx} received 'stop' instruction" - ) break else: torchrl_logger.warning( @@ -471,10 +437,6 @@ def _background_receive_loop(self): f"DistributedWeightSyncScheme: Background receiver error: {e}" ) - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Background receiver stopped for worker {self._worker_idx}" - ) - def _setup_connection_and_weights_on_sender_impl( self, *, worker_idx: int | None = None, weights: Any | None = None ) -> None: @@ -489,10 +451,6 @@ def _setup_connection_and_weights_on_sender_impl( # Initialize torch.distributed process group if not already done # This is a collective operation - all workers must call it if not torch.distributed.is_initialized(): - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Initializing process group on sender " - f"(world_size={self._num_workers + 1})" - ) torch.distributed.init_process_group( backend=self.backend, rank=0, # Sender is always rank 0 @@ -502,9 +460,6 @@ def _setup_connection_and_weights_on_sender_impl( # Check if we have weights to send if weights is None and getattr(self, "model", None) is None: - torchrl_logger.debug( - "DistributedWeightSyncScheme: No model on sender, skipping initial weight sync" - ) self._store.set("STATELESS_MODEL", b"1") return @@ -512,15 +467,8 @@ def _setup_connection_and_weights_on_sender_impl( # Prepare weights from model weights = self._get_weights_buffer_from_model(self.model) if weights is None or weights.is_empty(): - torchrl_logger.debug( - "DistributedWeightSyncScheme: Empty weights, skipping initial weight sync" - ) return - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Sending initial weights to {self._num_workers} workers" - ) - # Send to all workers using direct torch.distributed (no TCPStore signaling) for i, transport in enumerate(self._iterate_transports()): if worker_idx is not None and i != worker_idx: @@ -542,10 +490,6 @@ def _setup_connection_and_weights_on_receiver_impl( # Initialize torch.distributed process group if not already done # This is a collective operation - sender and all workers must call it if not torch.distributed.is_initialized(): - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Initializing process group on worker {worker_idx} " - f"(world_size={self._num_workers + 1})" - ) torch.distributed.init_process_group( backend=self.backend, rank=worker_idx, @@ -559,28 +503,14 @@ def _setup_connection_and_weights_on_receiver_impl( ) return - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for STATELESS_MODEL key" - ) stateless_model = self.receiver_transport._store.get("STATELESS_MODEL") if stateless_model not in (b"0", b"1"): raise RuntimeError(f"Invalid STATELESS_MODEL value: {stateless_model}") - if stateless_model == b"1": - torchrl_logger.debug( - "DistributedWeightSyncScheme: Skipping initial weight sync on receiver because of stateless model." - ) - else: - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for initial weights" - ) - + if stateless_model != b"1": # Receive initial weights (blocking, no TCPStore coordination) weights = self._receiver_transport.receive_initial_weights() if weights is not None and self.model is not None: self._strategy.apply_weights(self.model, weights, inplace=False) - torchrl_logger.debug( - f"DistributedWeightSyncScheme: Worker {worker_idx} received and applied initial weights" - ) # Start background receiver thread AFTER initial weight sync is complete # This prevents the background thread from consuming the initial sync messages @@ -610,12 +540,9 @@ def model(self) -> Any | None: model = _resolve_model(self.context, self._model_id) if model is None: if self._model_id == "policy": - torchrl_logger.debug( - f"Creating policy from factory and setting in collector {type(self.context)}" - ) + torchrl_logger.debug("Creating policy from factory.") model = self.context.policy_factory[0]() self.context.policy = model - torchrl_logger.debug(f"{self.context.policy=}") else: raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" @@ -674,18 +601,15 @@ def send_weights(self, weights: Any) -> None: return # Instruct worker to expect weight update - torchrl_logger.debug("RANK 0 -- Setting weight sync instructions to store") self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed - torchrl_logger.debug(f"RANK 0 -- Send {type(weights)=} to rank {self._rank}") if self._sync: weights.send(self._rank) else: weights.isend(self._rank) # Wait for acknowledgment - torchrl_logger.debug("RANK 0 -- Receiving acknowledgement from store") status = self._store.get(f"NODE_{self._rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'updated' but got status {status}.") @@ -700,20 +624,13 @@ def send_weights_async(self, weights: Any) -> None: return # Instruct worker to expect weight update - torchrl_logger.debug( - f"RANK 0 -- Setting weight sync instructions to store for rank {self._rank}" - ) self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed - torchrl_logger.debug( - f"RANK 0 -- Send {type(weights)=} to rank {self._rank} with sync={self._sync}" - ) if self._sync: weights.send(self._rank) else: weights.isend(self._rank) - torchrl_logger.debug(f"RANK 0 -- Weights successfully sent to {self._rank}") def wait_ack(self) -> None: """Wait for acknowledgment from distributed worker.""" @@ -760,16 +677,11 @@ def receive_weights( if self._sync or timeout is None: # Blocking receive - no timeout support if self._sync: - torchrl_logger.debug(f"Rank {self._rank} -- calling recv") weights_buffer.recv(src=0) else: - torchrl_logger.debug(f"Rank {self._rank} -- calling irecv") weights_buffer.irecv(src=0) else: # Non-blocking receive with timeout support - torchrl_logger.debug( - f"Rank {self._rank} -- calling irecv with premature return" - ) futures = weights_buffer.irecv(src=0, return_premature=True) if futures: start_time = time.monotonic() @@ -790,7 +702,6 @@ def receive_weights( if model is not None and strategy is not None: strategy.apply_weights(model, weights_buffer) - torchrl_logger.debug(f"Rank {self._rank} -- closing receive_weights") return weights_buffer def send_initial_weights(self, weights: Any) -> None: @@ -802,9 +713,6 @@ def send_initial_weights(self, weights: Any) -> None: if self._rank is None: return - torchrl_logger.debug( - f"DistributedTransport: Sending initial weights to rank {self._rank}" - ) # Note: No TCPStore signaling for initial sync - just direct send/recv if self._sync: weights.send(self._rank) @@ -820,9 +728,6 @@ def receive_initial_weights(self) -> Any: Returns: The received weights TensorDict. """ - torchrl_logger.debug( - "DistributedTransport: Receiving initial weights from rank 0" - ) if self._sync: self._weights_buffer.recv(src=0) else: diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 0a8f3734b4a..80c24ce40b2 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -88,6 +88,7 @@ def _init_on_sender_impl( devices: list[torch.device] | None = None, device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, num_workers: int | None = None, + ctx: Any = None, **kwargs, ) -> None: """Initialize on the main process (sender side). @@ -119,6 +120,7 @@ def _init_on_sender_impl( Allows full control over device mapping. Requires num_workers. num_workers: Number of workers. Required with device_map_fn, inferred from devices length otherwise. + ctx: The multiprocessing context to use. Defaults to `multiprocessing.get_context()`. **kwargs: Reserved for future use. Examples: @@ -186,15 +188,17 @@ def _init_on_sender_impl( if not hasattr(self, "_weight_init_queues"): self._weight_init_queues = {} + if ctx is None: + ctx = mp.get_context() for worker_idx in all_workers: if worker_idx not in self._weight_init_queues: - self._weight_init_queues[worker_idx] = mp.Queue() + self._weight_init_queues[worker_idx] = ctx.Queue() # Create instruction queues for background receiver if worker_idx not in self._instruction_queues: - self._instruction_queues[worker_idx] = mp.Queue() + self._instruction_queues[worker_idx] = ctx.Queue() # Create ack queues for synchronous mode if worker_idx not in self._ack_queues: - self._ack_queues[worker_idx] = mp.Queue() + self._ack_queues[worker_idx] = ctx.Queue() # Store model_id and context on scheme self.model_id = model_id @@ -294,8 +298,6 @@ def send( Note: If sync=True (default), this is a blocking call that ensures specified workers are updated before returning. """ - from torchrl._utils import logger as torchrl_logger - if not self.initialized_on_sender: raise RuntimeError("Must be initialized on sender before sending weights") if not self.synchronized_on_sender: @@ -315,7 +317,6 @@ def send( transports = list(self._iterate_transports(worker_ids)) # Send weights to all workers first via queue (non-blocking) - torchrl_logger.debug("Sending weights to queues") for transport in transports: if hasattr(transport, "send_weights_async"): # For MPTransport, pass model_id; other transports don't need it @@ -325,12 +326,10 @@ def send( transport.send_weights(prepared_weights) # Send instruction to workers' background threads to receive the weights - torchrl_logger.debug("Sending 'receive' instruction to workers") self._send_instruction(instruction="receive", worker_ids=worker_ids) # Wait for all acknowledgments if in synchronous mode if self.sync: - torchrl_logger.debug("Waiting for acknowledgments from workers") self._wait_for_ack(worker_ids=worker_ids) def _setup_connection_and_weights_on_sender_impl( @@ -415,8 +414,6 @@ def _setup_connection_and_weights_on_receiver_impl( Args: worker_idx: The worker index. """ - from torchrl._utils import logger as torchrl_logger - # Use stored worker_idx if not provided if worker_idx is None: worker_idx = self._worker_idx @@ -444,9 +441,6 @@ def _setup_connection_and_weights_on_receiver_impl( # Apply weights to model if weights is not None and self.model is not None: self._strategy.apply_weights(self.model, weights, inplace=False) - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Worker {worker_idx} applied initial weights" - ) # Start background receiver thread self._start_background_receiver() @@ -463,9 +457,6 @@ def _background_receive_loop(self): """ from torchrl._utils import logger as torchrl_logger - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Background receiver started for worker {self._worker_idx}" - ) while not self._stop_event.is_set(): try: instruction = self._wait_for_instruction() @@ -473,10 +464,6 @@ def _background_receive_loop(self): # Stop event was set or timeout continue if instruction == "receive": - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Worker {self._worker_idx} received 'receive' instruction" - ) - # Receive weights from transport (blocking) if self._receiver_transport is not None: weights = self._receiver_transport.receive_weights( @@ -485,18 +472,11 @@ def _background_receive_loop(self): ) if weights is not None: - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Worker {self._worker_idx} received and applied weights" - ) - # Cascade weight update to sub-collectors if context supports it model_id = self._model_id or "policy" if self.context is not None and hasattr( self.context, "update_policy_weights_" ): - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Cascading weight update to sub-collectors for {model_id=}" - ) self.context.update_policy_weights_( model_id=model_id, policy_or_weights=weights ) @@ -505,9 +485,6 @@ def _background_receive_loop(self): self._send_ack("updated") elif instruction == "stop": - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Worker {self._worker_idx} received 'stop' instruction" - ) break else: torchrl_logger.warning( @@ -519,10 +496,6 @@ def _background_receive_loop(self): f"MultiProcessWeightSyncScheme: Background receiver error: {e}" ) - torchrl_logger.debug( - f"MultiProcessWeightSyncScheme: Background receiver stopped for worker {self._worker_idx}" - ) - def create_transport(self, **kwargs) -> TransportBackend: """Create an MPTransport using the provided queue. diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index 73c3a6894c8..a2ad64328d8 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -159,7 +159,6 @@ def send_weights(self, weights: Any) -> None: future = self._remote_actor._receive_weights_scheme.remote() # Step 2: Send weights via torch.distributed (async) - torchrl_logger.debug(f"RayTransport: Sending weights to rank {self._rank}") weights.isend(dst=self._rank) # Step 3: Wait for the Ray call to complete (receiver has applied weights) @@ -177,14 +176,10 @@ def send_weights_async(self, weights: Any) -> None: return # Step 1: Signal the actor via Ray to start receiving (async) - torchrl_logger.debug( - f"RayTransport: Sending weights async to rank {self._rank}" - ) self._pending_future = self._remote_actor._receive_weights_scheme.remote() # Step 2: Send weights via torch.distributed (async) self._pending_isend = weights.isend(dst=self._rank, return_early=True) - torchrl_logger.debug("RayTransport: Async send initiated") def wait_ack(self) -> None: """Wait for Ray actor to finish applying weights. @@ -194,13 +189,7 @@ def wait_ack(self) -> None: was not called before this method). """ if self._pending_future is not None: - torchrl_logger.debug( - f"RayTransport: Waiting for ack from rank {self._rank}" - ) self.ray.get(self._pending_future) - torchrl_logger.debug( - f"RayTransport: Ack received from rank {self._rank}. Waiting for isend to complete." - ) if self._pending_isend is not None: for fut in self._pending_isend: fut.wait() @@ -251,10 +240,6 @@ def receive_weights( self._weights_buffer = weights_buffer # Receive weights from rank 0 - torchrl_logger.debug( - f"RayTransport: Receiving weights from rank 0: {type(weights_buffer)=}" - ) - if timeout is None: # Blocking receive weights_buffer.irecv(src=0) @@ -272,9 +257,6 @@ def receive_weights( elapsed = time.monotonic() - start_time if elapsed >= timeout: # Timeout expired before receiving all weights - torchrl_logger.debug( - f"RayTransport: Timeout ({timeout}s) expired waiting for weights" - ) return None # Small sleep to avoid busy-waiting time.sleep(0.001) @@ -285,7 +267,6 @@ def receive_weights( raise RuntimeError( f"Cannot cast weights to model type: {type(model)} with weights: {weights_buffer}." ) - torchrl_logger.debug("RayTransport: No weights to apply to model") return None if strategy is not None: @@ -293,7 +274,6 @@ def receive_weights( else: weights_buffer.to_module(model) - torchrl_logger.debug("RayTransport: Weights applied to model") return weights_buffer # ======================================================================== @@ -359,10 +339,6 @@ def setup_connection_and_weights_on_receiver( except ValueError: i += 1 time.sleep(0.1) - if i % 50 == 0: - torchrl_logger.debug( - f"RayTransport: Waiting for connection info (attempt {i}) on {worker_idx=}/{rank=}" - ) continue break @@ -374,11 +350,6 @@ def setup_connection_and_weights_on_receiver( ) self._stateful_model = stateful_model - torchrl_logger.debug( - f"RayTransport: Worker {worker_idx} joining process group with " - f"rank={rank}, master_addr={master_addr}, master_port={master_port} -- blocking" - ) - # Set environment variables for torch.distributed os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(master_port) @@ -389,7 +360,6 @@ def setup_connection_and_weights_on_receiver( rank=rank, world_size=world_size, ) - torchrl_logger.debug(f"RayTransport: Worker {worker_idx} joined process group") self._dist_initialized = True # Receive initial weights if model is stateful @@ -462,12 +432,9 @@ def model(self) -> Any | None: model = _resolve_model(self.context, self._model_id) if model is None: if self._model_id == "policy": - torchrl_logger.debug( - f"Creating policy from factory and setting in collector {type(self.context)}" - ) + torchrl_logger.debug("Creating policy from factory.") model = self.context.policy_factory[0]() self.context.policy = model - torchrl_logger.debug(f"{self.context.policy=}") else: raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" @@ -666,11 +633,6 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: master_port = self._find_free_port() world_size = self._num_workers + 1 # +1 for the sender (rank 0) - torchrl_logger.debug( - f"RayWeightSyncScheme: Setting up distributed connection with " - f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" - ) - try: self.weights stateful_model = True @@ -697,9 +659,6 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: # Note: Workers will call init_process_group in their transport's # setup_connection_and_weights_on_receiver. The init_process_group is # a collective operation, so all ranks must call it together. - torchrl_logger.debug( - "RayWeightSyncScheme: Initializing process group on sender (rank 0) -- blocking." - ) torch.distributed.init_process_group( backend=self._backend, rank=0, @@ -708,10 +667,6 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: ) self._dist_initialized = True - torchrl_logger.debug( - "RayWeightSyncScheme: Distributed connection setup complete -- all workers at rendez-vous" - ) - def _setup_connection_and_weights_on_sender_impl( self, *, @@ -734,9 +689,6 @@ def _setup_connection_and_weights_on_sender_impl( """ # Set up distributed connection (with wait for workers to be ready) if not self._dist_initialized: - torchrl_logger.debug( - "RayWeightSyncScheme: Setting up distributed connection (sender)" - ) self._setup_distributed_connection_sender() # Send the initial weights @@ -758,7 +710,6 @@ def _send_weights_distributed(self) -> None: futures = [] for worker_idx in range(self._num_workers): rank = worker_idx + 1 - torchrl_logger.debug(f"RayWeightSyncScheme: Sending weights to rank {rank}") futures.extend(weights.isend(dst=rank, return_early=True)) # Wait for all sends to complete for future in futures: @@ -1001,11 +952,6 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: master_port = self._find_free_port() world_size = 2 # Sender (rank 0) + Transform (rank 1) - torchrl_logger.debug( - f"RayModuleTransformScheme: Setting up distributed connection with " - f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" - ) - # Check if model has weights try: w = self.weights @@ -1031,9 +977,6 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: # Now initialize process group on sender (rank 0) # The receiver is concurrently joining via the Ray call above - torchrl_logger.debug( - "RayModuleTransformScheme: Initializing process group on sender (rank 0) -- blocking." - ) torch.distributed.init_process_group( backend=self._backend, rank=0, @@ -1042,10 +985,6 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: ) self._dist_initialized = True - torchrl_logger.debug( - "RayModuleTransformScheme: Distributed connection setup complete" - ) - def _setup_connection_and_weights_on_sender_impl( self, *, @@ -1060,26 +999,16 @@ def _setup_connection_and_weights_on_sender_impl( weights (optional): Pre-extracted weights to send. If None, weights are extracted from the model. """ - torchrl_logger.debug( - "RayModuleTransformScheme: Signaling receiver to join process group" - ) receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote( scheme=self, model_id=self.model_id ) if not self._dist_initialized: - torchrl_logger.debug( - "RayModuleTransformScheme: Setting up distributed connection (sender)" - ) self._setup_distributed_connection_sender() if self._stateful_model: - torchrl_logger.debug( - "RayModuleTransformScheme: Sending first batch of weights (sender)" - ) self._send_weights_distributed(weights=weights) - torchrl_logger.debug("Waiting for receiver to join process group...") self.ray.get(receiver_future) def _send_weights_distributed(self, weights: Any | None = None) -> None: @@ -1098,7 +1027,6 @@ def _send_weights_distributed(self, weights: Any | None = None) -> None: raise RuntimeError("No weights available to send") # Send weights to the transform (rank 1) - torchrl_logger.debug("RayModuleTransformScheme: Sending weights to rank 1") futures = weights.isend(dst=1, return_early=True) for future in futures: future.wait() diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 7bc829599c5..bf12358f1db 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -93,12 +93,9 @@ def model(self) -> Any | None: model = _resolve_model(self.context, self._model_id) if model is None: if self._model_id == "policy": - torchrl_logger.debug( - f"Creating policy from factory and setting in collector {type(self.context)}" - ) + torchrl_logger.debug("Creating policy from factory.") model = self.context.policy_factory[0]() self.context.policy = model - torchrl_logger.debug(f"{self.context.policy=}") else: raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 30f42544369..0b0fe54875f 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -83,7 +83,6 @@ def setup_connection_and_weights_on_sender(self) -> None: Each worker reads from its own dedicated queue, to avoid race conditions. """ - torchrl_logger.debug("Sending shared memory weights to workers.") if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") @@ -114,9 +113,6 @@ def setup_connection_and_weights_on_receiver( Returns: The shared memory weights TensorDict. """ - torchrl_logger.debug( - f"Receiving shared memory weights from worker {worker_idx}." - ) if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") @@ -187,14 +183,9 @@ def receive_weights( """ # Apply weights to model if provided (same pattern as other transports) if model is not None and strategy is not None and weights is not None: - torchrl_logger.debug( - f"Applying shared memory weights {type(weights)=} to model {model} with {strategy=}." - ) + torchrl_logger.debug("Applying shared memory weights to model.") strategy.apply_weights(model, weights) return weights - torchrl_logger.debug( - f"Not applying shared memory weights {type(weights)=} to model {model} with {strategy=}." - ) return None def send_ack(self, message: str = "updated") -> None: @@ -257,6 +248,7 @@ def _init_on_sender_impl( devices: list[torch.device] | None = None, device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, num_workers: int | None = None, + ctx: Any = None, ) -> None: """Initialize on the main process (sender side). @@ -279,6 +271,7 @@ def _init_on_sender_impl( devices: List of devices for each worker device_map_fn: Custom function to map worker_idx and weights to device-specific weights num_workers: Number of workers (required with device_map_fn) + ctx: Multiprocessing context. Defaults to `mp.get_context()`. Examples: Simple usage with collector context (stateful policy): @@ -347,15 +340,17 @@ def _init_on_sender_impl( # Collect all unique worker indices all_workers = list(params_map.keys()) + if ctx is None: + ctx = mp.get_context() for worker_idx in all_workers: if worker_idx not in self._weight_init_queues: - self._weight_init_queues[worker_idx] = mp.Queue() + self._weight_init_queues[worker_idx] = ctx.Queue() # Create instruction queues for background receiver if worker_idx not in self._instruction_queues: - self._instruction_queues[worker_idx] = mp.Queue() + self._instruction_queues[worker_idx] = ctx.Queue() # Create ack queues for synchronous mode if worker_idx not in self._ack_queues: - self._ack_queues[worker_idx] = mp.Queue() + self._ack_queues[worker_idx] = ctx.Queue() # Set worker info in transport self.shared_transport.register_weights(params_map, self._weight_init_queues) @@ -689,13 +684,11 @@ def prepare_weights( # Update the shared memory buffer in-place so workers see the change if self._shared_transport is not None and self.shared_transport.unique_weights: - torchrl_logger.debug("Updating shared memory buffer in-place") shared_weights = self.shared_transport.unique_weights[0] # In-place update of shared memory buffer with fresh weights shared_weights.data.update_(fresh_weights.data) return shared_weights - torchrl_logger.debug("No shared transport, returning fresh weights") # If no shared transport, just return the fresh weights return fresh_weights @@ -721,9 +714,6 @@ def send( raise RuntimeError("Must be synchronized on sender before sending weights") # prepare_weights updates the shared buffer in-place - torchrl_logger.debug( - "Sending weights via shared memory -- calling prepare_weights()" - ) self.prepare_weights( weights=weights, model_id=self._model_id, @@ -732,12 +722,10 @@ def send( ) # Send instruction to workers' background threads to apply the weights - torchrl_logger.debug("Sending 'receive' instruction to workers") self._send_instruction(instruction="receive", worker_ids=worker_ids) # Wait for acknowledgments if in synchronous mode if self.sync: - torchrl_logger.debug("Waiting for acknowledgments from workers") self._wait_for_ack(worker_ids=worker_ids) @property @@ -823,9 +811,6 @@ def _background_receive_loop(self): 3. Sends an acknowledgment back to the sender 4. Repeats until stop event is set or "stop" instruction received """ - torchrl_logger.debug( - f"SharedMemWeightSyncScheme: Background receiver started for worker {self._worker_idx}" - ) while not self._stop_event.is_set(): try: instruction = self._wait_for_instruction() @@ -833,9 +818,6 @@ def _background_receive_loop(self): # Stop event was set or timeout continue if instruction == "receive": - torchrl_logger.debug( - f"SharedMemWeightSyncScheme: Worker {self._worker_idx} received 'receive' instruction" - ) # Apply the current shared memory weights to the model # The weights are already updated in shared memory by the sender if ( @@ -845,18 +827,12 @@ def _background_receive_loop(self): self._strategy.apply_weights( self.model, self._receiver_shared_weights, inplace=True ) - torchrl_logger.debug( - f"SharedMemWeightSyncScheme: Worker {self._worker_idx} applied weights" - ) # Cascade weight update to sub-collectors if context supports it model_id = self._model_id or "policy" if self.context is not None and hasattr( self.context, "update_policy_weights_" ): - torchrl_logger.debug( - f"SharedMemWeightSyncScheme: Cascading weight update to sub-collectors for {model_id=}" - ) self.context.update_policy_weights_( model_id=model_id, policy_or_weights=self._receiver_shared_weights, @@ -865,9 +841,6 @@ def _background_receive_loop(self): # Send acknowledgment self._send_ack("updated") elif instruction == "stop": - torchrl_logger.debug( - f"SharedMemWeightSyncScheme: Worker {self._worker_idx} received 'stop' instruction" - ) break else: torchrl_logger.warning( @@ -879,10 +852,6 @@ def _background_receive_loop(self): f"SharedMemWeightSyncScheme: Background receiver error: {e}" ) - torchrl_logger.debug( - f"SharedMemWeightSyncScheme: Background receiver stopped for worker {self._worker_idx}" - ) - def __getstate__(self): """Prepare the scheme for pickling.""" state = super().__getstate__() diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index b381a4db55b..1448fd2a65f 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -474,6 +474,10 @@ def init_on_receiver( context: Optional context object (e.g., inner collector) **kwargs: Alternative to context (model, etc.) """ + if self.initialized_on_sender: + # emulate pickling to erase the current state + self.__setstate__(self.__getstate__()) + self._initialized_on_receiver = True try: result = self._init_on_receiver_impl( @@ -788,7 +792,6 @@ def send( context = self.context # Let the scheme prepare the weights - torchrl_logger.debug("Preparing weights") prepared_weights = self.prepare_weights( weights=weights, model_id=self._model_id, @@ -802,22 +805,14 @@ def send( raise RuntimeError("No transports available.") # Send to all workers first (non-blocking if transport supports it) - torchrl_logger.debug(f"Sending over transports {transports}") for transport in transports: if hasattr(transport, "send_weights_async"): - torchrl_logger.debug( - f"Sending {type(prepared_weights)=} through {type(transport)=} asynchronously." - ) transport.send_weights_async(prepared_weights) else: # Fallback for transports that don't support async send - torchrl_logger.debug( - f"Sending {type(prepared_weights)=} through {type(transport)=} synchronously." - ) transport.send_weights(prepared_weights) # Wait for all acknowledgments - torchrl_logger.debug("Waiting for acknowledgement") for transport in transports: if hasattr(transport, "wait_ack"): transport.wait_ack() @@ -913,7 +908,6 @@ def receive(self, timeout: float | None = None) -> TensorDictBase | None: return None # Try to receive weights - transport handles receiving and applying - torchrl_logger.debug(f"Calling receive_weights on transport {transport}") result = transport.receive_weights( timeout=timeout, weights=self.weights, @@ -925,20 +919,15 @@ def receive(self, timeout: float | None = None) -> TensorDictBase | None: weights = result model_id = self._model_id or "policy" - torchrl_logger.debug(f"Received weights for {model_id=}") # Cascade weight update to sub-collectors if context supports it if self.context is not None and hasattr(self.context, "update_policy_weights_"): - torchrl_logger.debug( - f"Cascading weight update to sub-collectors for {model_id=}" - ) self.context.update_policy_weights_( model_id=model_id, policy_or_weights=weights ) # Send acknowledgment if transport supports it if hasattr(transport, "send_ack"): - torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") transport.send_ack("updated") return weights @@ -994,7 +983,6 @@ def connect( if self.synchronized_on_receiver or self.synchronized_on_sender: raise RuntimeError("Cannot synchronize weights on sender twice.") if self._initialized_on_sender: - torchrl_logger.debug("Synchronizing weights on sender") if worker_idx is not None: # Safety check, we can consider removing this in the future. raise RuntimeError( @@ -1007,7 +995,6 @@ def connect( self.synchronized_on_sender = False raise elif self._initialized_on_receiver: - torchrl_logger.debug(f"Synchronizing weights on receiver -- {worker_idx=}") if weights is not None: # safety check: weights are passed to sender, not receiver for initial sync raise RuntimeError( @@ -1118,9 +1105,6 @@ def _start_background_receiver(self): name=f"WeightReceiver-{self._worker_idx}", ) self._background_thread.start() - torchrl_logger.debug( - f"{type(self).__name__}: Started background receiver thread for worker {self._worker_idx}" - ) def _background_receive_loop(self): """Background thread loop that waits for instructions and receives weights.