Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support of parametric hardware models #101

Merged
merged 18 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import model, data # isort:skip
from . import common, contact, joint, link, ode, references
from . import common, contact, joint, kin_dyn_parameters, link, ode, references
22 changes: 10 additions & 12 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.physics.algos import soft_contacts

from . import data as Data
from . import model as Model


@jax.jit
def collidable_point_kinematics(
model: Model.JaxSimModel, data: Data.JaxSimModelData
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and 3D velocity of the collidable points in the world frame.
Expand Down Expand Up @@ -44,7 +42,7 @@ def collidable_point_kinematics(

@jax.jit
def collidable_point_positions(
model: Model.JaxSimModel, data: Data.JaxSimModelData
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the position of the collidable points in the world frame.
Expand All @@ -62,7 +60,7 @@ def collidable_point_positions(

@jax.jit
def collidable_point_velocities(
model: Model.JaxSimModel, data: Data.JaxSimModelData
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the 3D velocity of the collidable points in the world frame.
Expand All @@ -80,8 +78,8 @@ def collidable_point_velocities(

@functools.partial(jax.jit, static_argnames=["link_names"])
def in_contact(
model: Model.JaxSimModel,
data: Data.JaxSimModelData,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
Expand Down Expand Up @@ -131,7 +129,7 @@ def in_contact(

@jax.jit
def estimate_good_soft_contacts_parameters(
model: Model.JaxSimModel,
model: js.model.JaxSimModel,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
Expand Down Expand Up @@ -160,14 +158,14 @@ def estimate_good_soft_contacts_parameters(
specific application.
"""

def estimate_model_height(model: Model.JaxSimModel) -> jtp.Float:
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = Data.JaxSimModelData.build(
zero_data = js.data.JaxSimModelData.build(
model=model, soft_contacts_params=soft_contacts.SoftContactsParams()
)

W_pz_CoM = Model.com_position(model=model, data=zero_data)[2]
W_pz_CoM = js.model.com_position(model=model, data=zero_data)[2]

if model.physics_model.is_floating_base:
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
Expand Down
32 changes: 16 additions & 16 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jaxlie
import numpy as np

import jaxsim.api
import jaxsim.api as js
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
Expand Down Expand Up @@ -48,7 +48,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
)

def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
"""
Check if the current state is valid for the given model.

Expand All @@ -68,7 +68,7 @@ def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:

@staticmethod
def zero(
model: jaxsim.api.model.JaxSimModel,
model: js.model.JaxSimModel,
velocity_representation: VelRepr = VelRepr.Inertial,
) -> JaxSimModelData:
"""
Expand All @@ -88,7 +88,7 @@ def zero(

@staticmethod
def build(
model: jaxsim.api.model.JaxSimModel,
model: js.model.JaxSimModel,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
joint_positions: jtp.Vector | None = None,
Expand Down Expand Up @@ -167,7 +167,7 @@ def build(
soft_contacts_params = (
soft_contacts_params
if soft_contacts_params is not None
else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
else js.contact.estimate_good_soft_contacts_parameters(model=model)
)

W_H_B = jaxlie.SE3.from_rotation_and_translation(
Expand Down Expand Up @@ -225,7 +225,7 @@ def time(self) -> jtp.Float:
@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_positions(
self,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Expand Down Expand Up @@ -259,13 +259,13 @@ def joint_positions(
joint_names = joint_names if joint_names is not None else model.joint_names()

return self.state.physics_model.joint_positions[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
]

@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_velocities(
self,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Expand Down Expand Up @@ -299,7 +299,7 @@ def joint_velocities(
joint_names = joint_names if joint_names is not None else model.joint_names()

return self.state.physics_model.joint_velocities[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
]

@jax.jit
Expand Down Expand Up @@ -430,7 +430,7 @@ def generalized_velocity(self) -> jtp.Vector:
def reset_joint_positions(
self,
positions: jtp.VectorLike,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Expand Down Expand Up @@ -468,15 +468,15 @@ def replace(s: jtp.VectorLike) -> JaxSimModelData:

return replace(
s=self.state.physics_model.joint_positions.at[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
].set(positions)
)

@functools.partial(jax.jit, static_argnames=["joint_names"])
def reset_joint_velocities(
self,
velocities: jtp.VectorLike,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Expand Down Expand Up @@ -514,7 +514,7 @@ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:

return replace(
ṡ=self.state.physics_model.joint_velocities.at[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
].set(velocities)
)

Expand Down Expand Up @@ -692,7 +692,7 @@ def reset_base_velocity(


def random_model_data(
model: jaxsim.api.model.JaxSimModel,
model: js.model.JaxSimModel,
*,
key: jax.Array | None = None,
velocity_representation: VelRepr | None = None,
Expand Down Expand Up @@ -762,8 +762,8 @@ def random_model_data(
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]

if model.number_of_joints() > 0:
physics_model_state.joint_positions = (
jaxsim.api.joint.random_joint_positions(model=model, key=k3)
physics_model_state.joint_positions = js.joint.random_joint_positions(
model=model, key=k3
)

physics_model_state.joint_velocities = jax.random.uniform(
Expand Down
Loading
Loading