Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor batch processing #7

Merged
merged 2 commits into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/lstm_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
"metadata": {},
"outputs": [],
"source": [
"_, agent, hidden_states = jax_ppo.init_lstm_agent(\n",
"_, agent = jax_ppo.init_lstm_agent(\n",
" k, \n",
" params,\n",
" env.action_space().shape,\n",
Expand Down
16 changes: 4 additions & 12 deletions jax_ppo/lstm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from flax import linen

from jax_ppo.data_types import Agent, PPOParams
from jax_ppo.lstm.data_types import HiddenStates

from .policy import RecurrentActorCritic, initialise_carry

Expand All @@ -19,12 +18,11 @@ def init_lstm_agent(
observation_space_shape: typing.Tuple[int, ...],
schedule: typing.Union[float, optax._src.base.Schedule],
seq_len: int,
n_agents: int = 1,
layer_width: int = 64,
n_layers: int = 2,
n_recurrent_layers: int = 1,
activation: linen.activation = linen.tanh,
) -> typing.Tuple[jax.random.PRNGKey, Agent, HiddenStates]:
) -> typing.Tuple[jax.random.PRNGKey, Agent]:

observation_size = np.prod(observation_space_shape)

Expand All @@ -42,19 +40,13 @@ def init_lstm_agent(
learning_rate=schedule, eps=ppo_params.adam_eps
),
)
fake_args_model = jnp.zeros(
(
n_agents,
seq_len,
)
+ observation_space_shape
)
fake_args_model = jnp.zeros((seq_len,) + observation_space_shape)

hidden_states = initialise_carry(n_recurrent_layers, (n_agents,), observation_size)
hidden_states = initialise_carry(n_recurrent_layers, (), observation_size)

key, sub_key = jax.random.split(key)
params_model = policy.init(sub_key, fake_args_model, hidden_states)

agent = Agent.create(apply_fn=policy.apply, params=params_model, tx=tx)

return key, agent, hidden_states
return key, agent
6 changes: 4 additions & 2 deletions jax_ppo/lstm/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

@partial(jax.jit, static_argnames="apply_fn")
def policy(apply_fn, params, state, hidden_states: HiddenStates):
mean, log_std, value, hidden_states = apply_fn(params, state, hidden_states)
mean, log_std, value, hidden_states = jax.vmap(apply_fn, in_axes=(None, 0, 0))(
params, state, hidden_states
)
return mean, log_std, value, hidden_states


Expand All @@ -28,7 +30,7 @@ def sample_actions(
key, sub_key = jax.random.split(key)
dist = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))
actions, log_likelihood = dist.sample_and_log_prob(seed=sub_key)
return key, actions, log_likelihood, value[:, 0], hidden_states
return key, actions, log_likelihood, value, hidden_states


def max_action(agent: Agent, state, hidden_states: HiddenStates):
Expand Down
11 changes: 6 additions & 5 deletions jax_ppo/lstm/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial

import jax
import jax.numpy as jnp
from flax import linen

from jax_ppo.lstm.data_types import HiddenStates
Expand All @@ -12,8 +13,8 @@ class _LSTMLayer(linen.Module):
@partial(
linen.transforms.scan,
variable_broadcast="params",
in_axes=1,
out_axes=1,
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@linen.compact
Expand All @@ -40,8 +41,8 @@ def __call__(self, x, hidden_states: HiddenStates):

new_hidden_states = tuple(new_hidden_states)

value = x.reshape(x.shape[0], -1)
mean = x.reshape(x.shape[0], -1)
x = jnp.reshape(x, (-1,))
value, mean = x, x

for _ in range(self.n_layers):
value = linen.Dense(self.layer_width, **layer_init())(value)
Expand All @@ -57,7 +58,7 @@ def __call__(self, x, hidden_states: HiddenStates):
"log_std", linen.initializers.zeros, (self.single_action_shape,)
)

return mean, log_std, value, new_hidden_states
return mean, log_std, value[0], new_hidden_states


def initialise_carry(
Expand Down
3 changes: 1 addition & 2 deletions jax_ppo/mlp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def init_agent(
action_space_shape: typing.Tuple[int, ...],
observation_space_shape: typing.Tuple[int, ...],
schedule: typing.Union[float, optax._src.base.Schedule],
n_agents: int = 1,
layer_width: int = 64,
n_layers: int = 2,
activation: flax.linen.activation = flax.linen.tanh,
Expand All @@ -36,7 +35,7 @@ def init_agent(
),
)

fake_args_model = jnp.zeros((n_agents,) + observation_space_shape)
fake_args_model = jnp.zeros(observation_space_shape)

key, sub_key = jax.random.split(key)
params_model = policy.init(sub_key, fake_args_model)
Expand Down
4 changes: 2 additions & 2 deletions jax_ppo/mlp/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@partial(jax.jit, static_argnames="apply_fn")
def policy(apply_fn, params, state):
mean, log_std, value = apply_fn(params, state)
mean, log_std, value = jax.vmap(apply_fn, in_axes=(None, 0))(params, state)
return mean, log_std, value


Expand All @@ -20,7 +20,7 @@ def sample_actions(key: jax.random.PRNGKey, agent: Agent, state):
key, sub_key = jax.random.split(key)
dist = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))
actions, log_likelihood = dist.sample_and_log_prob(seed=sub_key)
return key, actions, log_likelihood, value[:, 0]
return key, actions, log_likelihood, value


