From b34e65cab83174dfa6929aa6838a2724e715bbaa Mon Sep 17 00:00:00 2001 From: zombie-einstein Date: Sat, 13 May 2023 15:38:45 +0100 Subject: [PATCH] Use vmap to apply policy to batches --- jax_ppo/lstm/agent.py | 16 ++++------------ jax_ppo/lstm/algos.py | 6 ++++-- jax_ppo/lstm/policy.py | 11 ++++++----- jax_ppo/mlp/agent.py | 3 +-- jax_ppo/mlp/algos.py | 4 ++-- jax_ppo/mlp/policy.py | 5 ++--- tests/conftest.py | 6 ++---- tests/mlp/test_policy.py | 6 +++--- tests/recurrent/test_policy.py | 21 ++++++++++++--------- tests/recurrent/test_sampling.py | 12 ++++-------- 10 files changed, 40 insertions(+), 50 deletions(-) diff --git a/jax_ppo/lstm/agent.py b/jax_ppo/lstm/agent.py index 1e4370c..e04626b 100644 --- a/jax_ppo/lstm/agent.py +++ b/jax_ppo/lstm/agent.py @@ -7,7 +7,6 @@ from flax import linen from jax_ppo.data_types import Agent, PPOParams -from jax_ppo.lstm.data_types import HiddenStates from .policy import RecurrentActorCritic, initialise_carry @@ -19,12 +18,11 @@ def init_lstm_agent( observation_space_shape: typing.Tuple[int, ...], schedule: typing.Union[float, optax._src.base.Schedule], seq_len: int, - n_agents: int = 1, layer_width: int = 64, n_layers: int = 2, n_recurrent_layers: int = 1, activation: linen.activation = linen.tanh, -) -> typing.Tuple[jax.random.PRNGKey, Agent, HiddenStates]: +) -> typing.Tuple[jax.random.PRNGKey, Agent]: observation_size = np.prod(observation_space_shape) @@ -42,19 +40,13 @@ def init_lstm_agent( learning_rate=schedule, eps=ppo_params.adam_eps ), ) - fake_args_model = jnp.zeros( - ( - n_agents, - seq_len, - ) - + observation_space_shape - ) + fake_args_model = jnp.zeros((seq_len,) + observation_space_shape) - hidden_states = initialise_carry(n_recurrent_layers, (n_agents,), observation_size) + hidden_states = initialise_carry(n_recurrent_layers, (), observation_size) key, sub_key = jax.random.split(key) params_model = policy.init(sub_key, fake_args_model, hidden_states) agent = Agent.create(apply_fn=policy.apply, params=params_model, tx=tx) - return key, agent, hidden_states + return key, agent diff --git a/jax_ppo/lstm/algos.py b/jax_ppo/lstm/algos.py index 5374ffb..ed4d493 100644 --- a/jax_ppo/lstm/algos.py +++ b/jax_ppo/lstm/algos.py @@ -12,7 +12,9 @@ @partial(jax.jit, static_argnames="apply_fn") def policy(apply_fn, params, state, hidden_states: HiddenStates): - mean, log_std, value, hidden_states = apply_fn(params, state, hidden_states) + mean, log_std, value, hidden_states = jax.vmap(apply_fn, in_axes=(None, 0, 0))( + params, state, hidden_states + ) return mean, log_std, value, hidden_states @@ -28,7 +30,7 @@ def sample_actions( key, sub_key = jax.random.split(key) dist = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std)) actions, log_likelihood = dist.sample_and_log_prob(seed=sub_key) - return key, actions, log_likelihood, value[:, 0], hidden_states + return key, actions, log_likelihood, value, hidden_states def max_action(agent: Agent, state, hidden_states: HiddenStates): diff --git a/jax_ppo/lstm/policy.py b/jax_ppo/lstm/policy.py index e7d12f5..fb2c972 100644 --- a/jax_ppo/lstm/policy.py +++ b/jax_ppo/lstm/policy.py @@ -2,6 +2,7 @@ from functools import partial import jax +import jax.numpy as jnp from flax import linen from jax_ppo.lstm.data_types import HiddenStates @@ -12,8 +13,8 @@ class _LSTMLayer(linen.Module): @partial( linen.transforms.scan, variable_broadcast="params", - in_axes=1, - out_axes=1, + in_axes=0, + out_axes=0, split_rngs={"params": False}, ) @linen.compact @@ -40,8 +41,8 @@ def __call__(self, x, hidden_states: HiddenStates): new_hidden_states = tuple(new_hidden_states) - value = x.reshape(x.shape[0], -1) - mean = x.reshape(x.shape[0], -1) + x = jnp.reshape(x, (-1,)) + value, mean = x, x for _ in range(self.n_layers): value = linen.Dense(self.layer_width, **layer_init())(value) @@ -57,7 +58,7 @@ def __call__(self, x, hidden_states: HiddenStates): "log_std", linen.initializers.zeros, (self.single_action_shape,) ) - return mean, log_std, value, new_hidden_states + return mean, log_std, value[0], new_hidden_states def initialise_carry( diff --git a/jax_ppo/mlp/agent.py b/jax_ppo/mlp/agent.py index 9288981..db7719b 100644 --- a/jax_ppo/mlp/agent.py +++ b/jax_ppo/mlp/agent.py @@ -16,7 +16,6 @@ def init_agent( action_space_shape: typing.Tuple[int, ...], observation_space_shape: typing.Tuple[int, ...], schedule: typing.Union[float, optax._src.base.Schedule], - n_agents: int = 1, layer_width: int = 64, n_layers: int = 2, activation: flax.linen.activation = flax.linen.tanh, @@ -36,7 +35,7 @@ def init_agent( ), ) - fake_args_model = jnp.zeros((n_agents,) + observation_space_shape) + fake_args_model = jnp.zeros(observation_space_shape) key, sub_key = jax.random.split(key) params_model = policy.init(sub_key, fake_args_model) diff --git a/jax_ppo/mlp/algos.py b/jax_ppo/mlp/algos.py index 4f7d7a3..3f769b8 100644 --- a/jax_ppo/mlp/algos.py +++ b/jax_ppo/mlp/algos.py @@ -11,7 +11,7 @@ @partial(jax.jit, static_argnames="apply_fn") def policy(apply_fn, params, state): - mean, log_std, value = apply_fn(params, state) + mean, log_std, value = jax.vmap(apply_fn, in_axes=(None, 0))(params, state) return mean, log_std, value @@ -20,7 +20,7 @@ def sample_actions(key: jax.random.PRNGKey, agent: Agent, state): key, sub_key = jax.random.split(key) dist = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std)) actions, log_likelihood = dist.sample_and_log_prob(seed=sub_key) - return key, actions, log_likelihood, value[:, 0] + return key, actions, log_likelihood, value def max_action(agent: Agent, state): diff --git a/jax_ppo/mlp/policy.py b/jax_ppo/mlp/policy.py index 4493fb4..4cc4ef6 100644 --- a/jax_ppo/mlp/policy.py +++ b/jax_ppo/mlp/policy.py @@ -18,8 +18,7 @@ class ActorCritic(linen.module.Module): @linen.compact def __call__(self, x): - value = x - mean = x + value, mean = x, x for _ in range(self.n_layers): value = linen.Dense(self.layer_width, **layer_init())(value) @@ -35,4 +34,4 @@ def __call__(self, x): "log_std", linen.initializers.zeros, (self.single_action_shape,) ) - return mean, log_std, value + return mean, log_std, value[0] diff --git a/tests/conftest.py b/tests/conftest.py index 07b32b2..c673d1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,6 @@ def mlp_agent(key): (N_ACTIONS,), (N_OBS,), 0.01, - n_agents=1, layer_width=8, n_layers=1, ) @@ -34,18 +33,17 @@ def mlp_agent(key): @pytest.fixture def recurrent_agent(key): - _, agent, hidden_states = jax_ppo.init_lstm_agent( + _, agent = jax_ppo.init_lstm_agent( key, jax_ppo.default_params, (N_ACTIONS,), (N_OBS,), 0.01, SEQ_LEN, - n_agents=N_AGENTS, layer_width=8, n_layers=1, ) - return agent, hidden_states + return agent @pytest.fixture diff --git a/tests/mlp/test_policy.py b/tests/mlp/test_policy.py index ce8528a..d34c390 100644 --- a/tests/mlp/test_policy.py +++ b/tests/mlp/test_policy.py @@ -12,10 +12,10 @@ def observation(): def test_policy_output_shapes(mlp_agent, observation): - mean, log_std, value = mlp_agent.apply_fn(mlp_agent.params, observation) - assert mean.shape == (N_AGENTS, N_ACTIONS) + mean, log_std, value = mlp_agent.apply_fn(mlp_agent.params, observation[0]) + assert mean.shape == (N_ACTIONS,) assert log_std.shape == (N_ACTIONS,) - assert value.shape == (N_AGENTS, 1) + assert value.shape == () def test_policy_sampling_shape(key, mlp_agent, observation): diff --git a/tests/recurrent/test_policy.py b/tests/recurrent/test_policy.py index 2fbb2ef..cf89f37 100644 --- a/tests/recurrent/test_policy.py +++ b/tests/recurrent/test_policy.py @@ -1,6 +1,7 @@ import jax.numpy as jnp import pytest +from jax_ppo import initialise_carry from jax_ppo.lstm import algos from ..conftest import N_ACTIONS, N_AGENTS, N_OBS, SEQ_LEN @@ -12,19 +13,19 @@ def observation(): def test_policy_output_shapes(recurrent_agent, observation): - agent, hidden_states = recurrent_agent - mean, log_std, value, new_hidden_states = agent.apply_fn( - agent.params, observation, hidden_states + hidden_states = initialise_carry(1, (), N_OBS) + mean, log_std, value, new_hidden_states = recurrent_agent.apply_fn( + recurrent_agent.params, observation[0], hidden_states ) - assert mean.shape == (N_AGENTS, N_ACTIONS) + assert mean.shape == (N_ACTIONS,) assert log_std.shape == (N_ACTIONS,) - assert value.shape == (N_AGENTS, 1) + assert value.shape == () def test_policy_sampling_shape(key, recurrent_agent, observation): - agent, hidden_states = recurrent_agent + hidden_states = initialise_carry(1, (N_AGENTS,), N_OBS) _, actions, log_likelihood, values, new_hidden_states = algos.sample_actions( - key, agent, observation, hidden_states + key, recurrent_agent, observation, hidden_states ) assert actions.shape == (N_AGENTS, N_ACTIONS) assert log_likelihood.shape == (N_AGENTS,) @@ -32,6 +33,8 @@ def test_policy_sampling_shape(key, recurrent_agent, observation): def test_greedy_policy_sampling(recurrent_agent, observation): - agent, hidden_states = recurrent_agent - actions, new_hidden_states = algos.max_action(agent, observation, hidden_states) + hidden_states = initialise_carry(1, (N_AGENTS,), N_OBS) + actions, new_hidden_states = algos.max_action( + recurrent_agent, observation, hidden_states + ) assert actions.shape == (N_AGENTS, N_ACTIONS) diff --git a/tests/recurrent/test_sampling.py b/tests/recurrent/test_sampling.py index c6720b6..fdb014f 100644 --- a/tests/recurrent/test_sampling.py +++ b/tests/recurrent/test_sampling.py @@ -9,11 +9,10 @@ def test_policy_sampling(key, recurrent_agent, dummy_env): - agent, _ = recurrent_agent trajectories = training.generate_samples( env=dummy_env, env_params=dummy_env.default_params, - agent=agent, + agent=recurrent_agent, n_samples=N_SAMPLES, n_agents=None, key=key, @@ -33,11 +32,10 @@ def test_policy_sampling(key, recurrent_agent, dummy_env): def test_marl_policy_sampling(key, recurrent_agent, dummy_marl_env): - agent, _ = recurrent_agent trajectories = training.generate_samples( env=dummy_marl_env, env_params=dummy_marl_env.default_params, - agent=agent, + agent=recurrent_agent, n_samples=N_SAMPLES, n_agents=N_AGENTS, key=key, @@ -57,12 +55,11 @@ def test_marl_policy_sampling(key, recurrent_agent, dummy_marl_env): def test_policy_testing(key, recurrent_agent, dummy_env): - agent, _ = recurrent_agent burn_in = 3 state_ts, reward_ts, _ = training.test_policy( env=dummy_env, env_params=dummy_env.default_params, - agent=agent, + agent=recurrent_agent, n_steps=N_SAMPLES, n_agents=None, key=key, @@ -80,12 +77,11 @@ def test_shape(x): def test_marl_policy_testing(key, recurrent_agent, dummy_marl_env): - agent, _ = recurrent_agent burn_in = 3 state_ts, reward_ts, _ = training.test_policy( env=dummy_marl_env, env_params=dummy_marl_env.default_params, - agent=agent, + agent=recurrent_agent, n_steps=N_SAMPLES, n_agents=N_AGENTS, key=key,