Skip to content

Commit

Permalink
Rename ODEState.contact_state to ODEState.contact
Browse files Browse the repository at this point in the history
Co-authored-by: Diego Ferigo <[email protected]>
  • Loading branch information
flferretti and diegoferigo committed Jun 17, 2024
1 parent eeaf61d commit 4556dbb
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def collidable_point_dynamics(
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.compute_contact_forces)(
W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
)

case _:
Expand Down
9 changes: 4 additions & 5 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def build(
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
contacts_state: jaxsim.rbda.ContactsState | None = None,
contact: jaxsim.rbda.ContactsState | None = None,
contacts_params: jaxsim.rbda.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
Expand All @@ -132,7 +132,7 @@ def build(
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
standard_gravity: The standard gravity constant.
contacts_state: The state of the soft contacts.
contact: The state of the soft contacts.
contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
time: The time at which the state is created.
Expand Down Expand Up @@ -213,9 +213,8 @@ def build(
base_angular_velocity=v_WB[3:6].astype(float),
joint_velocities=joint_velocities.astype(float),
tangential_deformation=(
contacts_state.tangential_deformation
if contacts_state is not None
and isinstance(model.contact_model, SoftContacts)
contact.tangential_deformation
if contact is not None and isinstance(model.contact_model, SoftContacts)
else None
),
)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def system_velocity_dynamics(
W_f_Ci = None

# Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}.
= jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float)
= jnp.zeros_like(data.state.contact.tangential_deformation).astype(float)

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:
# Compute the 6D forces applied to each collidable point and the
Expand Down
24 changes: 9 additions & 15 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ class ODEState(JaxsimDataclass):
Attributes:
physics_model: The state of the physics model.
contacts_state: The state of the contacts model.
contact: The state of the contacts model.
"""

physics_model: PhysicsModelState
contacts_state: ContactsState
contact: ContactsState

@staticmethod
def build_from_jaxsim_model(
Expand Down Expand Up @@ -183,7 +183,7 @@ def build_from_jaxsim_model(
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
),
contacts_state=getattr(
contact=getattr(
importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"),
class_name,
).build_from_jaxsim_model(
Expand All @@ -199,15 +199,15 @@ def build_from_jaxsim_model(
@staticmethod
def build(
physics_model_state: PhysicsModelState | None = None,
contacts_state: ContactsState | None = None,
contact: ContactsState | None = None,
model: js.model.JaxSimModel | None = None,
) -> ODEState:
"""
Build an `ODEState` from a `PhysicsModelState` and a `ContactsState`.
Args:
physics_model_state: The state of the physics model.
contacts_state: The state of the contacts model.
contact: The state of the contacts model.
model: The `JaxSimModel` associated with the ODE state.
Returns:
Expand Down Expand Up @@ -240,15 +240,11 @@ def build(
except ImportError as e:
raise e

contacts_state = (
contacts_state
if contacts_state is not None
else SoftContactsState.zero(model=model)
contact = (
contact if contact is not None else SoftContactsState.zero(model=model)
)

return ODEState(
physics_model=physics_model_state, contacts_state=contacts_state
)
return ODEState(physics_model=physics_model_state, contact=contact)

@staticmethod
def zero(model: js.model.JaxSimModel) -> ODEState:
Expand Down Expand Up @@ -277,9 +273,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
`True` if the ODE state is valid for the given model, `False` otherwise.
"""

return self.physics_model.valid(model=model) and self.contacts_state.valid(
model=model
)
return self.physics_model.valid(model=model) and self.contact.valid(model=model)


# ==================================================
Expand Down
8 changes: 3 additions & 5 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def test_ad_integration(
s = data.joint_positions(model=model)
W_v_WB = data.base_velocity()
= data.joint_velocities(model=model)
m = data.state.contacts_state.tangential_deformation
m = data.state.contact.tangential_deformation

# Inputs.
W_f_L = references.link_forces(model=model)
Expand Down Expand Up @@ -396,9 +396,7 @@ def step(
base_angular_velocity=W_v_WB[3:6],
joint_velocities=,
),
contacts_state=js.ode_data.SoftContactsState.build(
tangential_deformation=m
),
contact=js.ode_data.SoftContactsState.build(tangential_deformation=m),
),
)

Expand All @@ -417,7 +415,7 @@ def step(
xf_s = data_xf.joint_positions(model=model)
xf_W_v_WB = data_xf.base_velocity()
xf_ṡ = data_xf.joint_velocities(model=model)
xf_m = data_xf.state.contacts_state.tangential_deformation
xf_m = data_xf.state.contact.tangential_deformation

return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m

Expand Down

0 comments on commit 4556dbb

Please sign in to comment.