|
| 1 | +from typing import Callable, Protocol |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +from brax.training.types import Params, Transition |
| 6 | + |
| 7 | + |
| 8 | +class QTransformation(Protocol): |
| 9 | + def __call__( |
| 10 | + self, |
| 11 | + transitions: Transition, |
| 12 | + q_fn: Callable[[Params, jax.Array], jax.Array], |
| 13 | + policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]], |
| 14 | + gamma: float, |
| 15 | + domain_params: jax.Array | None = None, |
| 16 | + alpha: jax.Array | None = None, |
| 17 | + reward_scaling: float = 1.0, |
| 18 | + ): |
| 19 | + ... |
| 20 | + |
| 21 | + |
| 22 | +class LCB(QTransformation): |
| 23 | + def __init__(self, lambda_: float) -> None: |
| 24 | + self.lambda_ = lambda_ |
| 25 | + |
| 26 | + def __call__( |
| 27 | + self, |
| 28 | + transitions: Transition, |
| 29 | + q_fn: Callable[[Params, jax.Array], jax.Array], |
| 30 | + policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]], |
| 31 | + gamma: float, |
| 32 | + domain_params: jax.Array | None = None, |
| 33 | + alpha: jax.Array | None = None, |
| 34 | + reward_scaling: float = 1.0, |
| 35 | + ): |
| 36 | + next_obs = transitions.extras["state_extras"]["state_propagation"]["next_obs"] |
| 37 | + next_action, _ = policy(next_obs) |
| 38 | + if domain_params is not None: |
| 39 | + domain_params = jnp.tile( |
| 40 | + domain_params[:, None], (1, next_action.shape[1], 1) |
| 41 | + ) |
| 42 | + next_action = jnp.concatenate([next_action, domain_params], axis=-1) |
| 43 | + next_q = q_fn(next_obs, next_action) |
| 44 | + next_v = next_q.mean(axis=-1) |
| 45 | + std = jnp.std(next_v, axis=-1) |
| 46 | + cost = transitions.extras["state_extras"]["cost"] |
| 47 | + cost += self.lambda_ * std |
| 48 | + target_q = jax.lax.stop_gradient( |
| 49 | + cost * reward_scaling + transitions.discount * gamma * next_v |
| 50 | + ) |
| 51 | + return target_q |
| 52 | + |
| 53 | + |
| 54 | +class CVaR(QTransformation): |
| 55 | + def __init__(self, confidence: float) -> None: |
| 56 | + self.confidence = confidence |
| 57 | + |
| 58 | + def __call__( |
| 59 | + self, |
| 60 | + transitions: Transition, |
| 61 | + q_fn: Callable[[Params, jax.Array], jax.Array], |
| 62 | + policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]], |
| 63 | + gamma: float, |
| 64 | + domain_params: jax.Array | None = None, |
| 65 | + alpha: jax.Array | None = None, |
| 66 | + reward_scaling: float = 1.0, |
| 67 | + ): |
| 68 | + next_obs = transitions.extras["state_extras"]["state_propagation"]["next_obs"] |
| 69 | + next_action, _ = policy(next_obs) |
| 70 | + if domain_params is not None: |
| 71 | + domain_params = jnp.tile( |
| 72 | + domain_params[:, None], (1, next_action.shape[1], 1) |
| 73 | + ) |
| 74 | + next_action = jnp.concatenate([next_action, domain_params], axis=-1) |
| 75 | + next_q = q_fn(next_obs, next_action) |
| 76 | + next_v = next_q.mean(axis=-1) |
| 77 | + sort_next_v = jnp.sort(next_v, axis=-1) |
| 78 | + cvar_index = int((1 - self.confidence) * next_v.shape[1]) |
| 79 | + next_v = jnp.mean(sort_next_v[:, :cvar_index], axis=-1) |
| 80 | + cost = transitions.extras["state_extras"]["cost"] |
| 81 | + target_q = jax.lax.stop_gradient( |
| 82 | + cost * reward_scaling + transitions.discount * gamma * next_v |
| 83 | + ) |
| 84 | + return target_q |
| 85 | + |
| 86 | + |
| 87 | +class SACBase(QTransformation): |
| 88 | + def __call__( |
| 89 | + self, |
| 90 | + transitions: Transition, |
| 91 | + q_fn: Callable[[Params, jax.Array], jax.Array], |
| 92 | + policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]], |
| 93 | + gamma: float, |
| 94 | + domain_params: jax.Array | None = None, |
| 95 | + alpha: jax.Array | None = None, |
| 96 | + reward_scaling: float = 1.0, |
| 97 | + ): |
| 98 | + next_action, next_log_prob = policy(transitions.next_observation) |
| 99 | + if domain_params is not None: |
| 100 | + next_action = jnp.concatenate([next_action, domain_params], axis=-1) |
| 101 | + next_q = q_fn(transitions.next_observation, next_action) |
| 102 | + next_v = next_q.min(axis=-1) |
| 103 | + next_v -= alpha * next_log_prob |
| 104 | + target_q = jax.lax.stop_gradient( |
| 105 | + transitions.reward * reward_scaling + transitions.discount * gamma * next_v |
| 106 | + ) |
| 107 | + return target_q |
| 108 | + |
| 109 | + |
| 110 | +class SACCost(QTransformation): |
| 111 | + def __call__( |
| 112 | + self, |
| 113 | + transitions: Transition, |
| 114 | + q_fn: Callable[[Params, jax.Array], jax.Array], |
| 115 | + policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]], |
| 116 | + gamma: float, |
| 117 | + domain_params: jax.Array | None = None, |
| 118 | + alpha: jax.Array | None = None, |
| 119 | + reward_scaling: float = 1.0, |
| 120 | + ): |
| 121 | + next_action, _ = policy(transitions.next_observation) |
| 122 | + if domain_params is not None: |
| 123 | + next_action = jnp.concatenate([next_action, domain_params], axis=-1) |
| 124 | + next_q = q_fn(transitions.next_observation, next_action) |
| 125 | + next_v = next_q.mean(axis=-1) |
| 126 | + cost = transitions.extras["state_extras"]["cost"] |
| 127 | + target_q = jax.lax.stop_gradient( |
| 128 | + cost * reward_scaling + transitions.discount * gamma * next_v |
| 129 | + ) |
| 130 | + return target_q |
0 commit comments