Skip to content

Commit 1a2a669

Browse files
authored
Quick n dirty (#4)
* First implementation of PPO * Bug fixes * Test domain randomization actually works * Nice names * Temporary * Add SAC * Add again state sampler * SAC works * Add domain randomization to SAC, without history-dependent policy * Making it work * Support the case where there's not domain randomization * Add test to compare against dmc * My version of sac * Fix bugs * Make base SAC work * Update hparams * Standardize training script * Fix domain randomization bugs * Domain randomization should work now? * Add domain randomization for the gear * Add formatting to pre-commit * Fix gear dr * No negative mass * Explicit privileged information * Fix bug * Use min-max and uniform * Seperate train and eval dr * mass dr * Separate eval and train * Remove ssac * Remove ppo * Fix lint * Remove old things * Remove SAC
1 parent 88d8fba commit 1a2a669

31 files changed

+1219
-629
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ repos:
1919
rev: v0.1.5
2020
hooks:
2121
- id: ruff
22+
name: lint with ruff
23+
- id: ruff
24+
name: sort imports with ruff
25+
args: [--select, I, --fix]
26+
- id: ruff-format
27+
name: format with ruff
2228
- repo: https://github.com/pre-commit/mirrors-mypy
2329
rev: v1.2.0
2430
hooks:

main.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

ss2r/algorithms/jax/ppo/__init__.py

Whitespace-only changes.
File renamed without changes.

ss2r/algorithms/sac/losses.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2024 The Brax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Soft Actor-Critic losses.
16+
17+
See: https://arxiv.org/pdf/1812.05905.pdf
18+
"""
19+
from typing import Any, TypeAlias
20+
21+
import jax
22+
import jax.numpy as jnp
23+
from brax.training import types
24+
from brax.training.agents.sac import networks as sac_networks
25+
from brax.training.types import Params, PRNGKey
26+
27+
Transition: TypeAlias = types.Transition
28+
29+
30+
def make_losses(
31+
sac_network: sac_networks.SACNetworks,
32+
reward_scaling: float,
33+
discounting: float,
34+
action_size: int,
35+
):
36+
"""Creates the SAC losses."""
37+
38+
target_entropy = -0.5 * action_size
39+
policy_network = sac_network.policy_network
40+
q_network = sac_network.q_network
41+
parametric_action_distribution = sac_network.parametric_action_distribution
42+
43+
def alpha_loss(
44+
log_alpha: jnp.ndarray,
45+
policy_params: Params,
46+
normalizer_params: Any,
47+
transitions: Transition,
48+
key: PRNGKey,
49+
) -> jnp.ndarray:
50+
"""Eq 18 from https://arxiv.org/pdf/1812.05905.pdf."""
51+
dist_params = policy_network.apply(
52+
normalizer_params, policy_params, transitions.observation
53+
)
54+
action = parametric_action_distribution.sample_no_postprocessing(
55+
dist_params, key
56+
)
57+
log_prob = parametric_action_distribution.log_prob(dist_params, action)
58+
alpha = jnp.exp(log_alpha)
59+
alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy)
60+
return jnp.mean(alpha_loss)
61+
62+
def critic_loss(
63+
q_params: Params,
64+
policy_params: Params,
65+
normalizer_params: Any,
66+
target_q_params: Params,
67+
alpha: jnp.ndarray,
68+
transitions: Transition,
69+
key: PRNGKey,
70+
) -> jnp.ndarray:
71+
domain_params = transitions.extras.get("domain_parameters", None)
72+
if domain_params is not None:
73+
action = jnp.concatenate([transitions.action, domain_params], axis=-1)
74+
else:
75+
action = transitions.action
76+
q_old_action = q_network.apply(
77+
normalizer_params, q_params, transitions.observation, action
78+
)
79+
next_dist_params = policy_network.apply(
80+
normalizer_params, policy_params, transitions.next_observation
81+
)
82+
next_action = parametric_action_distribution.sample_no_postprocessing(
83+
next_dist_params, key
84+
)
85+
next_log_prob = parametric_action_distribution.log_prob(
86+
next_dist_params, next_action
87+
)
88+
next_action = parametric_action_distribution.postprocess(next_action)
89+
if domain_params is not None:
90+
next_action = jnp.concatenate([next_action, domain_params], axis=-1)
91+
next_q = q_network.apply(
92+
normalizer_params,
93+
target_q_params,
94+
transitions.next_observation,
95+
next_action,
96+
)
97+
next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob
98+
target_q = jax.lax.stop_gradient(
99+
transitions.reward * reward_scaling
100+
+ transitions.discount * discounting * next_v
101+
)
102+
q_error = q_old_action - jnp.expand_dims(target_q, -1)
103+
104+
# Better bootstrapping for truncated episodes.
105+
truncation = transitions.extras["state_extras"]["truncation"]
106+
q_error *= jnp.expand_dims(1 - truncation, -1)
107+
108+
q_loss = 0.5 * jnp.mean(jnp.square(q_error))
109+
return q_loss
110+
111+
def actor_loss(
112+
policy_params: Params,
113+
normalizer_params: Any,
114+
q_params: Params,
115+
alpha: jnp.ndarray,
116+
transitions: Transition,
117+
key: PRNGKey,
118+
) -> jnp.ndarray:
119+
dist_params = policy_network.apply(
120+
normalizer_params, policy_params, transitions.observation
121+
)
122+
action = parametric_action_distribution.sample_no_postprocessing(
123+
dist_params, key
124+
)
125+
log_prob = parametric_action_distribution.log_prob(dist_params, action)
126+
action = parametric_action_distribution.postprocess(action)
127+
domain_params = transitions.extras.get("domain_parameters", None)
128+
if domain_params is not None:
129+
action = jnp.concatenate([action, domain_params], axis=-1)
130+
q_action = q_network.apply(
131+
normalizer_params, q_params, transitions.observation, action
132+
)
133+
min_q = jnp.min(q_action, axis=-1)
134+
actor_loss = alpha * log_prob - min_q
135+
return jnp.mean(actor_loss)
136+
137+
return alpha_loss, critic_loss, actor_loss