def max_action(agent: Agent, state):
Expand Down
5 changes: 2 additions & 3 deletions jax_ppo/mlp/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class ActorCritic(linen.module.Module):
@linen.compact
def __call__(self, x):

value = x
mean = x
value, mean = x, x

for _ in range(self.n_layers):
value = linen.Dense(self.layer_width, **layer_init())(value)
Expand All @@ -35,4 +34,4 @@ def __call__(self, x):
"log_std", linen.initializers.zeros, (self.single_action_shape,)
)

return mean, log_std, value
return mean, log_std, value[0]
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def mlp_agent(key):
(N_ACTIONS,),
(N_OBS,),
0.01,
n_agents=1,
layer_width=8,
n_layers=1,
)
Expand All @@ -34,18 +33,17 @@ def mlp_agent(key):

@pytest.fixture
def recurrent_agent(key):
_, agent, hidden_states = jax_ppo.init_lstm_agent(
_, agent = jax_ppo.init_lstm_agent(
key,
jax_ppo.default_params,
(N_ACTIONS,),
(N_OBS,),
0.01,
SEQ_LEN,
n_agents=N_AGENTS,
layer_width=8,
n_layers=1,
)
return agent, hidden_states
return agent


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/mlp/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def observation():


def test_policy_output_shapes(mlp_agent, observation):
mean, log_std, value = mlp_agent.apply_fn(mlp_agent.params, observation)
assert mean.shape == (N_AGENTS, N_ACTIONS)
mean, log_std, value = mlp_agent.apply_fn(mlp_agent.params, observation[0])
assert mean.shape == (N_ACTIONS,)
assert log_std.shape == (N_ACTIONS,)
assert value.shape == (N_AGENTS, 1)
assert value.shape == ()


def test_policy_sampling_shape(key, mlp_agent, observation):
Expand Down
21 changes: 12 additions & 9 deletions tests/recurrent/test_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax.numpy as jnp
import pytest

from jax_ppo import initialise_carry
from jax_ppo.lstm import algos

from ..conftest import N_ACTIONS, N_AGENTS, N_OBS, SEQ_LEN
Expand All @@ -12,26 +13,28 @@ def observation():


def test_policy_output_shapes(recurrent_agent, observation):
agent, hidden_states = recurrent_agent
mean, log_std, value, new_hidden_states = agent.apply_fn(
agent.params, observation, hidden_states
hidden_states = initialise_carry(1, (), N_OBS)
mean, log_std, value, new_hidden_states = recurrent_agent.apply_fn(
recurrent_agent.params, observation[0], hidden_states
)
assert mean.shape == (N_AGENTS, N_ACTIONS)
assert mean.shape == (N_ACTIONS,)
assert log_std.shape == (N_ACTIONS,)
assert value.shape == (N_AGENTS, 1)
assert value.shape == ()


def test_policy_sampling_shape(key, recurrent_agent, observation):
agent, hidden_states = recurrent_agent
hidden_states = initialise_carry(1, (N_AGENTS,), N_OBS)
_, actions, log_likelihood, values, new_hidden_states = algos.sample_actions(
key, agent, observation, hidden_states
key, recurrent_agent, observation, hidden_states
)
assert actions.shape == (N_AGENTS, N_ACTIONS)
assert log_likelihood.shape == (N_AGENTS,)
assert values.shape == (N_AGENTS,)


def test_greedy_policy_sampling(recurrent_agent, observation):
agent, hidden_states = recurrent_agent
actions, new_hidden_states = algos.max_action(agent, observation, hidden_states)
hidden_states = initialise_carry(1, (N_AGENTS,), N_OBS)
actions, new_hidden_states = algos.max_action(
recurrent_agent, observation, hidden_states
)
assert actions.shape == (N_AGENTS, N_ACTIONS)
12 changes: 4 additions & 8 deletions tests/recurrent/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@


def test_policy_sampling(key, recurrent_agent, dummy_env):
agent, _ = recurrent_agent
trajectories = training.generate_samples(
env=dummy_env,
env_params=dummy_env.default_params,
agent=agent,
agent=recurrent_agent,
n_samples=N_SAMPLES,
n_agents=None,
key=key,
Expand All @@ -33,11 +32,10 @@ def test_policy_sampling(key, recurrent_agent, dummy_env):


def test_marl_policy_sampling(key, recurrent_agent, dummy_marl_env):
agent, _ = recurrent_agent
trajectories = training.generate_samples(
env=dummy_marl_env,
env_params=dummy_marl_env.default_params,
agent=agent,
agent=recurrent_agent,
n_samples=N_SAMPLES,
n_agents=N_AGENTS,
key=key,
Expand All @@ -57,12 +55,11 @@ def test_marl_policy_sampling(key, recurrent_agent, dummy_marl_env):


def test_policy_testing(key, recurrent_agent, dummy_env):
agent, _ = recurrent_agent
burn_in = 3
state_ts, reward_ts, _ = training.test_policy(
env=dummy_env,
env_params=dummy_env.default_params,
agent=agent,
agent=recurrent_agent,
n_steps=N_SAMPLES,
n_agents=None,
key=key,
Expand All @@ -80,12 +77,11 @@ def test_shape(x):


def test_marl_policy_testing(key, recurrent_agent, dummy_marl_env):
agent, _ = recurrent_agent
burn_in = 3
state_ts, reward_ts, _ = training.test_policy(
env=dummy_marl_env,
env_params=dummy_marl_env.default_params,
agent=agent,
agent=recurrent_agent,
n_steps=N_SAMPLES,
n_agents=N_AGENTS,
key=key,
Expand Down