diff --git a/test/test_modules.py b/test/test_modules.py index ff44b0266a8..4b43e3cd5ba 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -53,6 +53,7 @@ RSSMPrior, RSSMRollout, ) +from torchrl.modules.models.multiagent import MultiAgentNetBase from torchrl.modules.models.utils import SquashDims from torchrl.modules.planners.mppi import MPPIPlanner from torchrl.objectives.value import TDLambdaEstimator @@ -1010,6 +1011,89 @@ def test_multiagent_reset_mlp( .any() ) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("agent_dim", [1, -3]) + def test_multiagent_custom_agent_dim(self, share_params, agent_dim): + """Test that custom agent_dim values work correctly. + + Regression test for https://github.com/pytorch/rl/issues/3288 + """ + n_agents = 3 + obs_dim = 5 + seq_len = 6 + output_dim = 4 + + class SingleAgentMLP(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 32), + nn.Tanh(), + nn.Linear(32, out_dim), + ) + + def forward(self, x): + return self.net(x) + + class MultiAgentPolicyNet(MultiAgentNetBase): + def __init__( + self, + obs_dim, + output_dim, + n_agents, + share_params, + agent_dim, + device=None, + ): + self.obs_dim = obs_dim + self.output_dim = output_dim + self._agent_dim = agent_dim + + super().__init__( + n_agents=n_agents, + centralized=False, + share_params=share_params, + agent_dim=agent_dim, + device=device, + ) + + def _build_single_net(self, *, device, **kwargs): + net = SingleAgentMLP(self.obs_dim, self.output_dim) + return net.to(device) if device is not None else net + + def _pre_forward_check(self, inputs): + if inputs.shape[self._agent_dim] != self.n_agents: + raise ValueError( + f"Multi-agent network expected input with shape[{self._agent_dim}]={self.n_agents}," + f" but got {inputs.shape}" + ) + return inputs + + policy_net = MultiAgentPolicyNet( + obs_dim=obs_dim, + output_dim=output_dim, + n_agents=n_agents, + share_params=share_params, + agent_dim=agent_dim, + ) + + # Input shape: (batch, n_agents, seq_len, obs_dim) with agents at dim 1 + batch_size = 4 + obs = torch.randn(batch_size, n_agents, seq_len, obs_dim) + out = policy_net(obs) + + # Output should preserve agent dimension position + expected_shape = (batch_size, n_agents, seq_len, output_dim) + assert ( + out.shape == expected_shape + ), f"Expected {expected_shape}, got {out.shape}" + + # Verify different agents produce different outputs (unless share_params with same input) + if not share_params: + for i in range(n_agents): + for j in range(i + 1, n_agents): + assert not torch.allclose(out[:, i], out[:, j]) + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralized", [True, False]) diff --git a/torchrl/modules/llm/policies/transformers_wrapper.py b/torchrl/modules/llm/policies/transformers_wrapper.py index 77c759d7589..c007c2641d3 100644 --- a/torchrl/modules/llm/policies/transformers_wrapper.py +++ b/torchrl/modules/llm/policies/transformers_wrapper.py @@ -790,7 +790,7 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase: if self._device is not None: response_struct = response_struct.to(self._device) - + tokens_prompt_padded = response_struct.get( "input_ids", as_padded_tensor=True, diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index c1fd12fb34f..50311bcf50a 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -132,18 +132,32 @@ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: else: inputs = inputs[0] + # Convert agent_dim to positive index for consistent output placement. + # This ensures the agent dimension stays at the same position relative + # to batch dimensions, even if the network changes the number of dimensions + # (e.g., ConvNet collapses spatial dims). + # NOTE: Must compute this BEFORE _pre_forward_check, which may modify input shape + # (e.g., centralized mode flattens the agent dimension). + agent_dim_positive = self.agent_dim + if agent_dim_positive < 0: + agent_dim_positive = inputs.ndim + agent_dim_positive + inputs = self._pre_forward_check(inputs) + # If parameters are not shared, each agent has its own network if not self.share_params: if self.centralized: output = self.vmap_func_module( - self._empty_net, (0, None), (-2,), randomness=self.vmap_randomness + self._empty_net, + (0, None), + (agent_dim_positive,), + randomness=self.vmap_randomness, )(self.params, inputs) else: output = self.vmap_func_module( self._empty_net, - (0, self.agent_dim), - (-2,), + (0, agent_dim_positive), + (agent_dim_positive,), randomness=self.vmap_randomness, )(self.params, inputs) @@ -157,14 +171,16 @@ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor: # We expand it to maintain the agent dimension, but values will be the same for all agents n_agent_outputs = output.shape[-1] output = output.view(*output.shape[:-1], n_agent_outputs) - output = output.unsqueeze(-2) - output = output.expand( - *output.shape[:-2], self.n_agents, n_agent_outputs - ) - - if output.shape[-2] != (self.n_agents): + # Insert agent dimension at the correct position + output = output.unsqueeze(agent_dim_positive) + # Build the expanded shape + expand_shape = list(output.shape) + expand_shape[agent_dim_positive] = self.n_agents + output = output.expand(*expand_shape) + + if output.shape[agent_dim_positive] != (self.n_agents): raise ValueError( - f"Multi-agent network expected output with shape[-2]={self.n_agents}" + f"Multi-agent network expected output with shape[{agent_dim_positive}]={self.n_agents}" f" but got {output.shape}" )