Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Diego Ferigo <[email protected]>
  • Loading branch information
flferretti and diegoferigo committed Jun 18, 2024
1 parent 908ca8b commit 2cac06c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
4 changes: 4 additions & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,12 @@ def jacobian(

# Adjust the output representation.
match output_vel_repr:

case VelRepr.Inertial:
O_J_WC = W_J_WC

case VelRepr.Body:

W_H_C = transforms(model=model, data=data)

def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
Expand All @@ -381,9 +383,11 @@ def body_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:
O_J_WC = jax.vmap(body_jacobian)(W_H_C, W_J_WC)

case VelRepr.Mixed:

W_H_C = transforms(model=model, data=data)

def mixed_jacobian(W_H_C: jtp.Matrix, W_J_WC: jtp.Matrix) -> jtp.Matrix:

W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))

CW_X_W = jaxsim.math.Adjoint.from_transform(
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def κb(link_index: jtp.IntLike) -> jtp.Vector:
carry0 = κb, link_index

def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:

κb, active_link_index = carry

κb, active_link_index = jax.lax.cond(
Expand Down Expand Up @@ -225,12 +226,14 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]:
)

def __eq__(self, other: KynDynParameters) -> bool:

if not isinstance(other, KynDynParameters):
return False

return hash(self) == hash(other)

def __hash__(self) -> int:

return hash(
(
hash(self.number_of_links()),
Expand Down Expand Up @@ -640,6 +643,7 @@ def build_from_inertial_parameters(
def build_from_flat_parameters(
index: jtp.IntLike, parameters: jtp.VectorLike
) -> LinkParameters:

index = jnp.array(index).squeeze().astype(int)

m = jnp.array(parameters[0]).squeeze().astype(float)
Expand All @@ -664,11 +668,7 @@ def flat_parameters(params: LinkParameters) -> jtp.Vector:

return (
jnp.hstack(
[
params.mass,
params.center_of_mass.squeeze(),
params.inertia_elements,
]
[params.mass, params.center_of_mass.squeeze(), params.inertia_elements]
)
.squeeze()
.astype(float)
Expand Down
26 changes: 12 additions & 14 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,27 @@ class ContactsState(abc.ABC):
def build(cls, **kwargs) -> ContactsState:
"""
Build the contact state object.
Returns:
The contact state object.
"""

return cls(**kwargs)

@classmethod
@abc.abstractmethod
def zero(cls, **kwargs) -> ContactsState:
"""
Build a zero contact state.
Returns:
The zero contact state.
"""

return cls.build(**kwargs)

@abc.abstractmethod
def valid(self, **kwargs) -> bool:
"""
Check if the contacts state is valid.
"""

return True


class ContactsParams(abc.ABC):
"""
Expand All @@ -57,21 +53,19 @@ def build(cls) -> ContactsParams:
The `ContactsParams` instance.
"""

raise NotImplementedError

@abc.abstractmethod
def valid(self, *args, **kwargs) -> bool:
"""
Check if the parameters are valid.
Returns:
True if the parameters are valid, False otherwise.
"""

return True


class ContactModel(abc.ABC):
"""
Abstract class representing a contact model.
Attributes:
parameters: The parameters of the contact model.
terrain: The terrain model.
Expand All @@ -86,12 +80,16 @@ def compute_contact_forces(
position: jtp.Vector,
velocity: jtp.Vector,
**kwargs,
) -> tuple[Any, ...]:
) -> tuple[jtp.Vector, tuple[Any, ...]]:
"""
Compute the contact forces.
Args:
position: The position of the collidable point.
velocity: The velocity of the collidable point.
position: The position of the collidable point w.r.t. the world frame.
velocity:
The linear velocity of the collidable point (linear component of the mixed 6D velocity).
Returns:
A tuple containing the contact force and additional information.
A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame,
and as second element a tuple of optional additional information.
"""
16 changes: 16 additions & 0 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SoftContactsParams(ContactsParams):
)

def __hash__(self) -> int:

from jaxsim.utils.wrappers import HashedNumpyArray

return hash(
Expand All @@ -42,6 +43,7 @@ def __hash__(self) -> int:
)

def __eq__(self, other: SoftContactsParams) -> bool:

if not isinstance(other, SoftContactsParams):
return NotImplemented

Expand Down Expand Up @@ -126,6 +128,20 @@ def build_default_from_jaxsim_model(

return SoftContactsParams.build(K=K, D=D, mu=μc)

def valid(self) -> bool:
"""
Check if the parameters are valid.
Returns:
`True` if the parameters are valid, `False` otherwise.
"""

return (
jnp.all(self.K >= 0.0)
and jnp.all(self.D >= 0.0)
and jnp.all(self.mu >= 0.0)
)


@jax_dataclasses.pytree_dataclass
class SoftContacts(ContactModel):
Expand Down

0 comments on commit 2cac06c

Please sign in to comment.