From 4556dbb2b89dbe91b1989903973d93e448e2e8bb Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 17 Jun 2024 15:46:38 +0200 Subject: [PATCH] Rename `ODEState.contact_state` to `ODEState.contact` Co-authored-by: Diego Ferigo --- src/jaxsim/api/contact.py | 2 +- src/jaxsim/api/data.py | 9 ++++----- src/jaxsim/api/ode.py | 2 +- src/jaxsim/api/ode_data.py | 24 +++++++++--------------- tests/test_automatic_differentiation.py | 8 +++----- 5 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index f72812efe..2ff0ba568 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -150,7 +150,7 @@ def collidable_point_dynamics( # Note that the material deformation rate is always returned in the mixed frame # C[W] = (W_p_C, [W]). This is convenient for integration purpose. W_f_Ci, CW_ṁ = jax.vmap(soft_contacts.compute_contact_forces)( - W_p_Ci, W_ṗ_Ci, data.state.contacts_state.tangential_deformation + W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation ) case _: diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 9bbb48bcd..364eae7fc 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -113,7 +113,7 @@ def build( base_angular_velocity: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity, - contacts_state: jaxsim.rbda.ContactsState | None = None, + contact: jaxsim.rbda.ContactsState | None = None, contacts_params: jaxsim.rbda.ContactsParams | None = None, velocity_representation: VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, @@ -132,7 +132,7 @@ def build( The base angular velocity in the selected representation. joint_velocities: The joint velocities. standard_gravity: The standard gravity constant. - contacts_state: The state of the soft contacts. + contact: The state of the soft contacts. contacts_params: The parameters of the soft contacts. velocity_representation: The velocity representation to use. time: The time at which the state is created. @@ -213,9 +213,8 @@ def build( base_angular_velocity=v_WB[3:6].astype(float), joint_velocities=joint_velocities.astype(float), tangential_deformation=( - contacts_state.tangential_deformation - if contacts_state is not None - and isinstance(model.contact_model, SoftContacts) + contact.tangential_deformation + if contact is not None and isinstance(model.contact_model, SoftContacts) else None ), ) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 86f617d1d..88cc4fddf 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -132,7 +132,7 @@ def system_velocity_dynamics( W_f_Ci = None # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. - ṁ = jnp.zeros_like(data.state.contacts_state.tangential_deformation).astype(float) + ṁ = jnp.zeros_like(data.state.contact.tangential_deformation).astype(float) if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces applied to each collidable point and the diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index d6a74b902..d573da05a 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -121,11 +121,11 @@ class ODEState(JaxsimDataclass): Attributes: physics_model: The state of the physics model. - contacts_state: The state of the contacts model. + contact: The state of the contacts model. """ physics_model: PhysicsModelState - contacts_state: ContactsState + contact: ContactsState @staticmethod def build_from_jaxsim_model( @@ -183,7 +183,7 @@ def build_from_jaxsim_model( base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, ), - contacts_state=getattr( + contact=getattr( importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"), class_name, ).build_from_jaxsim_model( @@ -199,7 +199,7 @@ def build_from_jaxsim_model( @staticmethod def build( physics_model_state: PhysicsModelState | None = None, - contacts_state: ContactsState | None = None, + contact: ContactsState | None = None, model: js.model.JaxSimModel | None = None, ) -> ODEState: """ @@ -207,7 +207,7 @@ def build( Args: physics_model_state: The state of the physics model. - contacts_state: The state of the contacts model. + contact: The state of the contacts model. model: The `JaxSimModel` associated with the ODE state. Returns: @@ -240,15 +240,11 @@ def build( except ImportError as e: raise e - contacts_state = ( - contacts_state - if contacts_state is not None - else SoftContactsState.zero(model=model) + contact = ( + contact if contact is not None else SoftContactsState.zero(model=model) ) - return ODEState( - physics_model=physics_model_state, contacts_state=contacts_state - ) + return ODEState(physics_model=physics_model_state, contact=contact) @staticmethod def zero(model: js.model.JaxSimModel) -> ODEState: @@ -277,9 +273,7 @@ def valid(self, model: js.model.JaxSimModel) -> bool: `True` if the ODE state is valid for the given model, `False` otherwise. """ - return self.physics_model.valid(model=model) and self.contacts_state.valid( - model=model - ) + return self.physics_model.valid(model=model) and self.contact.valid(model=model) # ================================================== diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index d6ab792ea..c29fa785e 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -342,7 +342,7 @@ def test_ad_integration( s = data.joint_positions(model=model) W_v_WB = data.base_velocity() ṡ = data.joint_velocities(model=model) - m = data.state.contacts_state.tangential_deformation + m = data.state.contact.tangential_deformation # Inputs. W_f_L = references.link_forces(model=model) @@ -396,9 +396,7 @@ def step( base_angular_velocity=W_v_WB[3:6], joint_velocities=ṡ, ), - contacts_state=js.ode_data.SoftContactsState.build( - tangential_deformation=m - ), + contact=js.ode_data.SoftContactsState.build(tangential_deformation=m), ), ) @@ -417,7 +415,7 @@ def step( xf_s = data_xf.joint_positions(model=model) xf_W_v_WB = data_xf.base_velocity() xf_ṡ = data_xf.joint_velocities(model=model) - xf_m = data_xf.state.contacts_state.tangential_deformation + xf_m = data_xf.state.contact.tangential_deformation return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m