Skip to content

Commit

Permalink
Merge pull request #7 from zombie-einstein/refactor_batch_processing
Browse files Browse the repository at this point in the history
Refactor batch processing
  • Loading branch information
zombie-einstein committed May 13, 2023
2 parents a25c0a5 + 5d43dd7 commit b102d9b
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 51 deletions.
2 changes: 1 addition & 1 deletion examples/lstm_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, agent, hidden_states = jax_ppo.init_lstm_agent(\n",
"_, agent = jax_ppo.init_lstm_agent(\n",
" k, \n",
" params,\n",
" env.action_space().shape,\n",
Expand Down
16 changes: 4 additions & 12 deletions jax_ppo/lstm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from flax import linen

from jax_ppo.data_types import Agent, PPOParams
from jax_ppo.lstm.data_types import HiddenStates

from .policy import RecurrentActorCritic, initialise_carry

Expand All @@ -19,12 +18,11 @@ def init_lstm_agent(
observation_space_shape: typing.Tuple[int, ...],
schedule: typing.Union[float, optax._src.base.Schedule],
seq_len: int,
n_agents: int = 1,
layer_width: int = 64,
n_layers: int = 2,
n_recurrent_layers: int = 1,
activation: linen.activation = linen.tanh,
) -> typing.Tuple[jax.random.PRNGKey, Agent, HiddenStates]:
) -> typing.Tuple[jax.random.PRNGKey, Agent]:

observation_size = np.prod(observation_space_shape)

Expand All @@ -42,19 +40,13 @@ def init_lstm_agent(
learning_rate=schedule, eps=ppo_params.adam_eps
),
)
fake_args_model = jnp.zeros(
(
n_agents,
seq_len,
)
+ observation_space_shape
)
fake_args_model = jnp.zeros((seq_len,) + observation_space_shape)

hidden_states = initialise_carry(n_recurrent_layers, (n_agents,), observation_size)
hidden_states = initialise_carry(n_recurrent_layers, (), observation_size)

key, sub_key = jax.random.split(key)
params_model = policy.init(sub_key, fake_args_model, hidden_states)

agent = Agent.create(apply_fn=policy.apply, params=params_model, tx=tx)

return key, agent, hidden_states
return key, agent
6 changes: 4 additions & 2 deletions jax_ppo/lstm/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

@partial(jax.jit, static_argnames="apply_fn")
def policy(apply_fn, params, state, hidden_states: HiddenStates):
mean, log_std, value, hidden_states = apply_fn(params, state, hidden_states)
mean, log_std, value, hidden_states = jax.vmap(apply_fn, in_axes=(None, 0, 0))(
params, state, hidden_states
)
return mean, log_std, value, hidden_states


Expand All @@ -28,7 +30,7 @@ def sample_actions(
key, sub_key = jax.random.split(key)
dist = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))
actions, log_likelihood = dist.sample_and_log_prob(seed=sub_key)
return key, actions, log_likelihood, value[:, 0], hidden_states
return key, actions, log_likelihood, value, hidden_states


def max_action(agent: Agent, state, hidden_states: HiddenStates):
Expand Down
11 changes: 6 additions & 5 deletions jax_ppo/lstm/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial

import jax
import jax.numpy as jnp
from flax import linen

from jax_ppo.lstm.data_types import HiddenStates
Expand All @@ -12,8 +13,8 @@ class _LSTMLayer(linen.Module):
@partial(
linen.transforms.scan,
variable_broadcast="params",
in_axes=1,
out_axes=1,
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@linen.compact
Expand All @@ -40,8 +41,8 @@ def __call__(self, x, hidden_states: HiddenStates):

new_hidden_states = tuple(new_hidden_states)

value = x.reshape(x.shape[0], -1)
mean = x.reshape(x.shape[0], -1)
x = jnp.reshape(x, (-1,))
value, mean = x, x

for _ in range(self.n_layers):
value = linen.Dense(self.layer_width, **layer_init())(value)
Expand All @@ -57,7 +58,7 @@ def __call__(self, x, hidden_states: HiddenStates):
"log_std", linen.initializers.zeros, (self.single_action_shape,)
)

return mean, log_std, value, new_hidden_states
return mean, log_std, value[0], new_hidden_states


def initialise_carry(
Expand Down
3 changes: 1 addition & 2 deletions jax_ppo/mlp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def init_agent(
action_space_shape: typing.Tuple[int, ...],
observation_space_shape: typing.Tuple[int, ...],
schedule: typing.Union[float, optax._src.base.Schedule],
n_agents: int = 1,
layer_width: int = 64,
n_layers: int = 2,
activation: flax.linen.activation = flax.linen.tanh,
Expand All @@ -36,7 +35,7 @@ def init_agent(
),
)

fake_args_model = jnp.zeros((n_agents,) + observation_space_shape)
fake_args_model = jnp.zeros(observation_space_shape)

key, sub_key = jax.random.split(key)
params_model = policy.init(sub_key, fake_args_model)
Expand Down
4 changes: 2 additions & 2 deletions jax_ppo/mlp/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@partial(jax.jit, static_argnames="apply_fn")
def policy(apply_fn, params, state):
mean, log_std, value = apply_fn(params, state)
mean, log_std, value = jax.vmap(apply_fn, in_axes=(None, 0))(params, state)
return mean, log_std, value


Expand All @@ -20,7 +20,7 @@ def sample_actions(key: jax.random.PRNGKey, agent: Agent, state):
key, sub_key = jax.random.split(key)
dist = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))
actions, log_likelihood = dist.sample_and_log_prob(seed=sub_key)
return key, actions, log_likelihood, value[:, 0]
return key, actions, log_likelihood, value


