diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 70f9f2f4e7f..b5e1051eab5 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1556,14 +1556,30 @@ def _get_index( out_of_traj = relative_starts < 0 if out_of_traj.any(): # a negative start means sampling fewer elements + # Convert seq_length to tensor to avoid torch.compile inductor C++ codegen + # bug with mixed scalar/tensor int64 in blendv operations (see PyTorch #xyz) + seq_length_t = torch.as_tensor( + seq_length, + dtype=relative_starts.dtype, + device=relative_starts.device, + ) seq_length = torch.where( - ~out_of_traj, seq_length, seq_length + relative_starts + ~out_of_traj, seq_length_t, seq_length_t + relative_starts + ) + relative_starts = torch.where( + ~out_of_traj, relative_starts, torch.zeros_like(relative_starts) ) - relative_starts = torch.where(~out_of_traj, relative_starts, 0) if self.span[1]: out_of_traj = relative_starts + seq_length > lengths[traj_idx] if out_of_traj.any(): # a negative start means sampling fewer elements + # Convert seq_length to tensor if it's still a scalar + if not isinstance(seq_length, torch.Tensor): + seq_length = torch.as_tensor( + seq_length, + dtype=relative_starts.dtype, + device=relative_starts.device, + ) seq_length = torch.minimum( seq_length, lengths[traj_idx] - relative_starts )