Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
RSSMPrior,
RSSMRollout,
)
from torchrl.modules.models.multiagent import MultiAgentNetBase
from torchrl.modules.models.utils import SquashDims
from torchrl.modules.planners.mppi import MPPIPlanner
from torchrl.objectives.value import TDLambdaEstimator
Expand Down Expand Up @@ -1010,6 +1011,89 @@ def test_multiagent_reset_mlp(
.any()
)

@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("agent_dim", [1, -3])
def test_multiagent_custom_agent_dim(self, share_params, agent_dim):
"""Test that custom agent_dim values work correctly.

Regression test for https://github.com/pytorch/rl/issues/3288
"""
n_agents = 3
obs_dim = 5
seq_len = 6
output_dim = 4

class SingleAgentMLP(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 32),
nn.Tanh(),
nn.Linear(32, out_dim),
)

def forward(self, x):
return self.net(x)

class MultiAgentPolicyNet(MultiAgentNetBase):
def __init__(
self,
obs_dim,
output_dim,
n_agents,
share_params,
agent_dim,
device=None,
):
self.obs_dim = obs_dim
self.output_dim = output_dim
self._agent_dim = agent_dim

super().__init__(
n_agents=n_agents,
centralized=False,
share_params=share_params,
agent_dim=agent_dim,
device=device,
)

def _build_single_net(self, *, device, **kwargs):
net = SingleAgentMLP(self.obs_dim, self.output_dim)
return net.to(device) if device is not None else net

def _pre_forward_check(self, inputs):
if inputs.shape[self._agent_dim] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[{self._agent_dim}]={self.n_agents},"
f" but got {inputs.shape}"
)
return inputs

policy_net = MultiAgentPolicyNet(
obs_dim=obs_dim,
output_dim=output_dim,
n_agents=n_agents,
share_params=share_params,
agent_dim=agent_dim,
)

# Input shape: (batch, n_agents, seq_len, obs_dim) with agents at dim 1
batch_size = 4
obs = torch.randn(batch_size, n_agents, seq_len, obs_dim)
out = policy_net(obs)

# Output should preserve agent dimension position
expected_shape = (batch_size, n_agents, seq_len, output_dim)
assert (
out.shape == expected_shape
), f"Expected {expected_shape}, got {out.shape}"

# Verify different agents produce different outputs (unless share_params with same input)
if not share_params:
for i in range(n_agents):
for j in range(i + 1, n_agents):
assert not torch.allclose(out[:, i], out[:, j])

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/llm/policies/transformers_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:

if self._device is not None:
response_struct = response_struct.to(self._device)

tokens_prompt_padded = response_struct.get(
"input_ids",
as_padded_tensor=True,
Expand Down
36 changes: 26 additions & 10 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,18 +132,32 @@ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
else:
inputs = inputs[0]

# Convert agent_dim to positive index for consistent output placement.
# This ensures the agent dimension stays at the same position relative
# to batch dimensions, even if the network changes the number of dimensions
# (e.g., ConvNet collapses spatial dims).
# NOTE: Must compute this BEFORE _pre_forward_check, which may modify input shape
# (e.g., centralized mode flattens the agent dimension).
agent_dim_positive = self.agent_dim
if agent_dim_positive < 0:
agent_dim_positive = inputs.ndim + agent_dim_positive

inputs = self._pre_forward_check(inputs)

# If parameters are not shared, each agent has its own network
if not self.share_params:
if self.centralized:
output = self.vmap_func_module(
self._empty_net, (0, None), (-2,), randomness=self.vmap_randomness
self._empty_net,
(0, None),
(agent_dim_positive,),
randomness=self.vmap_randomness,
)(self.params, inputs)
else:
output = self.vmap_func_module(
self._empty_net,
(0, self.agent_dim),
(-2,),
(0, agent_dim_positive),
(agent_dim_positive,),
randomness=self.vmap_randomness,
)(self.params, inputs)

Expand All @@ -157,14 +171,16 @@ def forward(self, *inputs: tuple[torch.Tensor]) -> torch.Tensor:
# We expand it to maintain the agent dimension, but values will be the same for all agents
n_agent_outputs = output.shape[-1]
output = output.view(*output.shape[:-1], n_agent_outputs)
output = output.unsqueeze(-2)
output = output.expand(
*output.shape[:-2], self.n_agents, n_agent_outputs
)

if output.shape[-2] != (self.n_agents):
# Insert agent dimension at the correct position
output = output.unsqueeze(agent_dim_positive)
# Build the expanded shape
expand_shape = list(output.shape)
expand_shape[agent_dim_positive] = self.n_agents
output = output.expand(*expand_shape)

if output.shape[agent_dim_positive] != (self.n_agents):
raise ValueError(
f"Multi-agent network expected output with shape[-2]={self.n_agents}"
f"Multi-agent network expected output with shape[{agent_dim_positive}]={self.n_agents}"
f" but got {output.shape}"
)

Expand Down
Loading