diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index 2621bcf82eb..b35b3629427 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -284,6 +284,7 @@ collector.total_frames=600 \ collector.init_random_frames=10 \ collector.frames_per_batch=200 \ + collector.num_collectors=1 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ optimization.compile=False \ @@ -292,6 +293,7 @@ replay_buffer.buffer_size=120 \ replay_buffer.batch_size=24 \ replay_buffer.batch_length=12 \ + replay_buffer.prefetch=1 \ networks.rssm_hidden_dim=17 """, } diff --git a/sota-implementations/dreamer/README.md b/sota-implementations/dreamer/README.md index 94e28dc63d9..a64d899beab 100644 --- a/sota-implementations/dreamer/README.md +++ b/sota-implementations/dreamer/README.md @@ -1,7 +1,129 @@ -# Dreamer example +# Dreamer V1 -## Note: -This example is not included in the benchmarked results of the current release (v0.3). The intention is to include it in the -benchmarking of future releases, to ensure that it can be successfully run with the release code and that the -results are consistent. For now, be aware that this additional check has not been performed in the case of this -specific example. +This is an implementation of the Dreamer algorithm from the paper +["Dream to Control: Learning Behaviors by Latent Imagination"](https://arxiv.org/abs/1912.01603) (Hafner et al., ICLR 2020). + +Dreamer is a model-based reinforcement learning algorithm that: +1. Learns a **world model** (RSSM) from experience +2. **Imagines** future trajectories in latent space +3. Trains **actor and critic** using analytic gradients through the imagined rollouts + +## Setup + +### Dependencies + +```bash +# Create virtual environment +uv venv torchrl --python 3.12 +source torchrl/bin/activate + +# Install PyTorch (adjust for your CUDA version) +uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 + +# Install TorchRL and TensorDict +uv pip install tensordict torchrl + +# Install additional dependencies +uv pip install mujoco dm_control wandb tqdm hydra-core +``` + +### System Dependencies (for MuJoCo rendering) + +```bash +apt-get update && apt-get install -y \ + libegl1 \ + libgl1 \ + libgles2 \ + libglvnd0 +``` + +### Environment Variables + +```bash +export MUJOCO_GL=egl +export MUJOCO_EGL_DEVICE_ID=0 +``` + +## Running + +```bash +python dreamer.py +``` + +### Configuration + +The default configuration trains on DMControl's `cheetah-run` task. You can override settings via command line: + +```bash +# Different environment +python dreamer.py env.name=walker env.task=walk + +# Mixed precision options: false, true (=bfloat16), float16, bfloat16 +python dreamer.py optimization.autocast=bfloat16 # default +python dreamer.py optimization.autocast=float16 # for older GPUs +python dreamer.py optimization.autocast=false # disable autocast + +# Adjust batch size +python dreamer.py replay_buffer.batch_size=1000 +``` + +## Known Caveats + +### 1. Mixed Precision (Autocast) Compatibility + +Some GPU/cuBLAS combinations have issues with `bfloat16` autocast, resulting in: +``` +RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasGemmEx +``` + +**Solutions:** +- Try float16: `optimization.autocast=float16` +- Or disable autocast entirely: `optimization.autocast=false` + +Note: Ensure your PyTorch CUDA version matches your driver. For example, with CUDA 13.0: +```bash +uv pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu130 +``` + +### 2. Benchmarking Status + +This implementation has not been fully benchmarked against the original paper's results. +Performance may differ from published numbers. + +### 3. Video Logging + +To enable video logging of both real and imagined rollouts: +```bash +python dreamer.py logger.video=True +``` + +This requires additional setup for rendering and significantly increases computation time. + +## Architecture Overview + +``` +World Model: + - ObsEncoder: pixels -> encoded_latents + - RSSMPrior: (state, belief, action) -> next_belief, prior_dist + - RSSMPosterior: (belief, encoded_latents) -> posterior_dist, state + - ObsDecoder: (state, belief) -> reconstructed_pixels + - RewardModel: (state, belief) -> predicted_reward + +Actor: (state, belief) -> action_distribution +Critic: (state, belief) -> state_value +``` + +## Training Loop + +1. **Collect** real experience from environment +2. **Train world model** on sequences from replay buffer (KL + reconstruction + reward loss) +3. **Imagine** trajectories starting from encoded real states +4. **Train actor** to maximize imagined returns (gradients flow through dynamics) +5. **Train critic** to predict lambda returns on imagined trajectories + +## References + +- Original Paper: [Dream to Control: Learning Behaviors by Latent Imagination](https://arxiv.org/abs/1912.01603) +- PlaNet (predecessor): [Learning Latent Dynamics for Planning from Pixels](https://arxiv.org/abs/1811.04551) +- DreamerV2: [Mastering Atari with Discrete World Models](https://arxiv.org/abs/2010.02193) +- DreamerV3: [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 604e1ac546a..424ad685aca 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -15,6 +15,9 @@ collector: total_frames: 5_000_000 init_random_frames: 3000 frames_per_batch: 1000 + # Number of parallel collector workers (async mode) + # On multi-GPU: must be <= num_gpus - 1 (cuda:0 reserved for training) + num_collectors: 7 device: optimization: @@ -26,13 +29,18 @@ optimization: value_lr: 8e-5 kl_scale: 1.0 free_nats: 3.0 - optim_steps_per_batch: 80 + optim_steps_per_batch: 20 gamma: 0.99 lmbda: 0.95 imagination_horizon: 15 - compile: False - compile_backend: inductor - use_autocast: True + compile: + enabled: True + backend: inductor # or cudagraphs + mode: reduce-overhead + # Which losses to compile (subset of: world_model, actor, value) + losses: ["world_model", "actor", "value"] + # Autocast options: false, true (=bfloat16), float16, bfloat16 + autocast: bfloat16 networks: exploration_noise: 0.3 @@ -41,13 +49,21 @@ networks: rssm_hidden_dim: 200 hidden_dim: 400 activation: "elu" + # Use torch.scan for RSSM rollout (faster, no graph breaks with torch.compile) + use_scan: False + rssm_rollout: + # Compile only the per-timestep RSSM rollout step (keeps Python loop, avoids scan/unrolling). + compile: False + compile_backend: inductor + compile_mode: reduce-overhead replay_buffer: - batch_size: 2500 + batch_size: 10000 buffer_size: 1000000 batch_length: 50 scratch_dir: null + prefetch: 8 logger: backend: wandb @@ -58,3 +74,27 @@ logger: eval_iter: 10 eval_rollout_steps: 500 video: False + +profiling: + # Enable PyTorch profiling (overrides total_frames to profiling_total_frames) + enabled: False + # Total frames to collect when profiling (default: 5005 = 5 collection iters + buffer warmup) + total_frames: 5005 + # Skip the first N optim steps (no profiling at all) + skip_first: 1 + # Warmup steps (profiler runs but data discarded for warmup) + warmup_steps: 1 + # Number of optim steps to profile (actual traced data) + active_steps: 1 + # Export chrome trace to this file (if set) + trace_file: dreamer_trace.json + # Profile CUDA kernels (VERY heavy on GPU - 13GB vs 1GB trace!) + profile_cuda: true + # Record tensor shapes + record_shapes: True + # Profile memory usage + profile_memory: True + # Record Python call stacks + with_stack: True + # Compute FLOPs + with_flops: True diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index 60eb4d82669..09d9137d54a 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -14,20 +14,23 @@ from dreamer_utils import ( _default_device, + DreamerProfiler, dump_video, log_metrics, make_collector, make_dreamer, make_environments, make_replay_buffer, + make_storage_transform, ) +from omegaconf import DictConfig # mixed precision training from torch.amp import GradScaler +from torch.autograd.profiler import record_function from torch.nn.utils import clip_grad_norm_ -from torchrl._utils import logger as torchrl_logger, timeit +from torchrl._utils import compile_with_warmup, logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import RSSMRollout from torchrl.objectives.dreamer import ( DreamerActorLoss, DreamerModelLoss, @@ -41,6 +44,7 @@ def main(cfg: DictConfig): # noqa: F821 # cfg = correct_for_frame_skip(cfg) device = _default_device(cfg.networks.device) + assert device.type == "cuda", "Dreamer only supports CUDA devices" # Create logger exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) @@ -53,7 +57,8 @@ def main(cfg: DictConfig): # noqa: F821 wandb_kwargs={"mode": cfg.logger.mode}, # "config": cfg}, ) - train_env, test_env = make_environments( + # make_environments returns (train_env_factory, test_env) for async collection + train_env_factory, test_env = make_environments( cfg=cfg, parallel_envs=cfg.env.n_parallel_envs, logger=logger, @@ -98,45 +103,66 @@ def main(cfg: DictConfig): # noqa: F821 value_model, discount_loss=True, gamma=cfg.optimization.gamma ) - # Make collector - collector = make_collector(cfg, train_env, policy) + # Make async multi-collector (uses env factory for worker processes) + # Device allocation: cuda:0 for training, cuda:1+ for collectors (if multi-GPU) + collector = make_collector(cfg, train_env_factory, policy, training_device=device) - # Make replay buffer + # Make replay buffer with minimal sample-time transforms batch_size = cfg.replay_buffer.batch_size batch_length = cfg.replay_buffer.batch_length buffer_size = cfg.replay_buffer.buffer_size scratch_dir = cfg.replay_buffer.scratch_dir + prefetch = cfg.replay_buffer.prefetch replay_buffer = make_replay_buffer( batch_size=batch_size, batch_seq_len=batch_length, buffer_size=buffer_size, buffer_scratch_dir=scratch_dir, device=device, + prefetch=prefetch if not (profiling_enabled := cfg.profiling.enabled) else None, + pixel_obs=cfg.env.from_pixels, + grayscale=cfg.env.grayscale, + image_size=cfg.env.image_size, + ) + + # Create storage transform for extend-time processing (applied once per frame) + storage_transform = make_storage_transform( pixel_obs=cfg.env.from_pixels, grayscale=cfg.env.grayscale, image_size=cfg.env.image_size, - use_autocast=cfg.optimization.use_autocast, - compile=( - {"backend": cfg.optimization.compile_backend} - if cfg.optimization.compile - else False - ), ) # Training loop collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) - # Make optimizer + # Make optimizer (fused=True for faster GPU execution) + use_fused = device.type == "cuda" world_model_opt = torch.optim.Adam( - world_model.parameters(), lr=cfg.optimization.world_model_lr + world_model.parameters(), lr=cfg.optimization.world_model_lr, fused=use_fused + ) + actor_opt = torch.optim.Adam( + actor_model.parameters(), lr=cfg.optimization.actor_lr, fused=use_fused + ) + value_opt = torch.optim.Adam( + value_model.parameters(), lr=cfg.optimization.value_lr, fused=use_fused ) - actor_opt = torch.optim.Adam(actor_model.parameters(), lr=cfg.optimization.actor_lr) - value_opt = torch.optim.Adam(value_model.parameters(), lr=cfg.optimization.value_lr) # Grad scaler for mixed precision training https://pytorch.org/docs/stable/amp.html - use_autocast = cfg.optimization.use_autocast - if use_autocast: + # autocast can be: false, true (=bfloat16), float16, bfloat16 + autocast_cfg = cfg.optimization.autocast + if autocast_cfg in (False, "false", "False"): + autocast_dtype = None + elif autocast_cfg in (True, "true", "True", "bfloat16"): + autocast_dtype = torch.bfloat16 + elif autocast_cfg == "float16": + autocast_dtype = torch.float16 + else: + raise ValueError( + f"Invalid autocast value: {autocast_cfg}. Use false, true, float16, or bfloat16." + ) + + if autocast_dtype is not None: scaler1 = GradScaler() scaler2 = GradScaler() scaler3 = GradScaler() @@ -147,124 +173,211 @@ def main(cfg: DictConfig): # noqa: F821 eval_iter = cfg.logger.eval_iter eval_rollout_steps = cfg.logger.eval_rollout_steps - if cfg.optimization.compile: - torch._dynamo.config.capture_scalar_outputs = True - - torchrl_logger.info("Compiling") - backend = cfg.optimization.compile_backend + # Enable TensorFloat32 for better performance on Ampere+ GPUs + if device.type == "cuda": + torch.set_float32_matmul_precision("high") - def compile_rssms(module): - if isinstance(module, RSSMRollout) and not getattr( - module, "_compiled", False - ): - module._compiled = True - module.rssm_prior.module = torch.compile( - module.rssm_prior.module, backend=backend - ) - module.rssm_posterior.module = torch.compile( - module.rssm_posterior.module, backend=backend - ) + compile_cfg = cfg.optimization.compile + compile_enabled = compile_cfg.enabled + compile_losses = set(compile_cfg.losses) + if compile_enabled: + torch._dynamo.config.capture_scalar_outputs = True - world_model_loss.apply(compile_rssms) + compile_warmup = 3 + torchrl_logger.info(f"Compiling loss modules with warmup={compile_warmup}") + backend = compile_cfg.backend + mode = compile_cfg.mode + + # Note: We do NOT compile rssm_prior/rssm_posterior here because they are + # shared with the policy used in the collector. Compiling them would cause + # issues with the MultiCollector workers. + # + # Instead, we compile the loss modules themselves which wraps the forward pass. + # fullgraph=False allows graph breaks which can help with inductor issues. + # warmup=compile_warmup runs eagerly for first `compile_warmup` calls before compiling. + if "world_model" in compile_losses: + world_model_loss = compile_with_warmup( + world_model_loss, + backend=backend, + mode=mode, + fullgraph=False, + warmup=compile_warmup, + ) + if "actor" in compile_losses: + actor_loss = compile_with_warmup( + actor_loss, backend=backend, mode=mode, warmup=compile_warmup + ) + if "value" in compile_losses: + value_loss = compile_with_warmup( + value_loss, backend=backend, mode=mode, warmup=compile_warmup + ) + else: + compile_warmup = 0 + + # Throughput tracking + t_iter_start = time.time() + + # Profiling setup (encapsulated in helper class) + profiler = DreamerProfiler(cfg, device, pbar, compile_warmup=compile_warmup) - t_collect_init = time.time() for i, tensordict in enumerate(collector): - t_collect = time.time() - t_collect_init - - t_preproc_init = time.time() - pbar.update(tensordict.numel()) - current_frames = tensordict.numel() - collected_frames += current_frames - - ep_reward = tensordict.get("episode_reward")[..., -1, 0] - replay_buffer.extend(tensordict.cpu()) - t_preproc = time.time() - t_preproc_init + # Note: Collection time is implicitly measured by the collector's iteration + # The time between loop iterations that isn't training is effectively collection time + with timeit("collect/preproc"): + pbar.update(tensordict.numel()) + current_frames = tensordict.numel() + collected_frames += current_frames + + ep_reward = tensordict.get("episode_reward")[..., -1, 0] + # Apply storage transforms (ToTensorImage, Resize, GrayScale) once at extend-time + tensordict_cpu = tensordict.cpu() + if storage_transform is not None: + tensordict_cpu = storage_transform(tensordict_cpu) + replay_buffer.extend(tensordict_cpu) if collected_frames >= init_random_frames: - t_loss_actor = 0.0 - t_loss_critic = 0.0 - t_loss_model = 0.0 for _ in range(optim_steps_per_batch): # sample from replay buffer - t_sample_init = time.time() - sampled_tensordict = replay_buffer.sample().reshape(-1, batch_length) - t_sample = time.time() - t_sample_init + with timeit("train/sample"), record_function("## train/sample ##"): + sampled_tensordict = replay_buffer.sample().reshape( + -1, batch_length + ) + if profiling_enabled: + torch.cuda.synchronize() - t_loss_model_init = time.time() # update world model - with torch.autocast( - device_type=device.type, - dtype=torch.bfloat16, - ) if use_autocast else contextlib.nullcontext(): - model_loss_td, sampled_tensordict = world_model_loss( - sampled_tensordict + with timeit("train/world_model-forward"), record_function( + "## world_model/forward ##" + ): + # Mark step begin for CUDAGraph to prevent tensor overwrite issues + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast( + device_type=device.type, + dtype=autocast_dtype, + ) if autocast_dtype else contextlib.nullcontext(): + assert ( + sampled_tensordict.device.type == "cuda" + ), "sampled_tensordict should be on CUDA" + model_loss_td, sampled_tensordict = world_model_loss( + sampled_tensordict + ) + loss_world_model = ( + model_loss_td["loss_model_kl"] + + model_loss_td["loss_model_reco"] + + model_loss_td["loss_model_reward"] + ) + + with timeit("train/world_model-backward"), record_function( + "## world_model/backward ##" + ): + world_model_opt.zero_grad() + if autocast_dtype: + scaler1.scale(loss_world_model).backward() + scaler1.unscale_(world_model_opt) + else: + loss_world_model.backward() + torchrl_logger.debug("world_model_loss backward OK") + world_model_grad = clip_grad_norm_( + world_model.parameters(), grad_clip ) - loss_world_model = ( - model_loss_td["loss_model_kl"] - + model_loss_td["loss_model_reco"] - + model_loss_td["loss_model_reward"] - ) - - world_model_opt.zero_grad() - if use_autocast: - scaler1.scale(loss_world_model).backward() - scaler1.unscale_(world_model_opt) - else: - loss_world_model.backward() - world_model_grad = clip_grad_norm_(world_model.parameters(), grad_clip) - if use_autocast: - scaler1.step(world_model_opt) - scaler1.update() - else: - world_model_opt.step() - t_loss_model += time.time() - t_loss_model_init + if autocast_dtype: + scaler1.step(world_model_opt) + scaler1.update() + else: + world_model_opt.step() # update actor network - t_loss_actor_init = time.time() - with torch.autocast( - device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): - actor_loss_td, sampled_tensordict = actor_loss( - sampled_tensordict.reshape(-1) + with timeit("train/actor-forward"), record_function( + "## actor/forward ##" + ): + # Mark step begin for CUDAGraph to prevent tensor overwrite issues + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast( + device_type=device.type, dtype=autocast_dtype + ) if autocast_dtype else contextlib.nullcontext(): + actor_loss_td, sampled_tensordict = actor_loss( + sampled_tensordict.reshape(-1) + ) + + with timeit("train/actor-backward"), record_function( + "## actor/backward ##" + ): + actor_opt.zero_grad() + if autocast_dtype: + scaler2.scale(actor_loss_td["loss_actor"]).backward() + scaler2.unscale_(actor_opt) + else: + actor_loss_td["loss_actor"].backward() + torchrl_logger.debug("actor_loss backward OK") + actor_model_grad = clip_grad_norm_( + actor_model.parameters(), grad_clip ) - - actor_opt.zero_grad() - if use_autocast: - scaler2.scale(actor_loss_td["loss_actor"]).backward() - scaler2.unscale_(actor_opt) - else: - actor_loss_td["loss_actor"].backward() - actor_model_grad = clip_grad_norm_(actor_model.parameters(), grad_clip) - if use_autocast: - scaler2.step(actor_opt) - scaler2.update() - else: - actor_opt.step() - t_loss_actor += time.time() - t_loss_actor_init + if autocast_dtype: + scaler2.step(actor_opt) + scaler2.update() + else: + actor_opt.step() # update value network - t_loss_critic_init = time.time() - with torch.autocast( - device_type=device.type, dtype=torch.bfloat16 - ) if use_autocast else contextlib.nullcontext(): - value_loss_td, sampled_tensordict = value_loss(sampled_tensordict) - - value_opt.zero_grad() - if use_autocast: - scaler3.scale(value_loss_td["loss_value"]).backward() - scaler3.unscale_(value_opt) - else: - value_loss_td["loss_value"].backward() - critic_model_grad = clip_grad_norm_(value_model.parameters(), grad_clip) - if use_autocast: - scaler3.step(value_opt) - scaler3.update() - else: - value_opt.step() - t_loss_critic += time.time() - t_loss_critic_init + with timeit("train/value-forward"), record_function( + "## value/forward ##" + ): + # Mark step begin for CUDAGraph to prevent tensor overwrite issues + torch.compiler.cudagraph_mark_step_begin() + with torch.autocast( + device_type=device.type, dtype=autocast_dtype + ) if autocast_dtype else contextlib.nullcontext(): + value_loss_td, sampled_tensordict = value_loss( + sampled_tensordict + ) + + with timeit("train/value-backward"), record_function( + "## value/backward ##" + ): + value_opt.zero_grad() + if autocast_dtype: + scaler3.scale(value_loss_td["loss_value"]).backward() + scaler3.unscale_(value_opt) + else: + value_loss_td["loss_value"].backward() + torchrl_logger.debug("value_loss backward OK") + critic_model_grad = clip_grad_norm_( + value_model.parameters(), grad_clip + ) + if autocast_dtype: + scaler3.step(value_opt) + scaler3.update() + else: + value_opt.step() + + # Step profiler (returns True if profiling complete) + if profiler.step(): + break + + # Check if profiling is complete and we should exit + if profiler.should_exit(): + torchrl_logger.info("Profiling complete. Exiting training loop.") + break + + # Compute throughput metrics + t_iter_end = time.time() + iter_time = t_iter_end - t_iter_start + + # FPS: Frames (env steps) collected per second + fps = current_frames / iter_time if iter_time > 0 else 0 metrics_to_log = {"reward": ep_reward.mean().item()} if collected_frames >= init_random_frames: + # SPS: Samples (batch elements) processed per second + # Each optim step processes batch_size samples + total_samples = optim_steps_per_batch * batch_size + sps = total_samples / iter_time if iter_time > 0 else 0 + + # UPS: Updates (gradient steps) per second + # 3 updates per optim step (world_model, actor, value) + total_updates = optim_steps_per_batch * 3 + ups = total_updates / iter_time if iter_time > 0 else 0 + loss_metrics = { "loss_model_kl": model_loss_td["loss_model_kl"].item(), "loss_model_reco": model_loss_td["loss_model_reco"].item(), @@ -274,19 +387,26 @@ def compile_rssms(module): "world_model_grad": world_model_grad, "actor_model_grad": actor_model_grad, "critic_model_grad": critic_model_grad, - "t_loss_actor": t_loss_actor, - "t_loss_critic": t_loss_critic, - "t_loss_model": t_loss_model, - "t_sample": t_sample, - "t_preproc": t_preproc, - "t_collect": t_collect, + # Throughput metrics + "throughput/fps": fps, # Frames per second (collection) + "throughput/sps": sps, # Samples per second (training) + "throughput/ups": ups, # Updates per second (gradient steps) + "throughput/iter_time": iter_time, # Total iteration time + # Detailed timing from timeit (some metrics may be empty when compiling) **timeit.todict(prefix="time"), } metrics_to_log.update(loss_metrics) + else: + # During random collection phase, only log FPS + metrics_to_log["throughput/fps"] = fps + metrics_to_log["throughput/iter_time"] = iter_time if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) + # Reset timer for next iteration + t_iter_start = time.time() + policy[1].step(current_frames) collector.update_policy_weights_() # Evaluation @@ -325,16 +445,12 @@ def compile_rssms(module): if logger is not None: log_metrics(logger, eval_metrics, collected_frames) - t_collect_init = time.time() - if not test_env.is_closed: test_env.close() - if not train_env.is_closed: - train_env.close() + # Note: train envs are managed by the collector workers collector.shutdown() del test_env - del train_env del collector diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 18cee3ab07a..ccb326f7698 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -19,7 +19,8 @@ TensorDictModule, TensorDictSequential, ) -from torchrl.collectors import SyncDataCollector +from torchrl import logger as torchrl_logger +from torchrl.collectors import MultiCollector from torchrl.data import ( Composite, @@ -73,6 +74,190 @@ from torchrl.record import VideoRecorder +def allocate_collector_devices( + num_collectors: int, training_device: torch.device +) -> list[torch.device]: + """Allocate CUDA devices for collectors, reserving cuda:0 for training. + + Device allocation strategy: + - Training always uses cuda:0 + - Collectors use cuda:1, cuda:2, ..., cuda:N-1 if available + - If only 1 CUDA device, colocate training and inference on cuda:0 + - If num_collectors >= num_cuda_devices, raise an exception + + Args: + num_collectors: Number of collector workers requested + training_device: The device used for training (determines if CUDA is used) + + Returns: + List of devices for each collector worker + + Raises: + ValueError: If num_collectors >= num_cuda_devices (no device left for training) + """ + if training_device.type != "cuda": + # CPU training: all collectors on CPU + return [torch.device("cpu")] * num_collectors + + num_cuda_devices = torch.cuda.device_count() + + if num_cuda_devices == 0: + # No CUDA devices available, fall back to CPU + return [torch.device("cpu")] * num_collectors + + if num_cuda_devices == 1: + # Single GPU: colocate training and inference + torchrl_logger.info( + f"Single CUDA device available. Colocating {num_collectors} collectors " + "with training on cuda:0" + ) + return [torch.device("cuda:0")] * num_collectors + + # Multiple GPUs available + # Reserve cuda:0 for training, use cuda:1..cuda:N-1 for inference + inference_devices = num_cuda_devices - 1 # Devices available for collectors + + if num_collectors > inference_devices: + raise ValueError( + f"Requested {num_collectors} collectors but only {inference_devices} " + f"CUDA devices available for inference (cuda:1 to cuda:{num_cuda_devices - 1}). " + f"cuda:0 is reserved for training. Either reduce num_collectors to " + f"{inference_devices} or add more GPUs." + ) + + # Distribute collectors across available inference devices (round-robin) + collector_devices = [] + for i in range(num_collectors): + device_idx = (i % inference_devices) + 1 # +1 to skip cuda:0 + collector_devices.append(torch.device(f"cuda:{device_idx}")) + + device_str = ", ".join(str(d) for d in collector_devices) + torchrl_logger.info( + f"Allocated {num_collectors} collectors to devices: [{device_str}]. " + f"Training on cuda:0." + ) + + return collector_devices + + +class DreamerProfiler: + """Helper class for PyTorch profiling in Dreamer training. + + Encapsulates profiler setup, stepping, and trace export logic. + + Args: + cfg: Hydra config with profiling section. + device: Training device (used to determine CUDA profiling). + pbar: Progress bar to update total when profiling. + """ + + def __init__(self, cfg, device, pbar=None, *, compile_warmup: int = 0): + self.enabled = cfg.profiling.enabled + self.cfg = cfg + self.total_optim_steps = 0 + self._profiler = None + self._stopped = False + self._compile_warmup = compile_warmup + + if not self.enabled: + return + + # Override total_frames for profiling runs + torchrl_logger.info( + f"Profiling enabled: running {cfg.profiling.total_frames} frames " + f"(skip_first={cfg.profiling.skip_first}, warmup={cfg.profiling.warmup_steps}, " + f"active={cfg.profiling.active_steps})" + ) + if pbar is not None: + pbar.total = cfg.profiling.total_frames + + # Setup profiler schedule + # - skip_first: steps to skip entirely (no profiling) + # - warmup: steps to warm up profiler (data discarded) + # - active: steps to actually profile (data kept) + # + # When torch.compile is enabled via compile_with_warmup, the first `compile_warmup` + # calls run eagerly and the *next* call typically triggers compilation. Profiling + # these steps is usually undesirable because it captures compilation overhead and + # non-representative eager execution. + # + # Therefore we automatically extend skip_first by (compile_warmup + 1) optim steps. + extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0 + skip_first = cfg.profiling.skip_first + extra_skip + profiler_schedule = torch.profiler.schedule( + skip_first=skip_first, + wait=0, + warmup=cfg.profiling.warmup_steps, + active=cfg.profiling.active_steps, + repeat=1, + ) + + # Determine profiler activities + activities = [torch.profiler.ProfilerActivity.CPU] + if cfg.profiling.profile_cuda and device.type == "cuda": + activities.append(torch.profiler.ProfilerActivity.CUDA) + + self._profiler = torch.profiler.profile( + activities=activities, + schedule=profiler_schedule, + on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler_logs") + if not cfg.profiling.trace_file + else None, + record_shapes=cfg.profiling.record_shapes, + profile_memory=cfg.profiling.profile_memory, + with_stack=cfg.profiling.with_stack, + with_flops=cfg.profiling.with_flops, + ) + self._profiler.start() + + def step(self) -> bool: + """Step the profiler and check if profiling is complete. + + Returns: + True if profiling is complete and training should exit. + """ + if not self.enabled or self._stopped: + return False + + self.total_optim_steps += 1 + self._profiler.step() + + # Check if we should stop profiling + extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0 + target_steps = ( + self.cfg.profiling.skip_first + + extra_skip + + self.cfg.profiling.warmup_steps + + self.cfg.profiling.active_steps + ) + if self.total_optim_steps >= target_steps: + torchrl_logger.info( + f"Profiling complete after {self.total_optim_steps} optim steps. " + f"Exporting trace to {self.cfg.profiling.trace_file}" + ) + self._profiler.stop() + self._stopped = True + # Export trace if trace_file is set + if self.cfg.profiling.trace_file: + self._profiler.export_chrome_trace(self.cfg.profiling.trace_file) + return True + + return False + + def should_exit(self) -> bool: + """Check if training loop should exit due to profiling completion.""" + if not self.enabled: + return False + extra_skip = self._compile_warmup + 1 if self._compile_warmup else 0 + target_steps = ( + self.cfg.profiling.skip_first + + extra_skip + + self.cfg.profiling.warmup_steps + + self.cfg.profiling.active_steps + ) + return self.total_optim_steps >= target_steps + + def _make_env(cfg, device, from_pixels=False): lib = cfg.env.backend if lib in ("gym", "gymnasium"): @@ -129,15 +314,28 @@ def transform_env(cfg, env): def make_environments(cfg, parallel_envs=1, logger=None): - """Make environments for training and evaluation.""" - func = functools.partial(_make_env, cfg=cfg, device=_default_device(cfg.env.device)) - train_env = ParallelEnv( - parallel_envs, - EnvCreator(func), - serial_for_single=True, - ) - train_env = transform_env(cfg, train_env) - train_env.set_seed(cfg.env.seed) + """Make environments for training and evaluation. + + Returns: + train_env_factory: A callable that creates a training environment (for MultiCollector) + eval_env: The evaluation environment instance + """ + + def train_env_factory(): + """Factory function for creating training environments.""" + func = functools.partial( + _make_env, cfg=cfg, device=_default_device(cfg.env.device) + ) + train_env = ParallelEnv( + parallel_envs, + EnvCreator(func), + serial_for_single=True, + ) + train_env = transform_env(cfg, train_env) + train_env.set_seed(cfg.env.seed) + return train_env + + # Create eval env directly (not a factory) func = functools.partial( _make_env, cfg=cfg, @@ -153,9 +351,15 @@ def make_environments(cfg, parallel_envs=1, logger=None): eval_env.set_seed(cfg.env.seed + 1) if cfg.logger.video: eval_env.insert_transform(0, VideoRecorder(logger, tag="eval/video")) - check_env_specs(train_env) + + # Check specs on a temporary train env + temp_train_env = train_env_factory() + check_env_specs(temp_train_env) + temp_train_env.close() + del temp_train_env + check_env_specs(eval_env) - return train_env, eval_env + return train_env_factory, eval_env def dump_video(module): @@ -163,6 +367,17 @@ def dump_video(module): module.dump() +def _compute_encoder_output_size(image_size, channels=32, num_layers=4): + """Compute the flattened output size of ObsEncoder.""" + # Compute spatial size after each conv layer (kernel=4, stride=2) + size = image_size + for _ in range(num_layers): + size = (size - 4) // 2 + 1 + # Final channels = channels * 2^(num_layers-1) + final_channels = channels * (2 ** (num_layers - 1)) + return final_channels * size * size + + def make_dreamer( cfg, device, @@ -174,47 +389,89 @@ def make_dreamer( ): test_env = _make_env(cfg, device="cpu") test_env = transform_env(cfg, test_env) + + # Get dimensions for explicit module instantiation (avoids lazy modules) + state_dim = cfg.networks.state_dim + rssm_hidden_dim = cfg.networks.rssm_hidden_dim + action_dim = test_env.action_spec.shape[-1] + # Make encoder and decoder if cfg.env.from_pixels: - encoder = ObsEncoder() - decoder = ObsDecoder() + # Determine input channels (1 for grayscale, 3 for RGB) + in_channels = 1 if cfg.env.grayscale else 3 + image_size = cfg.env.image_size + + # Compute encoder output size for explicit posterior input + obs_embed_dim = _compute_encoder_output_size( + image_size, channels=32, num_layers=4 + ) + + encoder = ObsEncoder(in_channels=in_channels, device=device) + decoder = ObsDecoder(latent_dim=state_dim + rssm_hidden_dim, device=device) + observation_in_key = "pixels" observation_out_key = "reco_pixels" else: + obs_embed_dim = 1024 # MLP output size encoder = MLP( - out_features=1024, + out_features=obs_embed_dim, depth=2, num_cells=cfg.networks.hidden_dim, activation_class=get_activation(cfg.networks.activation), + device=device, ) decoder = MLP( out_features=test_env.observation_spec["observation"].shape[-1], depth=2, num_cells=cfg.networks.hidden_dim, activation_class=get_activation(cfg.networks.activation), + device=device, ) + observation_in_key = "observation" observation_out_key = "reco_observation" - # Make RSSM + # Make RSSM with explicit input sizes (no lazy modules) rssm_prior = RSSMPrior( - hidden_dim=cfg.networks.rssm_hidden_dim, - rnn_hidden_dim=cfg.networks.rssm_hidden_dim, - state_dim=cfg.networks.state_dim, + hidden_dim=rssm_hidden_dim, + rnn_hidden_dim=rssm_hidden_dim, + state_dim=state_dim, action_spec=test_env.action_spec, + action_dim=action_dim, + device=device, ) rssm_posterior = RSSMPosterior( - hidden_dim=cfg.networks.rssm_hidden_dim, state_dim=cfg.networks.state_dim + hidden_dim=rssm_hidden_dim, + state_dim=state_dim, + rnn_hidden_dim=rssm_hidden_dim, + obs_embed_dim=obs_embed_dim, + device=device, ) + + # When use_scan=True or rssm_rollout.compile=True, replace C++ GRU with Python-based GRU + # for torch.compile compatibility. The C++ GRU (cuBLAS) cannot be traced by torch.compile. + if cfg.networks.use_scan or cfg.networks.rssm_rollout.compile: + from torchrl.modules.tensordict_module.rnn import GRUCell as PythonGRUCell + + old_rnn = rssm_prior.rnn + python_rnn = PythonGRUCell( + old_rnn.input_size, old_rnn.hidden_size, device=device + ) + python_rnn.load_state_dict(old_rnn.state_dict()) + rssm_prior.rnn = python_rnn + torchrl_logger.info( + "Switched RSSMPrior to Python-based GRU for torch.compile compatibility" + ) # Make reward module reward_module = MLP( out_features=1, depth=2, num_cells=cfg.networks.hidden_dim, activation_class=get_activation(cfg.networks.activation), + device=device, ) - # Make combined world model + # Make combined world model (modules already on device) world_model = _dreamer_make_world_model( encoder, decoder, @@ -223,15 +480,16 @@ def make_dreamer( reward_module, observation_in_key=observation_in_key, observation_out_key=observation_out_key, + use_scan=cfg.networks.use_scan, + rssm_rollout_compile=cfg.networks.rssm_rollout.compile, + rssm_rollout_compile_backend=cfg.networks.rssm_rollout.compile_backend, + rssm_rollout_compile_mode=cfg.networks.rssm_rollout.compile_mode, ) - world_model.to(device) - # Initialize world model + # Initialize world model (already on device) with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): tensordict = ( - test_env.rollout(5, auto_cast_to_device=True) - .unsqueeze(-1) - .to(world_model.device) + test_env.rollout(5, auto_cast_to_device=True).unsqueeze(-1).to(device) ) tensordict = tensordict.to_tensordict() world_model(tensordict) @@ -256,7 +514,7 @@ def make_dreamer( # model_based_env = model_based_env.append_transform(detach_state_and_belief) check_env_specs(model_based_env) - # Make actor + # Make actor (modules already on device) actor_simulator, actor_realworld = _dreamer_make_actors( encoder=encoder, observation_in_key=observation_in_key, @@ -266,6 +524,7 @@ def make_dreamer( activation=get_activation(cfg.networks.activation), action_key=action_key, test_env=test_env, + device=device, ) # Exploration noise to be added to the actor_realworld actor_realworld = TensorDictSequential( @@ -281,16 +540,15 @@ def make_dreamer( ), ) - # Make Critic + # Make Critic (on device) value_model = _dreamer_make_value_model( hidden_dim=cfg.networks.hidden_dim, activation=cfg.networks.activation, value_key=value_key, + device=device, ) - actor_simulator.to(device) - value_model.to(device) - actor_realworld.to(device) + # Move model_based_env to device (it contains references to modules already on device) model_based_env.to(device) # Initialize model-based environment, actor and critic @@ -332,23 +590,75 @@ def float_to_int(data): ) -def make_collector(cfg, train_env, actor_model_explore): - """Make collector.""" - collector = SyncDataCollector( - train_env, - actor_model_explore, - init_random_frames=cfg.collector.init_random_frames, +def make_collector( + cfg, train_env_factory, actor_model_explore, training_device: torch.device +): + """Make async multi-collector for parallel data collection. + + Args: + cfg: Configuration object + train_env_factory: A callable that creates a training environment + actor_model_explore: The exploration policy + training_device: Device used for training (used to allocate collector devices) + + Returns: + MultiCollector in async mode with multiple worker processes + + Device allocation: + - If training on CUDA with multiple GPUs: collectors use cuda:1, cuda:2, etc. + - If training on CUDA with single GPU: collectors colocate on cuda:0 + - If training on CPU: collectors use CPU + """ + num_collectors = cfg.collector.num_collectors + + # Allocate devices for collectors (reserves cuda:0 for training if multi-GPU) + collector_devices = allocate_collector_devices(num_collectors, training_device) + + collector = MultiCollector( + create_env_fn=[train_env_factory] * num_collectors, + policy=actor_model_explore, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - policy_device=_default_device(cfg.collector.device), - env_device=train_env.device, + init_random_frames=cfg.collector.init_random_frames, + policy_device=collector_devices, storing_device="cpu", + sync=False, # Async mode for overlapping collection with training + update_at_each_batch=True, ) collector.set_seed(cfg.env.seed) return collector +def make_storage_transform( + *, + pixel_obs=True, + grayscale=True, + image_size, +): + """Create transforms to be applied at extend-time (once per frame). + + These heavy transforms (ToTensorImage, GrayScale, Resize) are applied once + when data is added to the buffer, rather than on every sample. + """ + if not pixel_obs: + return None + + storage_transforms = Compose( + ExcludeTransform("pixels", ("next", "pixels"), inverse=True), + ToTensorImage( + in_keys=["pixels_int", ("next", "pixels_int")], + out_keys=["pixels", ("next", "pixels")], + ), + ) + if grayscale: + storage_transforms.append(GrayScale(in_keys=["pixels", ("next", "pixels")])) + storage_transforms.append( + Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) + ) + return storage_transforms + + def make_replay_buffer( *, batch_size, @@ -356,39 +666,30 @@ def make_replay_buffer( buffer_size=1000000, buffer_scratch_dir=None, device=None, - prefetch=3, + prefetch=8, pixel_obs=True, grayscale=True, image_size, - use_autocast, - compile: bool | dict = False, ): + """Create replay buffer with minimal sample-time transforms. + + Heavy image transforms are expected to be applied at extend-time using + make_storage_transform(). Only DeviceCastTransform is applied at sample-time. + + Note: We don't compile the SliceSampler because: + 1. Sampler operations (index computation) happen on CPU and are already fast + 2. torch.compile with inductor has bugs with the sampler's vectorized int64 operations + """ with ( tempfile.TemporaryDirectory() if buffer_scratch_dir is None else nullcontext(buffer_scratch_dir) ) as scratch_dir: - transforms = Compose() - if pixel_obs: - - def check_no_pixels(data): - assert "pixels" not in data.keys() - return data - - transforms = Compose( - ExcludeTransform("pixels", ("next", "pixels"), inverse=True), - check_no_pixels, # will be called only during forward - ToTensorImage( - in_keys=["pixels_int", ("next", "pixels_int")], - out_keys=["pixels", ("next", "pixels")], - ), - ) - if grayscale: - transforms.append(GrayScale(in_keys=["pixels", ("next", "pixels")])) - transforms.append( - Resize(image_size, image_size, in_keys=["pixels", ("next", "pixels")]) - ) - transforms.append(DeviceCastTransform(device=device)) + # Sample-time transforms: only device transfer (fast) + sample_transforms = Compose( + # Reshape on CPU before device transfer to avoid extra work / sync in the training loop. + DeviceCastTransform(device=device), + ) replay_buffer = TensorDictReplayBuffer( pin_memory=False, @@ -404,22 +705,26 @@ def check_no_pixels(data): strict_length=False, traj_key=("collector", "traj_ids"), cache_values=True, - compile=compile, + # Don't compile the sampler - inductor has C++ codegen bugs for int64 ops ), - transform=transforms, + transform=sample_transforms, batch_size=batch_size, ) return replay_buffer def _dreamer_make_value_model( - hidden_dim: int = 400, activation: str = "elu", value_key: str = "state_value" + hidden_dim: int = 400, + activation: str = "elu", + value_key: str = "state_value", + device=None, ): value_model = MLP( out_features=1, depth=3, num_cells=hidden_dim, activation_class=get_activation(activation), + device=device, ) value_model = ProbabilisticTensorDictSequential( TensorDictModule( @@ -447,12 +752,14 @@ def _dreamer_make_actors( activation, action_key, test_env, + device=None, ): actor_module = DreamerActor( out_features=test_env.action_spec.shape[-1], depth=3, num_cells=mlp_num_units, activation_class=activation, + device=device, ) actor_simulator = _dreamer_make_actor_sim(action_key, test_env, actor_module) actor_realworld = _dreamer_make_actor_real( @@ -633,28 +940,50 @@ def _dreamer_make_world_model( reward_module, observation_in_key: NestedKey = "pixels", observation_out_key: NestedKey = "reco_pixels", + use_scan: bool = False, + rssm_rollout_compile: bool = False, + rssm_rollout_compile_backend: str = "inductor", + rssm_rollout_compile_mode: str | None = "reduce-overhead", ): # World Model and reward model + # Note: in_keys uses dict form with out_to_in_map=True to map function args to tensordict keys. + # {"noise": "prior_noise"} means: read "prior_noise" from tensordict, pass as `noise` kwarg. + # With strict=False (default), missing noise keys pass None to the module. rssm_rollout = RSSMRollout( TensorDictModule( rssm_prior, - in_keys=["state", "belief", "action"], + in_keys={ + "state": "state", + "belief": "belief", + "action": "action", + "noise": "prior_noise", + }, out_keys=[ ("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief"), ], + out_to_in_map=True, ), TensorDictModule( rssm_posterior, - in_keys=[("next", "belief"), ("next", "encoded_latents")], + in_keys={ + "belief": ("next", "belief"), + "obs_embedding": ("next", "encoded_latents"), + "noise": "posterior_noise", + }, out_keys=[ ("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state"), ], + out_to_in_map=True, ), + use_scan=use_scan, + compile_step=rssm_rollout_compile, + compile_backend=rssm_rollout_compile_backend, + compile_mode=rssm_rollout_compile_mode, ) event_dim = 3 if observation_out_key == "reco_pixels" else 1 # 3 for RGB decoder = ProbabilisticTensorDictSequential( diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 3e2035f1cd9..cc713fb56b2 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -18,7 +18,6 @@ # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell -from torchrl._utils import timeit from torchrl.modules.models.models import MLP @@ -98,7 +97,9 @@ class ObsEncoder(nn.Module): Defaults to None (uses default device). """ - def __init__(self, channels=32, num_layers=4, in_channels=None, depth=None, device=None): + def __init__( + self, channels=32, num_layers=4, in_channels=None, depth=None, device=None + ): if depth is not None: warnings.warn( f"The depth argument in {type(self)} will soon be deprecated and " @@ -159,7 +160,15 @@ class ObsDecoder(nn.Module): Defaults to None (uses default device). """ - def __init__(self, channels=32, num_layers=4, kernel_sizes=None, latent_dim=None, depth=None, device=None): + def __init__( + self, + channels=32, + num_layers=4, + kernel_sizes=None, + latent_dim=None, + depth=None, + device=None, + ): if depth is not None: warnings.warn( f"The depth argument in {type(self)} will soon be deprecated and " @@ -198,7 +207,11 @@ def __init__(self, channels=32, num_layers=4, kernel_sizes=None, latent_dim=None if j != num_layers - 1: layers = [ nn.ConvTranspose2d( - channels * k * 2, channels * k, kernel_sizes[-1], stride=2, device=device + channels * k * 2, + channels * k, + kernel_sizes[-1], + stride=2, + device=device, ), ] + layers kernel_sizes = kernel_sizes[:-1] @@ -207,7 +220,13 @@ def __init__(self, channels=32, num_layers=4, kernel_sizes=None, latent_dim=None else: # Use explicit ConvTranspose2d - input is always channels * 8 from state_to_latent layers = [ - nn.ConvTranspose2d(linear_out, channels * k, kernel_sizes[-1], stride=2, device=device) + nn.ConvTranspose2d( + linear_out, + channels * k, + kernel_sizes[-1], + stride=2, + device=device, + ) ] + layers self.decoder = nn.Sequential(*layers) @@ -449,7 +468,9 @@ def forward(self, state, belief, action, noise=None): dtype = action_state.dtype device_type = action_state.device.type with torch.amp.autocast(device_type=device_type, enabled=False): - belief = self.rnn(action_state.float(), belief.float() if belief is not None else None) + belief = self.rnn( + action_state.float(), belief.float() if belief is not None else None + ) belief = belief.to(dtype) if unsqueeze: belief = belief.squeeze(0) @@ -485,7 +506,15 @@ class RSSMPosterior(nn.Module): """ - def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1, rnn_hidden_dim=None, obs_embed_dim=None, device=None): + def __init__( + self, + hidden_dim=200, + state_dim=30, + scale_lb=0.1, + rnn_hidden_dim=None, + obs_embed_dim=None, + device=None, + ): super().__init__() # Use explicit Linear if both dims provided, else LazyLinear for backward compat if rnn_hidden_dim is not None and obs_embed_dim is not None: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 23a31b800a6..8766e06d745 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -139,13 +139,18 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: posterior_std = tensordict.get(("next", self.tensor_keys.posterior_std)) kl_loss = self.kl_loss( - prior_mean, prior_std, posterior_mean, posterior_std, + prior_mean, + prior_std, + posterior_mean, + posterior_std, ).unsqueeze(-1) # Ensure contiguous layout for torch.compile compatibility # The gradient from distance_loss flows back through decoder convolutions pixels = tensordict.get(("next", self.tensor_keys.pixels)).contiguous() - reco_pixels = tensordict.get(("next", self.tensor_keys.reco_pixels)).contiguous() + reco_pixels = tensordict.get( + ("next", self.tensor_keys.reco_pixels) + ).contiguous() reco_loss = distance_loss( pixels, reco_pixels, @@ -437,7 +442,9 @@ def forward(self, fake_data) -> torch.Tensor: self.value_model(tensordict_select) if self.discount_loss: - discount = self.gamma * torch.ones_like(lambda_target, device=lambda_target.device) + discount = self.gamma * torch.ones_like( + lambda_target, device=lambda_target.device + ) discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) value_loss = ( diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index c10e8a96540..1d7ded154a0 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1028,7 +1028,11 @@ def _fast_td_lambda_return_estimate( next_state_value = (~terminated).int() * next_state_value # Use torch.full to create directly on device (avoids DeviceCopy in cudagraph) - gamma_tensor = torch.full((1,), gamma, device=device) + # Handle both scalar and single-element tensor gamma + if isinstance(gamma, torch.Tensor): + gamma_tensor = gamma.to(device).view(1) + else: + gamma_tensor = torch.full((1,), gamma, device=device) gammalmbda = gamma_tensor * lmbda num_per_traj = _get_num_per_traj(done)