Skip to content

Commit

Permalink
Merge pull request #6 from zombie-einstein/non_static_env_args
Browse files Browse the repository at this point in the history
Allow non-static env_params and fix changes to reward shape returned …
  • Loading branch information
zombie-einstein committed Apr 22, 2023
2 parents e030e77 + 2c925e0 commit a25c0a5
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 23 deletions.
9 changes: 5 additions & 4 deletions examples/gym_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,19 @@
"execution_count": null,
"id": "39e5b9b1",
"metadata": {
"scrolled": true
"scrolled": false
},
"outputs": [],
"source": [
"_k, trained_agent, losses, ts, rewards = jax_ppo.train(\n",
"_k, trained_agent, losses, ts, rewards, _ = jax_ppo.train(\n",
" k, env, env_params, agent,\n",
" N_TRAIN, \n",
" N_TRAIN_ENV, \n",
" N_EPOCHS, \n",
" MINI_BATCH_SIZE, \n",
" N_TEST_ENV, \n",
" params, \n",
" env_params.max_steps_in_episode,\n",
" greedy_test_policy=True\n",
")"
]
Expand All @@ -190,7 +191,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(jnp.mean(jnp.sum(rewards[:, :, :, 0], axis=2), axis=1));\n",
"plt.plot(jnp.mean(jnp.sum(rewards[:, :, :], axis=2), axis=1));\n",
"plt.xlabel(\"Training Step\")\n",
"plt.ylabel(\"Avg Total Rewards\");"
]
Expand Down Expand Up @@ -236,7 +237,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(rewards[-1, :, :, 0].T, drawstyle=\"steps-mid\");\n",
"plt.plot(rewards[-1, :, :].T, drawstyle=\"steps-mid\");\n",
"plt.xlabel(\"Step\");\n",
"plt.ylabel(\"Reward\");"
]
Expand Down
7 changes: 4 additions & 3 deletions examples/lstm_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
},
"outputs": [],
"source": [
"_k, trained_agent, losses, ts, rewards = jax_ppo.train_recurrent(\n",
"_k, trained_agent, losses, ts, rewards, _ = jax_ppo.train_recurrent(\n",
" k, env, env_params, agent, \n",
" N_TRAIN, \n",
" N_TRAIN_ENV, \n",
Expand All @@ -190,6 +190,7 @@
" 1,\n",
" N_BURN_IN,\n",
" params,\n",
" env_params.max_steps_in_episode,\n",
" greedy_test_policy=True,\n",
")"
]
Expand All @@ -201,7 +202,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(jnp.mean(jnp.sum(rewards[:, :, :, 0], axis=2), axis=1));\n",
"plt.plot(jnp.mean(jnp.sum(rewards[:, :, :], axis=2), axis=1));\n",
"plt.xlabel(\"Training Step\")\n",
"plt.ylabel(\"Avg Total Rewards\");"
]
Expand Down Expand Up @@ -247,7 +248,7 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(rewards[-1, :, :, 0].T, drawstyle=\"steps-mid\");\n",
"plt.plot(rewards[-1, :, :].T, drawstyle=\"steps-mid\");\n",
"plt.xlabel(\"Step\");\n",
"plt.ylabel(\"Reward\");"
]
Expand Down
8 changes: 4 additions & 4 deletions jax_ppo/lstm/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _sample_step(carry, _):
if n_agents is None:
new_observation = new_observation[jnp.newaxis]
_done = jnp.array([_done])
_reward = jnp.array([_reward])

new_observation = jnp.hstack(
(_observation.at[:, 1:].get(), new_observation[:, jnp.newaxis])
Expand All @@ -96,7 +97,7 @@ def _sample_step(carry, _):
action=_action,
log_likelihood=_log_likelihood,
value=_value,
reward=_reward[0],
reward=_reward,
done=_done,
hidden_states=_hidden_state,
),
Expand Down Expand Up @@ -161,7 +162,7 @@ def _step(carry, _):

return (
(k, _agent, new_hidden_state, new_state, new_observation),
(new_state, _reward[0], _info),
(new_state, _reward, _info),
)

