Skip to content

Commit

Permalink
Move SoftContactsState to soft_contacts
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 14, 2024
1 parent 9f5e265 commit af896be
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 135 deletions.
133 changes: 1 addition & 132 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.soft_contacts import SoftContactsState
from jaxsim.utils import JaxsimDataclass

# =============================================================================
Expand Down Expand Up @@ -595,135 +596,3 @@ def valid(self, model: js.model.JaxSimModel) -> bool:
return False

return True


# ===========================================
# Define the state of the soft-contacts model
# ===========================================


@jax_dataclasses.pytree_dataclass
class SoftContactsState(JaxsimDataclass):
"""
Class storing the state of the soft contacts model.
Attributes:
tangential_deformation:
The matrix of 3D tangential material deformations corresponding to
each collidable point.
"""

tangential_deformation: jtp.Matrix

def __hash__(self) -> int:

from jaxsim.utils.wrappers import HashedNumpyArray

return HashedNumpyArray.hash_of_array(self.tangential_deformation)

def __eq__(self, other: SoftContactsState) -> bool:

if not isinstance(other, SoftContactsState):
return False

return hash(self) == hash(other)

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
tangential_deformation: jtp.Matrix | None = None,
) -> SoftContactsState:
"""
Build a `SoftContactsState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the soft contacts state.
tangential_deformation: The matrix of 3D tangential material deformations.
Returns:
The `SoftContactsState` built from the `JaxSimModel`.
Note:
If any of the state components are not provided, they are built from the
`JaxSimModel` and initialized to zero.
"""

return SoftContactsState.build(
tangential_deformation=tangential_deformation,
number_of_collidable_points=len(
model.kin_dyn_parameters.contact_parameters.body
),
)

@staticmethod
def build(
tangential_deformation: jtp.Matrix | None = None,
number_of_collidable_points: int | None = None,
) -> SoftContactsState:
"""
Create a `SoftContactsState`.
Args:
tangential_deformation:
The matrix of 3D tangential material deformations corresponding to
each collidable point.
number_of_collidable_points: The number of collidable points.
Returns:
A `SoftContactsState` instance.
"""

tangential_deformation = (
tangential_deformation
if tangential_deformation is not None
else jnp.zeros(shape=(number_of_collidable_points, 3))
)

if tangential_deformation.shape[1] != 3:
raise RuntimeError("The tangential deformation matrix must have 3 columns.")

if (
number_of_collidable_points is not None
and tangential_deformation.shape[0] != number_of_collidable_points
):
msg = "The number of collidable points must match the number of rows "
msg += "in the tangential deformation matrix."
raise RuntimeError(msg)

return SoftContactsState(
tangential_deformation=jnp.array(tangential_deformation).astype(float)
)

@staticmethod
def zero(model: js.model.JaxSimModel) -> SoftContactsState:
"""
Build a zero `SoftContactsState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the soft contacts state.
Returns:
A zero `SoftContactsState` instance.
"""

return SoftContactsState.build_from_jaxsim_model(model=model)

def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `SoftContactsState` against.
Returns:
`True` if the soft contacts state is valid for the given `JaxSimModel`,
`False` otherwise.
"""

shape = self.tangential_deformation.shape
expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)

if shape != expected:
return False

return True
134 changes: 131 additions & 3 deletions src/jaxsim/api/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import jaxsim.typing as jtp
from jaxsim.math import Skew, StandardGravity
from jaxsim.terrain import FlatTerrain, Terrain
from jaxsim.utils import JaxsimDataclass

from .contact import ContactModel, ContactsParams, ContactsState


@jax_dataclasses.pytree_dataclass
class SoftContactsParams(JaxsimDataclass):
class SoftContactsParams(ContactsParams):
"""Parameters of the soft contacts model."""

K: jtp.Float = dataclasses.field(
Expand Down Expand Up @@ -129,7 +130,7 @@ def build_default_from_jaxsim_model(


@jax_dataclasses.pytree_dataclass
class SoftContacts:
class SoftContacts(ContactModel):
"""Soft contacts model."""

parameters: SoftContactsParams = dataclasses.field(
Expand Down Expand Up @@ -313,3 +314,130 @@ def slipping_contact():
false_fun=lambda _: with_friction(),
operand=None,
)


@jax_dataclasses.pytree_dataclass
class SoftContactsState(ContactsState):
"""
Class storing the state of the soft contacts model.
Attributes:
tangential_deformation:
The matrix of 3D tangential material deformations corresponding to
each collidable point.
"""

tangential_deformation: jtp.Matrix

def __hash__(self) -> int:

return hash(
tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist())
)

def __eq__(self, other: SoftContactsState) -> bool:

if not isinstance(other, SoftContactsState):
return False

return hash(self) == hash(other)

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
tangential_deformation: jtp.Matrix | None = None,
) -> SoftContactsState:
"""
Build a `SoftContactsState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the soft contacts state.
tangential_deformation: The matrix of 3D tangential material deformations.
Returns:
The `SoftContactsState` built from the `JaxSimModel`.
Note:
If any of the state components are not provided, they are built from the
`JaxSimModel` and initialized to zero.
"""

return SoftContactsState.build(
tangential_deformation=tangential_deformation,
number_of_collidable_points=len(
model.kin_dyn_parameters.contact_parameters.body
),
)

@staticmethod
def build(
tangential_deformation: jtp.Matrix | None = None,
number_of_collidable_points: int | None = None,
) -> SoftContactsState:
"""
Create a `SoftContactsState`.
Args:
tangential_deformation:
The matrix of 3D tangential material deformations corresponding to
each collidable point.
number_of_collidable_points: The number of collidable points.
Returns:
A `SoftContactsState` instance.
"""

tangential_deformation = (
tangential_deformation
if tangential_deformation is not None
else jnp.zeros(shape=(number_of_collidable_points, 3))
)

if tangential_deformation.shape[1] != 3:
raise RuntimeError("The tangential deformation matrix must have 3 columns.")

if (
number_of_collidable_points is not None
and tangential_deformation.shape[0] != number_of_collidable_points
):
msg = "The number of collidable points must match the number of rows "
msg += "in the tangential deformation matrix."
raise RuntimeError(msg)

return SoftContactsState(
tangential_deformation=jnp.array(tangential_deformation).astype(float)
)

@staticmethod
def zero(model: js.model.JaxSimModel) -> SoftContactsState:
"""
Build a zero `SoftContactsState` from a `JaxSimModel`.
Args:
model: The `JaxSimModel` associated with the soft contacts state.
Returns:
A zero `SoftContactsState` instance.
"""

return SoftContactsState.build_from_jaxsim_model(model=model)

def valid(self, model: js.model.JaxSimModel) -> bool:
"""
Check if the `SoftContactsState` is valid for a given `JaxSimModel`.
Args:
model: The `JaxSimModel` to validate the `SoftContactsState` against.
Returns:
`True` if the soft contacts state is valid for the given `JaxSimModel`,
`False` otherwise.
"""

shape = self.tangential_deformation.shape
expected = (len(model.kin_dyn_parameters.contact_parameters.body), 3)

if shape != expected:
return False

return True

0 comments on commit af896be

Please sign in to comment.