def max_action(agent: Agent, state):
Expand Down
5 changes: 2 additions & 3 deletions jax_ppo/mlp/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class ActorCritic(linen.module.Module):
@linen.compact
def __call__(self, x):

value = x
mean = x
value, mean = x, x

for _ in range(self.n_layers):
value = linen.Dense(self.layer_width, **layer_init())(value)
Expand All @@ -35,4 +34,4 @@ def __call__(self, x):
"log_std", linen.initializers.zeros, (self.single_action_shape,)
)

return mean, log_std, value
return mean, log_std, value[0]
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def mlp_agent(key):
(N_ACTIONS,),
(N_OBS,),
0.01,
n_agents=1,
layer_width=8,
n_layers=1,
)
Expand All @@ -34,18 +33,17 @@ def mlp_agent(key):

@pytest.fixture
def recurrent_agent(key):
_, agent, hidden_states = jax_ppo.init_lstm_agent(
_, agent = jax_ppo.init_lstm_agent(
key,
jax_ppo.default_params,
(N_ACTIONS,),
(N_OBS,),
0.01,
SEQ_LEN,
n_agents=N_AGENTS,
layer_width=8,
n_layers=1,
)
return agent, hidden_states
return agent


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/mlp/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def observation():


def test_policy_output_shapes(mlp_agent, observation):
mean, log_std, value = mlp_agent.apply_fn(mlp_agent.params, observation)
assert mean.shape == (N_AGENTS, N_ACTIONS)
mean, log_std, value = mlp_agent.apply_fn(mlp_agent.params, observation[0])
assert mean.shape == (N_ACTIONS,)
assert log_std.shape == (N_ACTIONS,)
assert value.shape == (N_AGENTS, 1)
assert value.shape == ()


def test_policy_sampling_shape(key, mlp_agent, observation):
Expand Down
21 changes: 12 additions & 9 deletions tests/recurrent/test_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax.numpy as jnp
import pytest

from jax_ppo import initialise_carry
from jax_ppo.lstm import algos

from ..conftest import N_ACTIONS, N_AGENTS, N_OBS, SEQ_LEN
Expand All @@ -12,26 +13,28 @@ def observation():


def test_policy_output_shapes(recurrent_agent, observation):
agent, hidden_states = recurrent_agent
mean, log_std, value, new_hidden_states = agent.apply_fn(
agent.params, observation, hidden_states
hidden_states = initialise_carry(1, (), N_OBS)
mean, log_std, value, new_hidden_states = recurrent_agent.apply_fn(
recurrent_agent.params, observation[0], hidden_states
)
assert mean.shape == (N_AGENTS, N_ACTIONS)
assert mean.shape == (N_ACTIONS,)
assert log_std.shape == (N_ACTIONS,)
assert value.shape == (N_AGENTS, 1)
assert value.shape == ()


def test_policy_sampling_shape(key, recurrent_agent, observation):
agent, hidden_states = recurrent_agent
hidden_states = initialise_carry(1, (N_AGENTS,), N_OBS)
_, actions, log_likelihood, values, new_hidden_states = algos.sample_actions(
key, agent, observation, hidden_states
key, recurrent_agent, observation, hidden_states
)
assert actions.shape == (N_AGENTS, N_ACTIONS)
assert log_likelihood.shape == (N_AGENTS,)
assert values.shape == (N_AGENTS,)


def test_greedy_policy_sampling(recurrent_agent, observation):
agent, hidden_states = recurrent_agent
actions, new_hidden_states = algos.max_action(agent, observation, hidden_states)
hidden_states = initialise_carry(1, (N_AGENTS,), N_OBS)
actions, new_hidden_states = algos.max_action(
recurrent_agent, observation, hidden_states
)
assert actions.shape == (N_AGENTS, N_ACTIONS)
12 changes: 4 additions & 8 deletions tests/recurrent/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@


def test_policy_sampling(key, recurrent_agent, dummy_env):
agent, _ = recurrent_agent
trajectories = training.generate_samples(
env=dummy_env,
env_params=dummy_env.default_params,
agent=agent,
agent=recurrent_agent,
n_samples=N_SAMPLES,
n_agents=None,
key=key,
Expand All @@ -33,11 +32,10 @@ def test_policy_sampling(key, recurrent_agent, dummy_env):


def test_marl_policy_sampling(key, recurrent_agent, dummy_marl_env):
agent, _ = recurrent_agent
trajectories = training.generate_samples(
env=dummy_marl_env,
env_params=dummy_marl_env.default_params,
agent=agent,
agent=recurrent_agent,
n_samples=N_SAMPLES,
n_agents=N_AGENTS,
key=key,
Expand All @@ -57,12 +55,11 @@ def test_marl_policy_sampling(key, recurrent_agent, dummy_marl_env):


def test_policy_testing(key, recurrent_agent, dummy_env):
agent, _ = recurrent_agent
burn_in = 3
state_ts, reward_ts, _ = training.test_policy(
env=dummy_env,
env_params=dummy_env.default_params,
agent=agent,
agent=recurrent_agent,
n_steps=N_SAMPLES,
n_agents=None,
key=key,
Expand All @@ -80,12 +77,11 @@ def test_shape(x):


def test_marl_policy_testing(key, recurrent_agent, dummy_marl_env):
agent, _ = recurrent_agent
burn_in = 3
state_ts, reward_ts, _ = training.test_policy(
env=dummy_marl_env,
env_params=dummy_marl_env.default_params,
agent=agent,
agent=recurrent_agent,
n_steps=N_SAMPLES,
n_agents=N_AGENTS,
key=key,
Expand Down

0 comments on commit b102d9b

Please sign in to comment.