Skip to content

Commit

Permalink
Simplify contact model indentification logic
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 17, 2024
1 parent 5640547 commit 9675c49
Showing 1 changed file with 25 additions and 44 deletions.
69 changes: 25 additions & 44 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from __future__ import annotations

import importlib

import jax.numpy as jnp
import jax_dataclasses

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.rbda import ContactsState
from jaxsim.rbda.contacts.soft_contacts import SoftContactsState
from jaxsim.rbda.contacts.soft_contacts import SoftContacts, SoftContactsState
from jaxsim.utils import JaxsimDataclass

# =============================================================================
Expand Down Expand Up @@ -164,13 +161,18 @@ def build_from_jaxsim_model(
"""

# Get the contact model from the `JaxSimModel`
prefix = type(model.contact_model).__name__.split("Contact")[0]

if prefix:
module_name = f"{prefix.lower()}_contacts"
class_name = f"{prefix.capitalize()}ContactsState"
else:
raise ValueError("Unable to determine contact state class prefix.")
match model.contact_model:
case SoftContacts():
contact = SoftContactsState.build_from_jaxsim_model(
model=model,
**(
dict(tangential_deformation=tangential_deformation)
if tangential_deformation is not None
else dict()
),
)
case _:
raise ValueError("Unable to determine contact state class prefix.")

return ODEState.build(
model=model,
Expand All @@ -183,17 +185,7 @@ def build_from_jaxsim_model(
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
),
contact=getattr(
importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"),
class_name,
).build_from_jaxsim_model(
model=model,
**(
dict(tangential_deformation=tangential_deformation)
if tangential_deformation is not None
else dict()
),
),
contact=contact,
)

@staticmethod
Expand Down Expand Up @@ -221,28 +213,17 @@ def build(
)

# Get the contact model from the `JaxSimModel`
try:
prefix = type(model.contact_model).__name__.split("Contact")[0]
except AttributeError:
logging.warning(
"Unable to determine contact state class prefix. Using default soft contacts."
)
prefix = "Soft"

module_name = f"{prefix.lower()}_contacts"
class_name = f"{prefix.capitalize()}ContactsState"

try:
state_cls = getattr(
importlib.import_module(f"jaxsim.rbda.contacts.{module_name}"),
class_name,
)
except ImportError as e:
raise e

contact = (
contact if contact is not None else SoftContactsState.zero(model=model)
)
match model.contact_model:
case SoftContacts():
pass
case None:
contact = (
contact
if contact is not None
else SoftContactsState.zero(model=model)
)
case _:
raise ValueError("Unable to determine contact state class prefix.")

return ODEState(physics_model=physics_model_state, contact=contact)

Expand Down

0 comments on commit 9675c49

Please sign in to comment.