Skip to content

Commit

Permalink
sq
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Feb 12, 2024
1 parent 4429442 commit 10cc1d6
Showing 1 changed file with 66 additions and 34 deletions.
100 changes: 66 additions & 34 deletions src/jaxsim/simulation/integrators_variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@

from jaxsim import typing as jtp
from jaxsim.physics.model.physics_model import PhysicsModel
from jaxsim.simulation.integrators import (
State,
StateDerivative,
StateDerivativeCallable,
Time,
TimeHorizon,
TimeStep,
)
from jaxsim.simulation.integrators import Time, TimeHorizon, TimeStep
from jaxsim.simulation.ode_data import ODEState
from jaxsim.sixd import so3

Expand All @@ -27,6 +20,15 @@
BETA_MAX_DEFAULT = 2.5
MAX_STEP_REJECTIONS_DEFAULT = 5

# Contrarily to the fixed-step integrators that operate on generic PyTrees,
# these variable-step integrators operate only on arrays (that could be the
# flatted PyTree).
State = jtp.Vector
StateDerivative = jtp.Vector
StateDerivativeCallable = Callable[
[State, Time], tuple[StateDerivative, dict[str, Any]]
]


class AdaptiveIntegratorType(enum.IntEnum):
HeunEuler = enum.auto()
Expand Down Expand Up @@ -63,7 +65,7 @@ def initial_step_size(
x0: State,
t0: Time,
f: StateDerivativeCallable,
order: jtp.IntLike = 2,
order: jtp.IntLike,
rtol: jtp.FloatLike = RTOL_DEFAULT,
atol: jtp.FloatLike = ATOL_DEFAULT,
) -> tuple[jtp.Float, StateDerivative]:
Expand All @@ -89,6 +91,7 @@ def initial_step_size(
E. Hairer, S. P. Norsett G. Wanner.
"""

# Compute the state derivative at the initial state.
ẋ0 = f(x0, t0)[0]

# Scale the initial state and its derivative.
Expand Down Expand Up @@ -140,9 +143,10 @@ def scale_array(
the local integration error.
"""

# Use a zeroed second state if not provided
# Use a zeroed second state if not provided.
x2 = x2 if x2 is not None else jnp.zeros_like(x1)

# Return: atol + max(|x1|, |x2|) * rtol.
return (
atol
+ jnp.vstack(
Expand All @@ -169,8 +173,8 @@ def error_local(
Args:
x0: The initial state $x(t_0)$.
xf: The final state $x(t_f)$.
error_estimate: The optional error estimate. In not given, it is computed
as the difference between the final and initial states.
error_estimate: The optional error estimate. In not given, it is computed as the
absolute value of the difference between the final and initial states.
rtol: The relative tolerance to scale the state.
atol: The absolute tolerance to scale the state.
norm_ord: The norm to use to compute the error. Default is the infinity norm.
Expand All @@ -183,18 +187,13 @@ def error_local(
sc = scale_array(x1=x0, x2=xf, rtol=rtol, atol=atol)

# Compute the error estimate if not given.
error_estimate = error_estimate if error_estimate is not None else xf - x0
error_estimate = error_estimate if error_estimate is not None else jnp.abs(xf - x0)

# Then, compute the local error by properly scaling the given error estimate and apply
# the desired norm (default is infinity norm, that is the maximum absolute value).
return jnp.linalg.norm(error_estimate / sc, ord=norm_ord)


# =======================
# Runge-Kutta Integrators
# =======================


@functools.partial(jax.jit, static_argnames=["f"])
def runge_kutta_from_butcher_tableau(
x0: State,
Expand All @@ -205,35 +204,59 @@ def runge_kutta_from_butcher_tableau(
c: jax.Array,
b: jax.Array,
A: jax.Array,
f0: StateDerivative | None = None,
dxdt0: StateDerivative | None = None,
) -> tuple[jax.Array, jax.Array, jax.Array | float, dict[str, Any]]:
""""""
"""
Advance a state vector by integrating a system dynamics with a Runge-Kutta integrator.
Args:
x0: The initial state.
t0: The initial time.
dt: The integration time step.
f: The state derivative function :math:`f(x, t)`.
c: The :math:`\mathbf{c}` parameter of the Butcher tableau.
b: The :math:`\mathbf{b}` parameter of the Butcher tableau.
A: The :math:`\mathbf{A}` parameter of the Butcher tableau.
dxdt0: The optional pre-computed state derivative at the
initial :math:`(x_0, t_0)`, useful for FSAL schemes.
Returns:
A tuple containing the next state, the intermediate states :math:`\mathbf{k}_i`,
the error estimate, and the auxiliary dictionary returned by `f`.
Note:
If `b.T` has multiple rows (used e.g. in embedded Runge-Kutta methods), the first
returned argument is a 2D array having as many rows as `b.T`. Each i-th row
corresponds to the solution computed with coefficients of the i-th row of `b.T`.
"""

# Adjust sizes of Butcher tableau arrays.
c = jnp.atleast_1d(c.squeeze())
b = jnp.atleast_2d(b.squeeze())
A = jnp.atleast_2d(A.squeeze())

h = dt
# Use a symbol for the time step.
Δt = dt

# Initialize the carry of the for loop with the stacked kᵢ vectors.
carry0 = jnp.zeros(shape=(c.size, x0.size), dtype=float)

# Allow FSAL (first-same-as-last) property by passing f0 = f(x0, t0) from
# Allow FSAL (first-same-as-last) property by passing ẋ0 = f(x0, t0) from
# the previous iteration.
get_ẋ0 = lambda: f0 if f0 is not None else f(x0, t0)[0]
get_ẋ0 = lambda: dxdt0 if dxdt0 is not None else f(x0, t0)[0]

# We use a `jax.lax.scan` to have only a single instance of the compiled f function.
# We use a `jax.lax.scan` to have only a single instance of the compiled `f` function.
# Otherwise, if we compute e.g. for RK4 sequentially, the jit-compiled code
# would include 4 repetitions of the f logic, making everything extremely slow.
# would include 4 repetitions of the `f` logic, making everything extremely slow.
def scan_body(carry: jax.Array, i: int | jax.Array) -> tuple[Any, None]:
""""""

# Unpack the carry
k = carry

def compute_ki():
xi = x0 + h * jnp.dot(A[i, :], k)
ti = t0 + c[i] * h
xi = x0 + Δt * jnp.dot(A[i, :], k)
ti = t0 + c[i] * Δt
return f(xi, ti)[0]

# This selector enables FSAL property in the first iteration (i=0).
Expand All @@ -253,11 +276,18 @@ def compute_ki():
xs=jnp.arange(c.size),
)

# Compute the output state and the error estimate.
# Compute the output state.
# Note that z contains as many new states as the rows of `b.T`.
z = x0 + h * jnp.dot(b.T, k)
error_estimate = dt * jnp.dot(b.T[-1] - b.T[0], k)
z = x0 + Δt * jnp.dot(b.T, k)

# Compute the error estimate if `b.T` has multiple rows, otherwise return 0.
error_estimate = jax.lax.select(
pred=b.T.shape[0] == 1,
on_true=jnp.array(0.0, dtype=float),
on_false=dt * jnp.dot(b.T[-1] - b.T[0], k),
)

# TODO: populate the auxiliary dictionary
return z, k, error_estimate, dict()


Expand Down Expand Up @@ -405,9 +435,9 @@ def bogacki_shampine(
return x_next, z_order, error, aux_dict


# =======================================
# Variable-step integrators (single step)
# =======================================
# ==========================================
# Variable-step RK integrators (single step)
# ==========================================


@functools.partial(
Expand Down Expand Up @@ -460,7 +490,7 @@ def odeint_embedded_rk_one_step(
Δt0, ẋ0 = jax.lax.cond(
pred=jnp.where(dt0 is None, 0.0, dt0) == 0.0,
true_fun=lambda _: initial_step_size(
x0=x0, t0=t0, f=f, order=q, atol=atol, rtol=rtol
x0=x0, t0=t0, f=f, order=p, atol=atol, rtol=rtol
),
false_fun=lambda _: (dt0, f(x0, t0)[0]),
operand=None,
Expand Down Expand Up @@ -744,7 +774,9 @@ def tf_next_state(x0: State, xf: State, t0: Time, dt: TimeStep) -> State:
static_argnames=[
"f",
"odeint_adaptive_one_step",
"integrator_type",
"debug_buffers_size_per_step",
"tf_next_state",
],
)
def _ode_integration_adaptive_template(
Expand Down

0 comments on commit 10cc1d6

Please sign in to comment.