key, observation, state, hidden_states = _reset_env(
Expand Down Expand Up @@ -190,7 +191,6 @@ def _step(carry, _):
jax.jit,
static_argnames=(
"env",
"env_params",
"n_train",
"n_train_env",
"n_train_epochs",
Expand Down Expand Up @@ -218,8 +218,8 @@ def train(
n_recurrent_layers: int,
n_burn_in: int,
ppo_params: data_types.PPOParams,
n_env_steps: int,
n_agents: typing.Optional[int] = None,
n_env_steps: typing.Optional[int] = None,
greedy_test_policy: bool = False,
max_mini_batches: int = 10_000,
) -> typing.Tuple[
Expand Down
8 changes: 4 additions & 4 deletions jax_ppo/mlp/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _sample_step(carry, _):
if n_agents is None:
new_observation = new_observation[jnp.newaxis]
_done = jnp.array([_done])
_reward = jnp.array([_reward])

return (
(k, _agent, new_state, new_observation),
Expand All @@ -55,7 +56,7 @@ def _sample_step(carry, _):
action=_action,
log_likelihood=_log_likelihood,
value=_value,
reward=_reward[0],
reward=_reward,
done=_done,
),
)
Expand Down Expand Up @@ -118,7 +119,7 @@ def _step(carry, _):

return (
(k, _agent, new_state, new_observation),
(new_state, _reward[0], _info),
(new_state, _reward, _info),
)

key, reset_key = jax.random.split(key)
Expand All @@ -138,7 +139,6 @@ def _step(carry, _):
jax.jit,
static_argnames=(
"env",
"env_params",
"n_train",
"n_train_env",
"n_train_epochs",
Expand All @@ -160,8 +160,8 @@ def train(
mini_batch_size: int,
n_test_env: int,
ppo_params: data_types.PPOParams,
n_env_steps: int,
n_agents: typing.Optional[int] = None,
n_env_steps: typing.Optional[int] = None,
greedy_test_policy: bool = False,
max_mini_batches: int = 10_000,
) -> typing.Tuple[
Expand Down
5 changes: 1 addition & 4 deletions jax_ppo/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def train(
n_agents: typing.Optional[int],
greedy_test_policy: bool,
max_mini_batches: int,
n_env_steps: typing.Optional[int],
n_env_steps: int,
**static_kwargs,
) -> typing.Tuple[
jax.random.PRNGKey,
Expand All @@ -37,9 +37,6 @@ def train(
typing.Dict,
]:

if n_env_steps is None:
n_env_steps = env_params.max_steps_in_episode

test_keys = jax.random.split(key, n_test_env + 1)
key, test_keys = test_keys[0], test_keys[1:]

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def default_params(self) -> EnvParams:
def step_env(self, key, state: EnvState, action, params: EnvParams):
new_state = EnvState(x=state.x + jnp.sum(action) + params.y, t=state.t + 1)
new_obs = self.get_obs(new_state)
rewards = jnp.array([[2.0]])
rewards = 2.0
dones = self.is_terminal(new_state, params)
return new_obs, new_state, rewards, dones, dict()

Expand Down Expand Up @@ -124,7 +124,7 @@ def step_env(self, key, state: EnvState, action, params: EnvParams):
t=state.t + 1,
)
new_obs = self.get_obs(new_state)
rewards = jnp.arange(N_AGENTS, dtype=jnp.float32)[jnp.newaxis]
rewards = jnp.arange(N_AGENTS, dtype=jnp.float32)
dones = self.is_terminal(new_state, params)
return new_obs, new_state, rewards, dones, dict()

Expand Down
2 changes: 1 addition & 1 deletion tests/mlp/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_shape(x):
assert x.shape[0] == N_SAMPLES

jax.tree_util.tree_map(test_shape, state_ts)
assert reward_ts.shape == (N_SAMPLES, 1)
assert reward_ts.shape == (N_SAMPLES,)


def test_marl_policy_testing(key, mlp_agent, dummy_marl_env):
Expand Down
2 changes: 1 addition & 1 deletion tests/recurrent/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_shape(x):
assert x.shape[0] == n

jax.tree_util.tree_map(test_shape, state_ts)
assert reward_ts.shape == (n, 1)
assert reward_ts.shape == (n,)


def test_marl_policy_testing(key, recurrent_agent, dummy_marl_env):
Expand Down

0 comments on commit a25c0a5

Please sign in to comment.