diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 15ea7df18..5de7e43ce 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -365,10 +365,12 @@ def jacobian( # Adjust the output representation. match output_vel_repr: + case VelRepr.Inertial: O_J_WC = W_J_WC case VelRepr.Body: + W_H_C = transforms(model=model, data=data) def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: @@ -381,9 +383,11 @@ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC) case VelRepr.Mixed: + W_H_C = transforms(model=model, data=data) def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix: + W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3)) CW_X_W = jaxsim.math.Adjoint.from_transform( diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 579bee865..f9ffbc0a7 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -184,6 +184,7 @@ def κb(link_index: jtp.IntLike) -> jtp.Vector: carry0 = κb, link_index def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: + κb, active_link_index = carry κb, active_link_index = jax.lax.cond( @@ -225,12 +226,14 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: ) def __eq__(self, other: KynDynParameters) -> bool: + if not isinstance(other, KynDynParameters): return False return hash(self) == hash(other) def __hash__(self) -> int: + return hash( ( hash(self.number_of_links()), @@ -640,6 +643,7 @@ def build_from_inertial_parameters( def build_from_flat_parameters( index: jtp.IntLike, parameters: jtp.VectorLike ) -> LinkParameters: + index = jnp.array(index).squeeze().astype(int) m = jnp.array(parameters[0]).squeeze().astype(float) @@ -664,11 +668,7 @@ def flat_parameters(params: LinkParameters) -> jtp.Vector: return ( jnp.hstack( - [ - params.mass, - params.center_of_mass.squeeze(), - params.inertia_elements, - ] + [params.mass, params.center_of_mass.squeeze(), params.inertia_elements] ) .squeeze() .astype(float) diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index 9ff071071..072317f50 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -17,31 +17,27 @@ class ContactsState(abc.ABC): def build(cls, **kwargs) -> ContactsState: """ Build the contact state object. + Returns: The contact state object. """ - return cls(**kwargs) - @classmethod @abc.abstractmethod def zero(cls, **kwargs) -> ContactsState: """ Build a zero contact state. + Returns: The zero contact state. """ - return cls.build(**kwargs) - @abc.abstractmethod def valid(self, **kwargs) -> bool: """ Check if the contacts state is valid. """ - return True - class ContactsParams(abc.ABC): """ @@ -57,8 +53,7 @@ def build(cls) -> ContactsParams: The `ContactsParams` instance. """ - raise NotImplementedError - + @abc.abstractmethod def valid(self, *args, **kwargs) -> bool: """ Check if the parameters are valid. @@ -66,12 +61,11 @@ def valid(self, *args, **kwargs) -> bool: True if the parameters are valid, False otherwise. """ - return True - class ContactModel(abc.ABC): """ Abstract class representing a contact model. + Attributes: parameters: The parameters of the contact model. terrain: The terrain model. @@ -86,12 +80,16 @@ def compute_contact_forces( position: jtp.Vector, velocity: jtp.Vector, **kwargs, - ) -> tuple[Any, ...]: + ) -> tuple[jtp.Vector, tuple[Any, ...]]: """ Compute the contact forces. + Args: - position: The position of the collidable point. - velocity: The velocity of the collidable point. + position: The position of the collidable point w.r.t. the world frame. + velocity: + The linear velocity of the collidable point (linear component of the mixed 6D velocity). + Returns: - A tuple containing the contact force and additional information. + A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame, + and as second element a tuple of optional additional information. """ diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 1362d68f9..f199efdb6 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -31,6 +31,7 @@ class SoftContactsParams(ContactsParams): ) def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray return hash( @@ -42,6 +43,7 @@ def __hash__(self) -> int: ) def __eq__(self, other: SoftContactsParams) -> bool: + if not isinstance(other, SoftContactsParams): return NotImplemented @@ -126,6 +128,20 @@ def build_default_from_jaxsim_model( return SoftContactsParams.build(K=K, D=D, mu=μc) + def valid(self) -> bool: + """ + Check if the parameters are valid. + + Returns: + `True` if the parameters are valid, `False` otherwise. + """ + + return ( + jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + and jnp.all(self.mu >= 0.0) + ) + @jax_dataclasses.pytree_dataclass class SoftContacts(ContactModel):