ss2r/algorithms/sac/networks.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 The Brax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""SAC networks."""
16+
17+
from typing import Protocol, Sequence, TypeAlias, TypeVar
18+
19+
import brax.training.agents.sac.networks as sac_networks
20+
from brax.training import distribution, networks, types
21+
from flax import linen
22+
23+
make_inference_fn = sac_networks.make_inference_fn
24+
SACNetworks: TypeAlias = sac_networks.SACNetworks
25+
NetworkType = TypeVar("NetworkType", covariant=True)
26+
27+
28+
class DomainRandomizationNetworkFactory(Protocol[NetworkType]):
29+
def __call__(
30+
self,
31+
observation_size: int,
32+
action_size: int,
33+
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
34+
*,
35+
domain_randomization_size: int = 0,
36+
) -> NetworkType:
37+
pass
38+
39+
40+
def make_sac_networks(
41+
observation_size: int,
42+
action_size: int,
43+
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
44+
hidden_layer_sizes: Sequence[int] = (256, 256),
45+
activation: networks.ActivationFn = linen.relu,
46+
*,
47+
domain_randomization_size: int = 0,
48+
) -> SACNetworks:
49+
"""Make SAC networks."""
50+
parametric_action_distribution = distribution.NormalTanhDistribution(
51+
event_size=action_size
52+
)
53+
policy_network = networks.make_policy_network(
54+
parametric_action_distribution.param_size,
55+
observation_size,
56+
preprocess_observations_fn=preprocess_observations_fn,
57+
hidden_layer_sizes=hidden_layer_sizes,
58+
activation=activation,
59+
)
60+
q_network = networks.make_q_network(
61+
observation_size,
62+
action_size + domain_randomization_size,
63+
preprocess_observations_fn=preprocess_observations_fn,
64+
hidden_layer_sizes=hidden_layer_sizes,
65+
activation=activation,
66+
)
67+
return SACNetworks(
68+
policy_network=policy_network,
69+
q_network=q_network,
70+
parametric_action_distribution=parametric_action_distribution,
71+
)

0 commit comments

Comments
 (0)