diff --git a/src/jaxsim/api/__init__.py b/src/jaxsim/api/__init__.py index a85176dd4..4f2e5d968 100644 --- a/src/jaxsim/api/__init__.py +++ b/src/jaxsim/api/__init__.py @@ -1,2 +1,2 @@ from . import model, data # isort:skip -from . import common, contact, joint, link, ode, references +from . import common, contact, joint, kin_dyn_parameters, link, ode, references diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 47650cf95..745353e0a 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -3,16 +3,14 @@ import jax import jax.numpy as jnp +import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.physics.algos import soft_contacts -from . import data as Data -from . import model as Model - @jax.jit def collidable_point_kinematics( - model: Model.JaxSimModel, data: Data.JaxSimModelData + model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Matrix, jtp.Matrix]: """ Compute the position and 3D velocity of the collidable points in the world frame. @@ -44,7 +42,7 @@ def collidable_point_kinematics( @jax.jit def collidable_point_positions( - model: Model.JaxSimModel, data: Data.JaxSimModelData + model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the position of the collidable points in the world frame. @@ -62,7 +60,7 @@ def collidable_point_positions( @jax.jit def collidable_point_velocities( - model: Model.JaxSimModel, data: Data.JaxSimModelData + model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> jtp.Matrix: """ Compute the 3D velocity of the collidable points in the world frame. @@ -80,8 +78,8 @@ def collidable_point_velocities( @functools.partial(jax.jit, static_argnames=["link_names"]) def in_contact( - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, *, link_names: tuple[str, ...] | None = None, ) -> jtp.Vector: @@ -131,7 +129,7 @@ def in_contact( @jax.jit def estimate_good_soft_contacts_parameters( - model: Model.JaxSimModel, + model: js.model.JaxSimModel, static_friction_coefficient: jtp.FloatLike = 0.5, number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, @@ -160,14 +158,14 @@ def estimate_good_soft_contacts_parameters( specific application. """ - def estimate_model_height(model: Model.JaxSimModel) -> jtp.Float: + def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: """""" - zero_data = Data.JaxSimModelData.build( + zero_data = js.data.JaxSimModelData.build( model=model, soft_contacts_params=soft_contacts.SoftContactsParams() ) - W_pz_CoM = Model.com_position(model=model, data=zero_data)[2] + W_pz_CoM = js.model.com_position(model=model, data=zero_data)[2] if model.physics_model.is_floating_base: W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 5a018840d..bec336c71 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -10,7 +10,7 @@ import jaxlie import numpy as np -import jaxsim.api +import jaxsim.api as js import jaxsim.physics.algos.aba import jaxsim.physics.algos.crba import jaxsim.physics.algos.forward_kinematics @@ -48,7 +48,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): default_factory=lambda: jnp.array(0, dtype=jnp.uint64) ) - def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool: + def valid(self, model: js.model.JaxSimModel | None = None) -> bool: """ Check if the current state is valid for the given model. @@ -68,7 +68,7 @@ def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool: @staticmethod def zero( - model: jaxsim.api.model.JaxSimModel, + model: js.model.JaxSimModel, velocity_representation: VelRepr = VelRepr.Inertial, ) -> JaxSimModelData: """ @@ -88,7 +88,7 @@ def zero( @staticmethod def build( - model: jaxsim.api.model.JaxSimModel, + model: js.model.JaxSimModel, base_position: jtp.Vector | None = None, base_quaternion: jtp.Vector | None = None, joint_positions: jtp.Vector | None = None, @@ -167,7 +167,7 @@ def build( soft_contacts_params = ( soft_contacts_params if soft_contacts_params is not None - else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model) + else js.contact.estimate_good_soft_contacts_parameters(model=model) ) W_H_B = jaxlie.SE3.from_rotation_and_translation( @@ -225,7 +225,7 @@ def time(self) -> jtp.Float: @functools.partial(jax.jit, static_argnames=["joint_names"]) def joint_positions( self, - model: jaxsim.api.model.JaxSimModel | None = None, + model: js.model.JaxSimModel | None = None, joint_names: tuple[str, ...] | None = None, ) -> jtp.Vector: """ @@ -259,13 +259,13 @@ def joint_positions( joint_names = joint_names if joint_names is not None else model.joint_names() return self.state.physics_model.joint_positions[ - jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model) + js.joint.names_to_idxs(joint_names=joint_names, model=model) ] @functools.partial(jax.jit, static_argnames=["joint_names"]) def joint_velocities( self, - model: jaxsim.api.model.JaxSimModel | None = None, + model: js.model.JaxSimModel | None = None, joint_names: tuple[str, ...] | None = None, ) -> jtp.Vector: """ @@ -299,7 +299,7 @@ def joint_velocities( joint_names = joint_names if joint_names is not None else model.joint_names() return self.state.physics_model.joint_velocities[ - jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model) + js.joint.names_to_idxs(joint_names=joint_names, model=model) ] @jax.jit @@ -430,7 +430,7 @@ def generalized_velocity(self) -> jtp.Vector: def reset_joint_positions( self, positions: jtp.VectorLike, - model: jaxsim.api.model.JaxSimModel | None = None, + model: js.model.JaxSimModel | None = None, joint_names: tuple[str, ...] | None = None, ) -> Self: """ @@ -468,7 +468,7 @@ def replace(s: jtp.VectorLike) -> JaxSimModelData: return replace( s=self.state.physics_model.joint_positions.at[ - jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model) + js.joint.names_to_idxs(joint_names=joint_names, model=model) ].set(positions) ) @@ -476,7 +476,7 @@ def replace(s: jtp.VectorLike) -> JaxSimModelData: def reset_joint_velocities( self, velocities: jtp.VectorLike, - model: jaxsim.api.model.JaxSimModel | None = None, + model: js.model.JaxSimModel | None = None, joint_names: tuple[str, ...] | None = None, ) -> Self: """ @@ -514,7 +514,7 @@ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData: return replace( ṡ=self.state.physics_model.joint_velocities.at[ - jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model) + js.joint.names_to_idxs(joint_names=joint_names, model=model) ].set(velocities) ) @@ -692,7 +692,7 @@ def reset_base_velocity( def random_model_data( - model: jaxsim.api.model.JaxSimModel, + model: js.model.JaxSimModel, *, key: jax.Array | None = None, velocity_representation: VelRepr | None = None, @@ -762,8 +762,8 @@ def random_model_data( ).as_quaternion_xyzw()[np.array([3, 0, 1, 2])] if model.number_of_joints() > 0: - physics_model_state.joint_positions = ( - jaxsim.api.joint.random_joint_positions(model=model, key=k3) + physics_model_state.joint_positions = js.joint.random_joint_positions( + model=model, key=k3 ) physics_model_state.joint_velocities = jax.random.uniform( diff --git a/src/jaxsim/api/joint.py b/src/jaxsim/api/joint.py index 7c5668e8c..e70020496 100644 --- a/src/jaxsim/api/joint.py +++ b/src/jaxsim/api/joint.py @@ -3,17 +3,18 @@ import jax import jax.numpy as jnp +import numpy as np +import jaxsim.api as js import jaxsim.typing as jtp -from . import model as Model - # ======================= # Index-related functions # ======================= -def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int: +@functools.partial(jax.jit, static_argnames="joint_name") +def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int: """ Convert the name of a joint to its index. @@ -25,12 +26,25 @@ def name_to_idx(model: Model.JaxSimModel, *, joint_name: str) -> jtp.Int: The index of the joint. """ - return jnp.array( - model.physics_model.description.joints_dict[joint_name].index - 1, dtype=int - ) - - -def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str: + if joint_name in model.kin_dyn_parameters.joint_model.joint_names: + # Note: the index of the joint for RBDAs starts from 1, but + # the index for accessing the right element starts from 0. + # Therefore, there is a -1. + return ( + jnp.array( + np.argwhere( + np.array(model.kin_dyn_parameters.joint_model.joint_names) + == joint_name + ) + - 1 + ) + .squeeze() + .astype(int) + ) + return jnp.array(-1).astype(int) + + +def idx_to_name(model: js.model.JaxSimModel, *, joint_index: jtp.IntLike) -> str: """ Convert the index of a joint to its name. @@ -42,11 +56,13 @@ def idx_to_name(model: Model.JaxSimModel, *, joint_index: jtp.IntLike) -> str: The name of the joint. """ - d = {j.index: j.name for j in model.physics_model.description.joints_dict.values()} - return d[joint_index] + return model.kin_dyn_parameters.joint_model.joint_names[joint_index + 1] -def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> jax.Array: +@functools.partial(jax.jit, static_argnames="joint_names") +def names_to_idxs( + model: js.model.JaxSimModel, *, joint_names: Sequence[str] +) -> jax.Array: """ Convert a sequence of joint names to their corresponding indices. @@ -59,19 +75,14 @@ def names_to_idxs(model: Model.JaxSimModel, *, joint_names: Sequence[str]) -> ja """ return jnp.array( - [ - # Note: the index of the joint for RBDAs starts from 1, but - # the index for accessing the right element starts from 0. - # Therefore, there is a -1. - model.physics_model.description.joints_dict[name].index - 1 - for name in joint_names - ], - dtype=int, - ) + [name_to_idx(model=model, joint_name=name) for name in joint_names], + ).astype(int) def idxs_to_names( - model: Model.JaxSimModel, *, joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike + model: js.model.JaxSimModel, + *, + joint_indices: Sequence[jtp.IntLike] | jtp.VectorLike, ) -> tuple[str, ...]: """ Convert a sequence of joint indices to their corresponding names. @@ -84,12 +95,7 @@ def idxs_to_names( The names of the joints. """ - d = { - j.index - 1: j.name - for j in model.physics_model.description.joints_dict.values() - } - - return tuple(d[i] for i in joint_indices) + return tuple(idx_to_name(model=model, joint_index=idx) for idx in joint_indices) # ============ @@ -99,26 +105,48 @@ def idxs_to_names( @jax.jit def position_limit( - model: Model.JaxSimModel, *, joint_index: jtp.IntLike + model: js.model.JaxSimModel, *, joint_index: jtp.IntLike ) -> tuple[jtp.Float, jtp.Float]: - """""" + """ + Get the position limits of a joint. - if model.physics_model.NB <= 1: - return jnp.array([]), jnp.array([]) + Args: + model: The model to consider. + joint_index: The index of the joint. + + Returns: + The position limits of the joint. + """ - s_min = model.physics_model._joint_position_limits_min[joint_index] - s_max = model.physics_model._joint_position_limits_max[joint_index] + if model.number_of_joints() <= 1: + return jnp.empty(0).astype(float), jnp.empty(0).astype(float) + + s_min = model.kin_dyn_parameters.joint_parameters.position_limits_min[joint_index] + s_max = model.kin_dyn_parameters.joint_parameters.position_limits_max[joint_index] return s_min.astype(float), s_max.astype(float) @functools.partial(jax.jit, static_argnames=["joint_names"]) def position_limits( - model: Model.JaxSimModel, *, joint_names: Sequence[str] | None = None + model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Get the position limits of a list of joint. + + Args: + model: The model to consider. + joint_names: The names of the joints. + + Returns: + The position limits of the joints. + """ joint_names = joint_names if joint_names is not None else model.joint_names() + if len(joint_names) == 0: + return jnp.empty(0).astype(float), jnp.empty(0).astype(float) + joint_idxs = names_to_idxs(joint_names=joint_names, model=model) return jax.vmap(lambda i: position_limit(model=model, joint_index=i))(joint_idxs) @@ -130,12 +158,22 @@ def position_limits( @functools.partial(jax.jit, static_argnames=["joint_names"]) def random_joint_positions( - model: Model.JaxSimModel, + model: js.model.JaxSimModel, *, joint_names: Sequence[str] | None = None, key: jax.Array | None = None, ) -> jtp.Vector: - """""" + """ + Generate random joint positions. + + Args: + model: The model to consider. + joint_names: The names of the joints. + key: The random key. + + Returns: + The random joint positions. + """ key = key if key is not None else jax.random.PRNGKey(seed=0) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py new file mode 100644 index 000000000..7cbc73ea7 --- /dev/null +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -0,0 +1,512 @@ +from __future__ import annotations + +import jax.lax +import jax.numpy as jnp +import jax_dataclasses +import jaxlie +from jax_dataclasses import Static + +import jaxsim.typing as jtp +from jaxsim.math.inertia import Inertia +from jaxsim.math.joint_model import JointModel, supported_joint_motion +from jaxsim.parsers.descriptions import JointDescription, ModelDescription +from jaxsim.physics.model.ground_contact import GroundContact as ContactParameters +from jaxsim.utils import JaxsimDataclass + + +@jax_dataclasses.pytree_dataclass +class KynDynParameters(JaxsimDataclass): + + # Static + link_names: Static[tuple[str]] + parent_array: Static[jtp.Vector] + support_body_array_bool: Static[jtp.Matrix] + + # Links + link_parameters: LinkParameters + + # Contacts + contact_parameters: ContactParameters + + # Joints + joint_model: JointModel + joint_parameters: JointParameters | None + + @staticmethod + def build(model_description: ModelDescription) -> KynDynParameters: + """ + Construct the kinematic and dynamic parameters of the model. + + Args: + model_description: The parsed model description to consider. + + Returns: + The kinematic and dynamic parameters of the model. + + Note: + This class is meant to ease the management of parametric models in + an automatic differentiation context. + """ + + # Extract the links ordered by their index. + # The link index corresponds to the body index ∈ [0, num_bodies - 1]. + ordered_links = sorted( + list(model_description.links_dict.values()), + key=lambda l: l.index, + ) + + # Extract the joints ordered by their index. + # The joint index matches the index of its child link, therefore it starts + # from 1. Keep this in mind since this 1-indexing might introduce bugs. + ordered_joints = sorted( + list(model_description.joints_dict.values()), + key=lambda j: j.index, + ) + + # ================ + # Links properties + # ================ + + # Create a list of link parameters objects. + link_parameters_list = [ + LinkParameters.build_from_spatial_inertia(M=link.inertia) + for link in ordered_links + ] + + # Create a vectorized object of link parameters. + link_parameters = jax.tree_util.tree_map( + lambda *l: jnp.stack(l), *link_parameters_list + ) + + # ================= + # Joints properties + # ================= + + # Create a list of joint parameters objects. + joint_parameters_list = [ + JointParameters.build_from_joint_description(joint_description=joint) + for joint in ordered_joints + ] + + # Create a vectorized object of joint parameters. + joint_parameters = ( + jax.tree_util.tree_map(lambda *l: jnp.stack(l), *joint_parameters_list) + if len(ordered_joints) > 0 + else None + ) + + # Create an object that defines the joint model (parent-to-child transforms). + joint_model = JointModel.build(description=model_description) + + # =============== + # Tree properties + # =============== + + # Build the parent array λ(i) of the model. + # Note: the parent of the base link is not set since it's not defined. + parent_array_dict = { + link.index: link.parent.index + for link in ordered_links + if link.parent is not None + } + parent_array = jnp.array([-1] + list(parent_array_dict.values()), dtype=int) + + # Instead of building the support parent array κ(i) of the model, having a + # variable length that depends on the number of links connecting the root to + # the i-th link, we build the corresponding boolean version. + # Given a link index i, the boolean support parent array κb(i) is an array + # with the same number of elements of λ(i) having the i-th element set to True + # if the i-th link is in the support parent array κ(i), False otherwise. + def κb(link_index: jtp.IntLike) -> jtp.Vector: + κb = jnp.zeros(len(ordered_links), dtype=bool) + + carry0 = κb, link_index + + def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: + + κb, active_link_index = carry + + κb, active_link_index = jax.lax.cond( + pred=(i == active_link_index), + false_fun=lambda: (κb, active_link_index), + true_fun=lambda: ( + κb.at[active_link_index].set(True), + parent_array[active_link_index], + ), + ) + + return (κb, active_link_index), None + + (κb, _), _ = jax.lax.scan( + f=scan_body, + init=carry0, + xs=jnp.flip(jnp.arange(start=0, stop=len(ordered_links))), + ) + + return κb + + support_body_array_bool = jax.vmap(κb)( + jnp.arange(start=0, stop=len(ordered_links)) + ) + + return KynDynParameters( + link_names=tuple(l.name for l in ordered_links), + parent_array=parent_array, + support_body_array_bool=support_body_array_bool, + link_parameters=link_parameters, + joint_model=joint_model, + joint_parameters=joint_parameters, + contact_parameters=ContactParameters.build_from( + model_description=model_description + ), + ) + + def __eq__(self, other: KynDynParameters) -> bool: + + if not isinstance(other, KynDynParameters): + return False + + equal = True + equal = equal and self.number_of_links() == other.number_of_links() + equal = equal and self.number_of_joints() == other.number_of_joints() + equal = equal and jnp.allclose(self.parent_array, other.parent_array) + + return equal + + def __hash__(self) -> int: + + h = 0 + h += hash(self.number_of_links()) + h += hash(self.number_of_joints()) + h += hash(tuple(self.parent_array.tolist())) + + return h + + # ============================= + # Helpers to extract parameters + # ============================= + + def number_of_links(self) -> int: + """ + Return the number of links of the model. + + Returns: + The number of links of the model. + """ + + return len(self.link_names) + + def number_of_joints(self) -> int: + """ + Return the number of joints of the model. + + Returns: + The number of joints of the model. + """ + + return len(self.joint_model.joint_names) - 1 + + def support_body_array(self, link_index: jtp.IntLike) -> jtp.Vector: + """ + Return the support parent array κ(i) of a link belonging to the model. + + Args: + link_index: The index of the link. + + Returns: + The support parent array κ(i) of the link. + + Note: + This method returns a variable-length vector. In jit-compiled functions, + it's better to use the (static) boolean version `support_body_array_bool`. + """ + + return jnp.array( + jnp.where(self.support_body_array_bool[link_index])[0], dtype=int + ) + + # ======================== + # Quantities used by RBDAs + # ======================== + + @jax.jit + def links_spatial_inertia(self) -> jtp.Array: + """ + Return the spatial inertia of all links of the model. + + Returns: + The spatial inertia of all links of the model. + """ + + return jax.vmap(LinkParameters.spatial_inertia)(self.link_parameters) + + @jax.jit + def tree_transforms(self) -> jtp.Array: + """ + Return the tree transforms of the model. + + Returns: + The transforms + :math:`{}^{\text{pre}(\text{i})} H_{\lambda(\text{i})}` + of all joints of the model. + """ + + pre_Xi_λ = jax.vmap( + lambda i: self.joint_model.parent_H_predecessor(joint_index=i) + .inverse() + .adjoint() + )(jnp.arange(1, self.number_of_joints() + 1)) + + return jnp.vstack( + [ + jnp.zeros(shape=(1, 6, 6), dtype=float), + pre_Xi_λ, + ] + ) + + @jax.jit + def joint_transforms(self, joint_positions: jtp.VectorLike) -> jtp.Array: + """ + Return the transforms of the joints. + + Args: + joint_positions: The joint positions. + + Returns: + The stacked transforms + :math:`{}^{\text{i}} \mathbf{H}_{\lambda(\text{i})}(s)` + of each joint. + """ + + return self.joint_transforms_and_motion_subspaces(joint_positions)[0] + + @jax.jit + def joint_motion_subspaces(self, joint_positions: jtp.VectorLike) -> jtp.Array: + """ + Return the motion subspaces of the joints. + + Args: + joint_positions: The joint positions. + + Returns: + The stacked motion subspaces :math:`\mathbf{S}(s)` of each joint. + """ + + return self.joint_transforms_and_motion_subspaces(joint_positions)[1] + + @jax.jit + def joint_transforms_and_motion_subspaces( + self, joint_positions: jtp.VectorLike + ) -> tuple[jtp.Array, jtp.Array]: + """ + Return the transforms and the motion subspaces of the joints. + + Args: + joint_positions: The joint positions. + + Returns: + A tuple containing the stacked transforms + :math:`{}^{\text{i}} \mathbf{H}_{\lambda(\text{i})}(s)` + and the stacked motion subspaces :math:`\mathbf{S}(s)` of each joint. + """ + + λ_H_pre = jax.vmap( + lambda i: self.joint_model.parent_H_predecessor(joint_index=i) + )(jnp.arange(1, 1 + self.number_of_joints())) + + pre_H_suc_and_S = [ + supported_joint_motion( + joint_type=self.joint_model.joint_types[index + 1], + joint_position=s, + ) + for index, s in enumerate(joint_positions) + ] + + pre_H_suc = jnp.stack([jnp.eye(4)] + [H for H, _ in pre_H_suc_and_S]) + S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S]) + + suc_H_i = jax.vmap(lambda i: self.joint_model.successor_H_child(joint_index=i))( + jnp.arange(1, 1 + self.number_of_joints()) + ) + + i_X_λ = jax.vmap( + lambda λ_Hi_pre, pre_Hi_suc, suc_Hi_i: jaxlie.SE3.from_matrix( + λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i + ) + .inverse() + .adjoint() + )(λ_H_pre, pre_H_suc, suc_H_i) + + return i_X_λ, S + + # ============================ + # Helpers to update parameters + # ============================ + + def set_link_mass(self, link_index: int, mass: jtp.FloatLike) -> KynDynParameters: + """ + Set the mass of a link. + + Args: + link_index: The index of the link. + mass: The mass of the link. + + Returns: + The updated kinematic and dynamic parameters of the model. + """ + + link_parameters = self.link_parameters.replace( + mass=self.link_parameters.mass.at[link_index].set(mass) + ) + + return self.replace(link_parameters=link_parameters) + + def set_link_inertia( + self, link_index: int, inertia: jtp.MatrixLike + ) -> KynDynParameters: + """ + Set the inertia tensor of a link. + + Args: + link_index: The index of the link. + inertia: The 3×3 inertia tensor of the link. + + Returns: + The updated kinematic and dynamic parameters of the model. + """ + + inertia_elements = LinkParameters.flatten_inertia_tensor(I=inertia) + + link_parameters = self.link_parameters.replace( + mass=self.link_parameters.inertia_elements.at[link_index].set( + inertia_elements + ) + ) + + return self.replace(link_parameters=link_parameters) + + +@jax_dataclasses.pytree_dataclass +class JointParameters(JaxsimDataclass): + + friction_static: jtp.Float + friction_viscous: jtp.Float + + position_limits_min: jtp.Float + position_limits_max: jtp.Float + + position_limit_spring: jtp.Float + position_limit_damper: jtp.Float + + @staticmethod + def build_from_joint_description( + joint_description: JointDescription, + ) -> JointParameters: + """""" + + s_min = joint_description.position_limit[0] + s_max = joint_description.position_limit[1] + + position_limits_min = jnp.minimum(s_min, s_max) + position_limits_max = jnp.maximum(s_min, s_max) + + friction_static = jnp.array(joint_description.friction_static).squeeze() + friction_viscous = jnp.array(joint_description.friction_viscous).squeeze() + + position_limit_spring = jnp.array( + joint_description.position_limit_spring + ).squeeze() + + position_limit_damper = jnp.array( + joint_description.position_limit_damper + ).squeeze() + + return JointParameters( + friction_static=friction_static.astype(float), + friction_viscous=friction_viscous.astype(float), + position_limits_min=position_limits_min.astype(float), + position_limits_max=position_limits_max.astype(float), + position_limit_spring=position_limit_spring.astype(float), + position_limit_damper=position_limit_damper.astype(float), + ) + + +@jax_dataclasses.pytree_dataclass +class LinkParameters(JaxsimDataclass): + + mass: jtp.Float + inertia_elements: jtp.Vector + + # The following is L_p_CoM, that is the translation between the link frame and + # the link's center of mass, expressed in the coordinates of the link frame L. + center_of_mass: jtp.Vector + + @staticmethod + def build_from_spatial_inertia(M: jtp.Matrix) -> LinkParameters: + """""" + + m, L_p_CoM, I = Inertia.to_params(M=M) + + return LinkParameters( + mass=jnp.array(m).squeeze().astype(float), + center_of_mass=jnp.atleast_1d(jnp.array(L_p_CoM).squeeze()).astype(float), + inertia_elements=jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze()).astype( + float + ), + ) + + @staticmethod + def build_from_inertial_parameters( + m: jtp.FloatLike, I: jtp.MatrixLike, c: jtp.VectorLike + ) -> LinkParameters: + + return LinkParameters( + mass=jnp.array(m).squeeze().astype(float), + inertia_elements=jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze()).astype( + float + ), + center_of_mass=jnp.atleast_1d(c.squeeze()).astype(float), + ) + + @staticmethod + def build_from_flat_parameters(parameters: jtp.VectorLike) -> LinkParameters: + + m = jnp.array(parameters[0]).squeeze().astype(float) + c = jnp.atleast_1d(parameters[1:4].squeeze()).astype(float) + I = jnp.atleast_1d(parameters[4:].squeeze()).astype(float) + + return LinkParameters( + mass=m, inertia_elements=I[jnp.triu_indices(3)], center_of_mass=c + ) + + @staticmethod + def parameters(params: LinkParameters) -> jtp.Vector: + + return jnp.hstack( + [params.mass, params.center_of_mass.squeeze(), params.inertia_elements] + ) + + @staticmethod + def inertia_tensor(params: LinkParameters) -> jtp.Matrix: + + return LinkParameters.unflatten_inertia_tensor( + inertia_elements=params.inertia_elements + ) + + @staticmethod + def spatial_inertia(params: LinkParameters) -> jtp.Matrix: + + return Inertia.to_sixd( + mass=params.mass, + I=LinkParameters.inertia_tensor(params), + com=params.center_of_mass, + ) + + @staticmethod + def flatten_inertia_tensor(I: jtp.Matrix) -> jtp.Vector: + return jnp.atleast_1d(I[jnp.triu_indices(3)].squeeze()) + + @staticmethod + def unflatten_inertia_tensor(inertia_elements: jtp.Vector) -> jtp.Matrix: + I = jnp.zeros([3, 3]).at[jnp.triu_indices(3)].set(inertia_elements.squeeze()) + return jnp.atleast_2d(jnp.where(I, I, I.T)).astype(float) diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 616074994..eef6724a1 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -4,20 +4,20 @@ import jax import jax.numpy as jnp import jaxlie +import numpy as np +import jaxsim.api as js import jaxsim.physics.algos.jacobian import jaxsim.typing as jtp from jaxsim.high_level.common import VelRepr -from . import data as Data -from . import model as Model - # ======================= # Index-related functions # ======================= -def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int: +@functools.partial(jax.jit, static_argnames="link_name") +def name_to_idx(model: js.model.JaxSimModel, *, link_name: str) -> jtp.Int: """ Convert the name of a link to its index. @@ -29,12 +29,18 @@ def name_to_idx(model: Model.JaxSimModel, *, link_name: str) -> jtp.Int: The index of the link. """ - return jnp.array( - model.physics_model.description.links_dict[link_name].index, dtype=int - ) + if link_name in model.kin_dyn_parameters.link_names: + return ( + jnp.array( + np.argwhere(np.array(model.kin_dyn_parameters.link_names) == link_name) + ) + .squeeze() + .astype(int) + ) + return jnp.array(-1).astype(int) -def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str: +def idx_to_name(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> str: """ Convert the index of a link to its name. @@ -46,11 +52,13 @@ def idx_to_name(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> str: The name of the link. """ - d = {l.index: l.name for l in model.physics_model.description.links_dict.values()} - return d[link_index] + return model.kin_dyn_parameters.link_names[link_index] -def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax.Array: +@functools.partial(jax.jit, static_argnames="link_names") +def names_to_idxs( + model: js.model.JaxSimModel, *, link_names: Sequence[str] +) -> jax.Array: """ Convert a sequence of link names to their corresponding indices. @@ -63,13 +71,12 @@ def names_to_idxs(model: Model.JaxSimModel, *, link_names: Sequence[str]) -> jax """ return jnp.array( - [model.physics_model.description.links_dict[name].index for name in link_names], - dtype=int, - ) + [name_to_idx(model=model, link_name=name) for name in link_names], + ).astype(int) def idxs_to_names( - model: Model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike + model: js.model.JaxSimModel, *, link_indices: Sequence[jtp.IntLike] | jtp.VectorLike ) -> tuple[str, ...]: """ Convert a sequence of link indices to their corresponding names. @@ -82,8 +89,7 @@ def idxs_to_names( The names of the links. """ - d = {l.index: l.name for l in model.physics_model.description.links_dict.values()} - return tuple(d[i] for i in link_indices) + return tuple(idx_to_name(model=model, link_index=idx) for idx in link_indices) # ========= @@ -91,21 +97,32 @@ def idxs_to_names( # ========= -def mass(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float: +@jax.jit +def mass(model: js.model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Float: """""" - return model.physics_model._link_masses[link_index].astype(float) + return model.kin_dyn_parameters.link_parameters.mass[link_index].astype(float) -def spatial_inertia(model: Model.JaxSimModel, *, link_index: jtp.IntLike) -> jtp.Matrix: +@jax.jit +def spatial_inertia( + model: js.model.JaxSimModel, *, link_index: jtp.IntLike +) -> jtp.Matrix: """""" - return model.physics_model._link_spatial_inertias[link_index] + link_parameters = jax.tree_util.tree_map( + lambda l: l[link_index], model.kin_dyn_parameters.link_parameters + ) + + return js.kin_dyn_parameters.LinkParameters.spatial_inertia(link_parameters) @jax.jit def transform( - model: Model.JaxSimModel, data: Data.JaxSimModelData, *, link_index: jtp.IntLike + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_index: jtp.IntLike, ) -> jtp.Matrix: """ Compute the SE(3) transform from the world frame to the link frame. @@ -119,13 +136,13 @@ def transform( The 4x4 matrix representing the transform. """ - return Model.forward_kinematics(model=model, data=data)[link_index] + return js.model.forward_kinematics(model=model, data=data)[link_index] @jax.jit def com_position( - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, in_link_frame: jtp.BoolLike = True, @@ -168,8 +185,8 @@ def com_in_inertial_frame(): @functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def jacobian( - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, *, link_index: jtp.IntLike, output_vel_repr: VelRepr | None = None, diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 0cbc8dd0a..f7e94b9a3 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -13,6 +13,7 @@ from jax_dataclasses import Static import jaxsim.api as js +import jaxsim.parsers.descriptions import jaxsim.physics.algos.aba import jaxsim.physics.algos.crba import jaxsim.physics.algos.forward_kinematics @@ -21,7 +22,7 @@ import jaxsim.typing as jtp from jaxsim.high_level.common import VelRepr from jaxsim.physics.algos.terrain import FlatTerrain, Terrain -from jaxsim.utils import JaxsimDataclass, Mutability +from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability @jax_dataclasses.pytree_dataclass @@ -33,34 +34,25 @@ class JaxSimModel(JaxsimDataclass): model_name: Static[str] physics_model: jaxsim.physics.model.physics_model.PhysicsModel = dataclasses.field( - repr=False + repr=False, compare=False, hash=False ) - terrain: Static[Terrain] = dataclasses.field(default=FlatTerrain(), repr=False) + terrain: Static[Terrain] = dataclasses.field( + default=FlatTerrain(), repr=False, compare=False, hash=False + ) built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( - repr=False, default=None + default=None, repr=False, compare=False, hash=False ) - _number_of_links: Static[int] = dataclasses.field( - init=False, repr=False, default=None - ) + description: Static[ + HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None] + ] = dataclasses.field(default=None, repr=False, compare=False, hash=False) - _number_of_joints: Static[int] = dataclasses.field( - init=False, repr=False, default=None + kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = ( + dataclasses.field(default=None, repr=False, compare=False, hash=False) ) - def __post_init__(self): - - # These attributes are Static so that we can use `jax.vmap` and `jax.lax.scan` - # over the all links and joints - with self.mutable_context( - mutability=Mutability.MUTABLE_NO_VALIDATION, - restore_after_exception=False, - ): - self._number_of_links = len(self.physics_model.description.links_dict) - self._number_of_joints = len(self.physics_model.description.joints_dict) - # ======================== # Initialization and state # ======================== @@ -146,7 +138,14 @@ def build( ) # Build the model - model = JaxSimModel(physics_model=physics_model, model_name=model_name) # noqa + model = JaxSimModel( + physics_model=physics_model, + model_name=model_name, + description=HashlessObject(obj=physics_model.description), + kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build( + model_description=physics_model.description + ), + ) return model @@ -175,7 +174,7 @@ def number_of_links(self) -> jtp.Int: The base link is included in the count and its index is always 0. """ - return self._number_of_links + return self.kin_dyn_parameters.number_of_links() def number_of_joints(self) -> jtp.Int: """ @@ -185,7 +184,7 @@ def number_of_joints(self) -> jtp.Int: The number of joints in the model. """ - return self._number_of_joints + return self.kin_dyn_parameters.number_of_joints() # ================= # Base link methods @@ -199,7 +198,7 @@ def floating_base(self) -> bool: True if the model is floating-base, False otherwise. """ - return self.physics_model.is_floating_base + return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6) def base_link(self) -> str: """ @@ -207,9 +206,12 @@ def base_link(self) -> str: Returns: The name of the base link. + + Note: + By default, the base link is the root of the kinematic tree. """ - return self.physics_model.description.root.name + return self.link_names()[0] # ===================== # Joint-related methods @@ -227,7 +229,7 @@ def dofs(self) -> int: the number of joints. In the future, this could be different. """ - return len(self.physics_model.description.joints_dict) + return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])) def joint_names(self) -> tuple[str, ...]: """ @@ -237,7 +239,7 @@ def joint_names(self) -> tuple[str, ...]: The names of the joints in the model. """ - return tuple(self.physics_model.description.joints_dict.keys()) + return self.kin_dyn_parameters.joint_model.joint_names[1:] # ==================== # Link-related methods @@ -251,7 +253,7 @@ def link_names(self) -> tuple[str, ...]: The names of the links in the model. """ - return tuple(self.physics_model.description.links_dict.keys()) + return self.kin_dyn_parameters.link_names # ===================== @@ -279,7 +281,7 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode # Reduce the model description. # If considered_joints contains joints not existing in the model, the method # will raise an exception. - reduced_intermediate_description = model.physics_model.description.reduce( + reduced_intermediate_description = model.description.obj.reduce( considered_joints=list(considered_joints) ) @@ -297,6 +299,7 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode # Store the origin of the model, in case downstream logic needs it with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): reduced_model.built_from = model.built_from + reduced_model.description = HashlessObject(obj=physics_model.description) return reduced_model diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 8432a6806..7cb8ad03c 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -2,34 +2,30 @@ import jax import jax.numpy as jnp -import jaxlie +import jaxsim.api as js import jaxsim.physics.algos.soft_contacts import jaxsim.typing as jtp -from jaxsim import VelRepr, integrators +from jaxsim import VelRepr from jaxsim.integrators.common import Time from jaxsim.math.quaternion import Quaternion from jaxsim.physics.algos.soft_contacts import SoftContactsState from jaxsim.physics.model.physics_model_state import PhysicsModelState from jaxsim.simulation.ode_data import ODEState -from . import contact as Contact -from . import data as Data -from . import model as Model - class SystemDynamicsFromModelAndData(Protocol): def __call__( self, - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, **kwargs: dict[str, Any], ) -> tuple[ODEState, dict[str, Any]]: ... def wrap_system_dynamics_for_integration( - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, *, system_dynamics: SystemDynamicsFromModelAndData, **kwargs, @@ -72,8 +68,8 @@ def f(x: ODEState, t: Time, **kwargs) -> tuple[ODEState, dict[str, Any]]: @jax.jit def system_velocity_dynamics( - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, *, joint_forces: jtp.Vector | None = None, external_forces: jtp.Vector | None = None, @@ -123,10 +119,10 @@ def system_velocity_dynamics( # Initialize the derivative of the tangential deformation ṁ ∈ ℝ^{n_c × 3}. ṁ = jnp.zeros_like(data.state.soft_contacts.tangential_deformation).astype(float) - if len(model.physics_model.gc.body) > 0: + if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the position and linear velocities (mixed representation) of # all collidable points belonging to the robot. - W_p_Ci, W_ṗ_Ci = Contact.collidable_point_kinematics(model=model, data=data) + W_p_Ci, W_ṗ_Ci = js.contact.collidable_point_kinematics(model=model, data=data) # Compute the 3D forces applied to each collidable point. W_f_Ci, ṁ = jax.vmap( @@ -142,7 +138,10 @@ def system_velocity_dynamics( lambda nc: ( jnp.vstack( jnp.equal( - jnp.array(model.physics_model.gc.body, dtype=int), nc + jnp.array( + model.kin_dyn_parameters.contact_parameters.body, dtype=int + ), + nc, ).astype(int) ) * W_f_Ci @@ -164,8 +163,12 @@ def system_velocity_dynamics( if model.dofs() > 0: # Static and viscous joint friction parameters - kc = jnp.array(list(model.physics_model._joint_friction_static.values())) - kv = jnp.array(list(model.physics_model._joint_friction_viscous.values())) + kc = jnp.array( + model.kin_dyn_parameters.joint_parameters.friction_static + ).astype(float) + kv = jnp.array( + model.kin_dyn_parameters.joint_parameters.friction_viscous + ).astype(float) # Compute the joint friction torque τ_friction = -( @@ -186,7 +189,7 @@ def system_velocity_dynamics( # - Joint accelerations: s̈ ∈ ℝⁿ # - Base inertial-fixed acceleration: W_v̇_WB = (W_p̈_B, W_ω̇_B) ∈ ℝ⁶ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - W_v̇_WB, s̈ = Model.forward_dynamics_aba( + W_v̇_WB, s̈ = js.model.forward_dynamics_aba( model=model, data=data, joint_forces=τ_total, @@ -198,7 +201,7 @@ def system_velocity_dynamics( @jax.jit def system_position_dynamics( - model: Model.JaxSimModel, data: Data.JaxSimModelData + model: js.model.JaxSimModel, data: js.data.JaxSimModelData ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: """ Compute the dynamics of the system position. @@ -232,8 +235,8 @@ def system_position_dynamics( @jax.jit def system_dynamics( - model: Model.JaxSimModel, - data: Data.JaxSimModelData, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, *, joint_forces: jtp.Vector | None = None, external_forces: jtp.Vector | None = None, diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 916ed4321..3a82a4e97 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -7,7 +7,6 @@ import jax_dataclasses import jaxsim.api as js -import jaxsim.physics.model.physics_model_state import jaxsim.typing as jtp from jaxsim import VelRepr from jaxsim.simulation.ode_data import ODEInput @@ -188,7 +187,7 @@ def link_forces( # If we have the model, we can extract the link names, if not provided. link_names = link_names if link_names is not None else model.link_names() - link_idxs = jaxsim.api.link.names_to_idxs(link_names=link_names, model=model) + link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) # In inertial-fixed representation, we already have the link forces. if self.velocity_representation is VelRepr.Inertial: @@ -379,7 +378,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: # If we have the model, we can extract the link names if not provided. link_names = link_names if link_names is not None else model.link_names() - link_idxs = jaxsim.api.link.names_to_idxs(link_names=link_names, model=model) + link_idxs = js.link.names_to_idxs(link_names=link_names, model=model) # Compute the bias depending on whether we either set or add the link forces. W_f0_L = ( diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py new file mode 100644 index 000000000..a141045e4 --- /dev/null +++ b/src/jaxsim/math/joint_model.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import functools + +import jax +import jax.numpy as jnp +import jax_dataclasses +import jaxlie +from jax_dataclasses import Static + +import jaxsim.typing as jtp +from jaxsim.math.rotation import Rotation +from jaxsim.parsers.descriptions import ( + JointDescriptor, + JointGenericAxis, + JointType, + ModelDescription, +) + + +@jax_dataclasses.pytree_dataclass +class JointModel: + """ + Class describing the joint kinematics of a robot model. + """ + + λ_H_pre: jax.Array + suc_H_i: jax.Array + + joint_dofs: Static[tuple[int, ...]] + joint_names: Static[tuple[str, ...]] + joint_types: Static[tuple[JointType | JointDescriptor, ...]] + + @staticmethod + def build(description: ModelDescription) -> JointModel: + """ + Build the joint model of a model description. + + Args: + description: The model description to consider. + + Returns: + The joint model of the considered model description. + """ + + # The link index is equal to its body index: [0, number_of_bodies - 1]. + ordered_links = sorted( + list(description.links_dict.values()), + key=lambda l: l.index, + ) + + # Note: the joint index is equal to its child link index, therefore it + # starts from 1. + ordered_joints = sorted( + list(description.joints_dict.values()), + key=lambda j: j.index, + ) + + # Allocate the parent-to-predecessor and successor-to-child transforms. + λ_H_pre = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float) + suc_H_i = jnp.zeros(shape=(1 + len(ordered_joints), 4, 4), dtype=float) + + # Initialize an identical parent-to-predecessor transform for the joint + # between the world frame W and the base link B. + λ_H_pre = λ_H_pre.at[0].set(jnp.eye(4)) + + # Initialize the successor-to-child transform of the joint between the + # world frame W and the base link B. + # We store here the optional transform between the root frame of the model + # and the base link frame (this is needed only if the pose of the link frame + # w.r.t. the implicit __model__ SDF frame is not the identity). + suc_H_i = suc_H_i.at[0].set(ordered_links[0].pose) + + # Compute the parent-to-predecessor and successor-to-child transforms for + # each joint belonging to the model. + # Note that the joint indices starts from i=1 given our joint model, + # therefore the entries at index 0 are not updated. + for joint in ordered_joints: + λ_H_pre = λ_H_pre.at[joint.index].set( + description.relative_transform( + relative_to=joint.parent.name, + name=joint.name, + ) + ) + suc_H_i = suc_H_i.at[joint.index].set( + description.relative_transform( + relative_to=joint.name, name=joint.child.name + ) + ) + + # Define the DoFs of the base link. + base_dofs = 0 if description.fixed_base else 6 + + # We always add a dummy fixed joint between world and base. + # TODO: Port floating-base support also at this level, not only in RBDAs. + return JointModel( + λ_H_pre=λ_H_pre, + suc_H_i=suc_H_i, + # Static attributes + joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]), + joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), + joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]), + ) + + def parent_H_child( + self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike + ) -> tuple[jtp.Matrix, jtp.Array]: + """ + Compute the homogeneous transformation between the parent link and + the child link of a joint, and the corresponding motion subspace. + + Args: + joint_index: The index of the joint. + joint_position: The position of the joint. + + Returns: + A tuple containing the homogeneous transformation + :math:`{}^{\lambda(\text{i})} \mathbf{H}_\text{i}(s)` + and the motion subspace :math:`\mathbf{S}(s)`. + """ + + i = joint_index + s = joint_position + + # Get the components of the joint model. + λ_Hi_pre = self.parent_H_predecessor(joint_index=i) + pre_Hi_suc, S = self.predecessor_H_successor(joint_index=i, joint_position=s) + suc_Hi_i = self.successor_H_child(joint_index=i) + + # Compose all the transforms. + return λ_Hi_pre @ pre_Hi_suc @ suc_Hi_i, S + + @jax.jit + def child_H_parent( + self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike + ) -> tuple[jtp.Matrix, jtp.Array]: + """ + Compute the homogeneous transformation between the child link and + the parent link of a joint, and the corresponding motion subspace. + + Args: + joint_index: The index of the joint. + joint_position: The position of the joint. + + Returns: + A tuple containing the homogeneous transformation + :math:`{}^{\text{i}} \mathbf{H}_{\lambda(\text{i})}(s)` + and the motion subspace :math:`\mathbf{S}(s)`. + """ + + λ_Hi_i, S = self.parent_H_child( + joint_index=joint_index, joint_position=joint_position + ) + + i_Hi_λ = jaxlie.SE3.from_matrix(λ_Hi_i).inverse().as_matrix() + + return i_Hi_λ, S + + def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix: + """ + Return the homogeneous transformation between the parent link and + the predecessor frame of a joint. + + Args: + joint_index: The index of the joint. + + Returns: + The homogeneous transformation + :math:`{}^{\lambda(\text{i})} \mathbf{H}_{\text{pre}(\text{i})}`. + """ + + return self.λ_H_pre[joint_index] + + def predecessor_H_successor( + self, joint_index: jtp.IntLike, joint_position: jtp.VectorLike + ) -> tuple[jtp.Matrix, jtp.Array]: + """ + Compute the homogeneous transformation between the predecessor and + the successor frame of a joint, and the corresponding motion subspace. + + Args: + joint_index: The index of the joint. + joint_position: The position of the joint. + + Returns: + A tuple containing the homogeneous transformation + :math:`{}^{\text{pre}(\text{i})} \mathbf{H}_{\text{suc}(\text{i})}(s)` + and the motion subspace :math:`\mathbf{S}(s)`. + """ + + pre_H_suc, S = supported_joint_motion( + joint_type=self.joint_types[joint_index], + joint_position=joint_position, + ) + + return pre_H_suc, S + + def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix: + """ + Return the homogeneous transformation between the successor frame and + the child link of a joint. + + Args: + joint_index: The index of the joint. + + Returns: + The homogeneous transformation + :math:`{}^{\text{suc}(\text{i})} \mathbf{H}_{\text{i}}`. + """ + + return self.suc_H_i[joint_index] + + +@functools.partial(jax.jit, static_argnames=["joint_type"]) +def supported_joint_motion( + joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike +) -> tuple[jtp.Matrix, jtp.Array]: + """ + Compute the homogeneous transformation and motion subspace of a joint. + + Args: + joint_type: The type of the joint. + joint_position: The position of the joint. + + Returns: + A tuple containing the homogeneous transformation and the motion subspace. + """ + + if isinstance(joint_type, JointType): + code = joint_type + elif isinstance(joint_type, JointDescriptor): + code = joint_type.code + else: + raise ValueError(joint_type) + + # Prepare the joint position + s = jnp.array(joint_position).astype(float) + + match code: + + case JointType.R: + joint_type: JointGenericAxis + + pre_H_suc = jaxlie.SE3.from_rotation( + rotation=jaxlie.SO3.from_matrix( + Rotation.from_axis_angle(vector=s * joint_type.axis) + ) + ) + + S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()])) + + case JointType.P: + joint_type: JointGenericAxis + + pre_H_suc = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3.identity(), + translation=jnp.array(s * joint_type.axis), + ) + + S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)])) + + case JointType.F: + raise ValueError("Fixed joints shouldn't be here") + + case JointType.Rx: + + pre_H_suc = jaxlie.SE3.from_rotation( + rotation=jaxlie.SO3.from_x_radians(theta=s) + ) + + S = jnp.vstack([0, 0, 0, 1.0, 0, 0]) + + case JointType.Ry: + + pre_H_suc = jaxlie.SE3.from_rotation( + rotation=jaxlie.SO3.from_y_radians(theta=s) + ) + + S = jnp.vstack([0, 0, 0, 0, 1.0, 0]) + + case JointType.Rz: + + pre_H_suc = jaxlie.SE3.from_rotation( + rotation=jaxlie.SO3.from_z_radians(theta=s) + ) + + S = jnp.vstack([0, 0, 0, 0, 0, 1.0]) + + case JointType.Px: + + pre_H_suc = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3.identity(), + translation=jnp.array([s, 0.0, 0.0]), + ) + + S = jnp.vstack([1.0, 0, 0, 0, 0, 0]) + + case JointType.Py: + + pre_H_suc = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3.identity(), + translation=jnp.array([0.0, s, 0.0]), + ) + + S = jnp.vstack([0, 1.0, 0, 0, 0, 0]) + + case JointType.Pz: + + pre_H_suc = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3.identity(), + translation=jnp.array([0.0, 0.0, s]), + ) + + S = jnp.vstack([0, 0, 1.0, 0, 0, 0]) + + case _: + raise ValueError(joint_type) + + return pre_H_suc.as_matrix(), S diff --git a/src/jaxsim/physics/algos/soft_contacts.py b/src/jaxsim/physics/algos/soft_contacts.py index ab86b3554..f4eeadf9b 100644 --- a/src/jaxsim/physics/algos/soft_contacts.py +++ b/src/jaxsim/physics/algos/soft_contacts.py @@ -247,11 +247,11 @@ def process_point_kinematics( return W_p_Ci, CW_vl_WCi # Process all the collidable points in parallel - W_p_Ci, CW_v_WC = jax.vmap(process_point_kinematics)( - model.gc.point.T, np.array(model.gc.body, dtype=int) + W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)( + model.gc.point, jnp.array(model.gc.body) ) - return W_p_Ci.transpose(), CW_v_WC.transpose() + return W_p_Ci.transpose(), CW_vl_WC.transpose() @jax_dataclasses.pytree_dataclass diff --git a/src/jaxsim/physics/model/ground_contact.py b/src/jaxsim/physics/model/ground_contact.py index beea198de..20c3f6829 100644 --- a/src/jaxsim/physics/model/ground_contact.py +++ b/src/jaxsim/physics/model/ground_contact.py @@ -1,53 +1,72 @@ +from __future__ import annotations + import dataclasses import jax.numpy as jnp import jax_dataclasses -import numpy as np -import numpy.typing as npt from jax_dataclasses import Static +import jaxsim.typing as jtp from jaxsim.parsers.descriptions import ModelDescription @jax_dataclasses.pytree_dataclass class GroundContact: """ - A class for managing collidable points in a robot model. + Class storing the collidable points of a robot model. - This class is used to store and manage information about collidable points on a robot model, - such as their positions and the corresponding bodies (links) they are associated with. + This class is used to store and manage information about collidable points + of a robot model, such as their positions and the corresponding bodies (links) + they are rigidly attached to. Attributes: - point (npt.NDArray): An array of shape (3, N) representing the 3D positions of collidable points. - body (Static[npt.NDArray]): An array of integers representing the indices of the bodies (links) associated with each collidable point. + point: + An array of shape (N, 3) representing the displacement of collidable points + w.r.t the origin of their parent body. + body: + An array of integers representing the indices of the bodies (links) + associated to each collidable point. """ - point: npt.NDArray = dataclasses.field(default_factory=lambda: jnp.array([])) - body: Static[list] = dataclasses.field(default_factory=lambda: []) + body: Static[tuple[int, ...]] = dataclasses.field(default_factory=lambda: []) + + point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([])) @staticmethod - def build_from( - model_description: ModelDescription, - ) -> "GroundContact": + def build_from(model_description: ModelDescription) -> GroundContact: + """ + Build a GroundContact object from a model description. + + Args: + model_description: The model description to consider. + + Returns: + The GroundContact object. + """ + if len(model_description.collision_shapes) == 0: return GroundContact() - # Get all the links so that we can take their updated index + # Get all the links so that we can take their updated index. links_dict = {link.name: link for link in model_description} - # Get all the enabled collidable points of the model + # Get all the enabled collidable points of the model. collidable_points = model_description.all_enabled_collidable_points() - # Build the GroundContact attributes - points = jnp.vstack([cp.position for cp in collidable_points]).T + # Extract the positions L_p_C of the collidable points w.r.t. the link frames + # they are rigidly attached to. + points = jnp.vstack([cp.position for cp in collidable_points]) + + # Extract the indices of the links to which the collidable points are rigidly + # attached to. link_index_of_points = [ links_dict[cp.parent_link.name].index for cp in collidable_points ] - # Build the object - gc = GroundContact(point=points, body=link_index_of_points) + # Build the GroundContact object. + gc = GroundContact(point=points, body=tuple(link_index_of_points)) # noqa - assert gc.point.shape[0] == 3 - assert gc.point.shape[1] == len(gc.body) + assert gc.point.shape[1] == 3 + assert gc.point.shape[0] == len(gc.body) return gc diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index b79fd990f..0e9509c29 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -1,5 +1,6 @@ from jax_dataclasses._copy_and_mutate import _Mutability as Mutability +from .hashless import HashlessObject from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing from .vmappable import Vmappable diff --git a/src/jaxsim/utils/hashless.py b/src/jaxsim/utils/hashless.py new file mode 100644 index 000000000..9a48fb437 --- /dev/null +++ b/src/jaxsim/utils/hashless.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import dataclasses +from typing import Generic, TypeVar + +T = TypeVar("T") + + +@dataclasses.dataclass +class HashlessObject(Generic[T]): + + obj: T + + def get(self: HashlessObject[T]) -> T: + return self.obj + + def __hash__(self) -> int: + return 0