Skip to content

Commit

Permalink
Abstract contact state in api.ode_data.ODEState
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 13, 2024
1 parent fee82b9 commit c202825
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def collidable_point_dynamics(
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.contact_model)(
W_p_Ci, W_ṗ_Ci, data.state.contact_state.tangential_deformation
W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation
)

case _:
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def system_velocity_dynamics(
W_f_Ci = None

# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
= jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float)
= jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float)

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
# Compute the 6D forces applied to each collidable point and the
Expand Down
61 changes: 49 additions & 12 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import importlib

import jax.numpy as jnp
import jax_dataclasses

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.api.soft_contacts import SoftContactsState
from jaxsim.utils import JaxsimDataclass

Expand Down Expand Up @@ -117,11 +120,11 @@ class ODEState(JaxsimDataclass):
Attributes:
physics_model: The state of the physics model.
soft_contacts: The state of the soft-contacts model.
contacts_state: The state of the contacts model.
"""

physics_model: PhysicsModelState
soft_contacts: SoftContactsState
contacts_state: js.contact.ContactsState

@staticmethod
def build_from_jaxsim_model(
Expand Down Expand Up @@ -159,6 +162,15 @@ def build_from_jaxsim_model(
`JaxSimModel` and initialized to zero.
"""

# Get the contact model from the `JaxSimModel`
prefix = type(model.contact_model).__name__.split("Contact")[0]

if prefix:
module_name = f"{prefix.lower()}_contacts"
class_name = f"{prefix.capitalize()}ContactsState"
else:
raise ValueError("Unable to determine contact state class prefix.")

return ODEState.build(
model=model,
physics_model_state=PhysicsModelState.build_from_jaxsim_model(
Expand All @@ -170,24 +182,30 @@ def build_from_jaxsim_model(
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
),
soft_contacts_state=SoftContactsState.build_from_jaxsim_model(
contacts_state=getattr(
importlib.import_module(f"jaxsim.api.{module_name}"), class_name
).build_from_jaxsim_model(
model=model,
tangential_deformation=tangential_deformation,
**(
dict(tangential_deformation=tangential_deformation)
if tangential_deformation is not None
else dict()
),
),
)

@staticmethod
def build(
physics_model_state: PhysicsModelState | None = None,
soft_contacts_state: SoftContactsState | None = None,
contacts_state: js.contact.ContactsState | None = None,
model: js.model.JaxSimModel | None = None,
) -> ODEState:
"""
Build an `ODEState` from a `PhysicsModelState` and a `SoftContactsState`.
Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
Args:
physics_model_state: The state of the physics model.
soft_contacts_state: The state of the soft-contacts model.
contacts_state: The state of the contacts model.
model: The `JaxSimModel` associated with the ODE state.
Returns:
Expand All @@ -200,14 +218,33 @@ def build(
else PhysicsModelState.zero(model=model)
)

soft_contacts_state = (
soft_contacts_state
if soft_contacts_state is not None
# Get the contact model from the `JaxSimModel`
try:
prefix = type(model.contact_model).__name__.split("Contact")[0]
except AttributeError:
logging.warning(
"Unable to determine contact state class prefix. Using default soft contacts."
)
prefix = "Soft"

module_name = f"{prefix.lower()}_contacts"
class_name = f"{prefix.capitalize()}ContactsState"

try:
state_cls = getattr(
importlib.import_module(f"jaxsim.api.{module_name}"), class_name
)
except ImportError as e:
raise e

contacts_state = (
contacts_state
if contacts_state is not None
else SoftContactsState.zero(model=model)
)

return ODEState(
physics_model=physics_model_state, soft_contacts=soft_contacts_state
physics_model=physics_model_state, contacts_state=contacts_state
)

@staticmethod
Expand Down Expand Up @@ -237,7 +274,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
`True` if the ODE state is valid for the given model, `False` otherwise.
"""

return self.physics_model.valid(model=model) and self.soft_contacts.valid(
return self.physics_model.valid(model=model) and self.contacts_state.valid(
model=model
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def test_ad_integration(
s = data.joint_positions(model=model)
W_v_WB = data.base_velocity()
= data.joint_velocities(model=model)
m = data.state.soft_contacts.tangential_deformation
m = data.state.contacts_state.tangential_deformation

# Inputs.
W_f_L = references.link_forces(model=model)
Expand Down Expand Up @@ -417,7 +417,7 @@ def step(
xf_s = data_xf.joint_positions(model=model)
xf_W_v_WB = data_xf.base_velocity()
xf_ṡ = data_xf.joint_velocities(model=model)
xf_m = data_xf.state.soft_contacts.tangential_deformation
xf_m = data_xf.state.contacts_state.tangential_deformation

return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m

Expand Down

0 comments on commit c202825

Please sign in to comment.