|
| 1 | +# Copyright 2024 The Brax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Soft Actor-Critic losses. |
| 16 | +
|
| 17 | +See: https://arxiv.org/pdf/1812.05905.pdf |
| 18 | +""" |
| 19 | +from typing import Any, TypeAlias |
| 20 | + |
| 21 | +import jax |
| 22 | +import jax.numpy as jnp |
| 23 | +from brax.training import types |
| 24 | +from brax.training.agents.sac import networks as sac_networks |
| 25 | +from brax.training.types import Params, PRNGKey |
| 26 | + |
| 27 | +Transition: TypeAlias = types.Transition |
| 28 | + |
| 29 | + |
| 30 | +def make_losses( |
| 31 | + sac_network: sac_networks.SACNetworks, |
| 32 | + reward_scaling: float, |
| 33 | + discounting: float, |
| 34 | + action_size: int, |
| 35 | +): |
| 36 | + """Creates the SAC losses.""" |
| 37 | + |
| 38 | + target_entropy = -0.5 * action_size |
| 39 | + policy_network = sac_network.policy_network |
| 40 | + q_network = sac_network.q_network |
| 41 | + parametric_action_distribution = sac_network.parametric_action_distribution |
| 42 | + |
| 43 | + def alpha_loss( |
| 44 | + log_alpha: jnp.ndarray, |
| 45 | + policy_params: Params, |
| 46 | + normalizer_params: Any, |
| 47 | + transitions: Transition, |
| 48 | + key: PRNGKey, |
| 49 | + ) -> jnp.ndarray: |
| 50 | + """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" |
| 51 | + dist_params = policy_network.apply( |
| 52 | + normalizer_params, policy_params, transitions.observation |
| 53 | + ) |
| 54 | + action = parametric_action_distribution.sample_no_postprocessing( |
| 55 | + dist_params, key |
| 56 | + ) |
| 57 | + log_prob = parametric_action_distribution.log_prob(dist_params, action) |
| 58 | + alpha = jnp.exp(log_alpha) |
| 59 | + alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) |
| 60 | + return jnp.mean(alpha_loss) |
| 61 | + |
| 62 | + def critic_loss( |
| 63 | + q_params: Params, |
| 64 | + policy_params: Params, |
| 65 | + normalizer_params: Any, |
| 66 | + target_q_params: Params, |
| 67 | + alpha: jnp.ndarray, |
| 68 | + transitions: Transition, |
| 69 | + key: PRNGKey, |
| 70 | + ) -> jnp.ndarray: |
| 71 | + domain_params = transitions.extras.get("domain_parameters", None) |
| 72 | + if domain_params is not None: |
| 73 | + action = jnp.concatenate([transitions.action, domain_params], axis=-1) |
| 74 | + else: |
| 75 | + action = transitions.action |
| 76 | + q_old_action = q_network.apply( |
| 77 | + normalizer_params, q_params, transitions.observation, action |
| 78 | + ) |
| 79 | + next_dist_params = policy_network.apply( |
| 80 | + normalizer_params, policy_params, transitions.next_observation |
| 81 | + ) |
| 82 | + next_action = parametric_action_distribution.sample_no_postprocessing( |
| 83 | + next_dist_params, key |
| 84 | + ) |
| 85 | + next_log_prob = parametric_action_distribution.log_prob( |
| 86 | + next_dist_params, next_action |
| 87 | + ) |
| 88 | + next_action = parametric_action_distribution.postprocess(next_action) |
| 89 | + if domain_params is not None: |
| 90 | + next_action = jnp.concatenate([next_action, domain_params], axis=-1) |
| 91 | + next_q = q_network.apply( |
| 92 | + normalizer_params, |
| 93 | + target_q_params, |
| 94 | + transitions.next_observation, |
| 95 | + next_action, |
| 96 | + ) |
| 97 | + next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob |
| 98 | + target_q = jax.lax.stop_gradient( |
| 99 | + transitions.reward * reward_scaling |
| 100 | + + transitions.discount * discounting * next_v |
| 101 | + ) |
| 102 | + q_error = q_old_action - jnp.expand_dims(target_q, -1) |
| 103 | + |
| 104 | + # Better bootstrapping for truncated episodes. |
| 105 | + truncation = transitions.extras["state_extras"]["truncation"] |
| 106 | + q_error *= jnp.expand_dims(1 - truncation, -1) |
| 107 | + |
| 108 | + q_loss = 0.5 * jnp.mean(jnp.square(q_error)) |
| 109 | + return q_loss |
| 110 | + |
| 111 | + def actor_loss( |
| 112 | + policy_params: Params, |
| 113 | + normalizer_params: Any, |
| 114 | + q_params: Params, |
| 115 | + alpha: jnp.ndarray, |
| 116 | + transitions: Transition, |
| 117 | + key: PRNGKey, |
| 118 | + ) -> jnp.ndarray: |
| 119 | + dist_params = policy_network.apply( |
| 120 | + normalizer_params, policy_params, transitions.observation |
| 121 | + ) |
| 122 | + action = parametric_action_distribution.sample_no_postprocessing( |
| 123 | + dist_params, key |
| 124 | + ) |
| 125 | + log_prob = parametric_action_distribution.log_prob(dist_params, action) |
| 126 | + action = parametric_action_distribution.postprocess(action) |
| 127 | + domain_params = transitions.extras.get("domain_parameters", None) |
| 128 | + if domain_params is not None: |
| 129 | + action = jnp.concatenate([action, domain_params], axis=-1) |
| 130 | + q_action = q_network.apply( |
| 131 | + normalizer_params, q_params, transitions.observation, action |
| 132 | + ) |
| 133 | + min_q = jnp.min(q_action, axis=-1) |
| 134 | + actor_loss = alpha * log_prob - min_q |
| 135 | + return jnp.mean(actor_loss) |
| 136 | + |
| 137 | + return alpha_loss, critic_loss, actor_loss |
0 commit comments