From 1d4957a9b525b02f7d5b0180b90126c1a7117d99 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 13 Jun 2024 11:52:31 +0200 Subject: [PATCH 01/12] Add HashedNumpyArray.hash_of_array static method --- src/jaxsim/utils/wrappers.py | 42 +++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/utils/wrappers.py b/src/jaxsim/utils/wrappers.py index fd0a29dd8..00b055cc5 100644 --- a/src/jaxsim/utils/wrappers.py +++ b/src/jaxsim/utils/wrappers.py @@ -56,6 +56,10 @@ class HashedNumpyArray: array: jax.Array | npt.NDArray + precision: float | None = dataclasses.field( + default=1e-9, repr=False, compare=False, hash=False + ) + large_array: jax_dataclasses.Static[bool] = dataclasses.field( default=False, repr=False, compare=False, hash=False ) @@ -65,7 +69,9 @@ def get(self) -> jax.Array | npt.NDArray: def __hash__(self) -> int: - return hash(tuple(np.atleast_1d(self.array).flatten().tolist())) + return HashedNumpyArray.hash_of_array( + array=self.array, precision=self.precision + ) def __eq__(self, other: HashedNumpyArray) -> bool: @@ -76,3 +82,37 @@ def __eq__(self, other: HashedNumpyArray) -> bool: return np.array_equal(self.array, other.array) return hash(self) == hash(other) + + @staticmethod + def hash_of_array( + array: jax.Array | npt.NDArray, precision: float | None = 1e-9 + ) -> int: + """ + Calculate the hash of a NumPy array. + + Args: + array: The array to hash. + precision: Optionally limit the precision over which the hash is computed. + + Returns: + The hash of the array. + """ + + array = np.array(array).flatten() + + array = np.where(array == np.nan, hash(np.nan), array) + array = np.where(array == np.inf, hash(np.inf), array) + array = np.where(array == -np.inf, hash(-np.inf), array) + + if precision is not None: + + integer1 = (array * precision).astype(int) + integer2 = (array - integer1 / precision).astype(int) + + decimal_array = ((array - integer1 * 1e9 - integer2) / precision).astype( + int + ) + + array = np.hstack([integer1, integer2, decimal_array]).astype(int) + + return hash(tuple(array.tolist())) From 8a41d4cb8de0cd19c9a752a7870e826ca1581c88 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 13 Jun 2024 13:38:20 +0200 Subject: [PATCH 02/12] Use more robust albeit slower logic on arrays in hash methods --- src/jaxsim/api/kin_dyn_parameters.py | 4 +--- src/jaxsim/api/model.py | 10 ++++----- src/jaxsim/parsers/descriptions/joint.py | 26 +++++++++++++----------- src/jaxsim/parsers/descriptions/link.py | 10 +++++---- src/jaxsim/parsers/descriptions/model.py | 4 ++-- 5 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index c185b83fc..d018fdba9 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -6,7 +6,6 @@ import jax.numpy as jnp import jax_dataclasses import jaxlie -import numpy as np from jax_dataclasses import Static import jaxsim.typing as jtp @@ -15,7 +14,7 @@ from jaxsim.utils import HashedNumpyArray, JaxsimDataclass -@jax_dataclasses.pytree_dataclass +@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class KynDynParameters(JaxsimDataclass): r""" Class storing the kinematic and dynamic parameters of a model. @@ -221,7 +220,6 @@ def __hash__(self) -> int: ( hash(self.number_of_links()), hash(self.number_of_joints()), - hash(tuple(np.atleast_1d(self.parent_array).flatten().tolist())), hash(self._parent_array), hash(self._support_body_array_bool), ) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 3fafaa1ee..96649c8f8 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -22,7 +22,7 @@ from .common import VelRepr -@jax_dataclasses.pytree_dataclass +@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JaxSimModel(JaxsimDataclass): """ The JaxSim model defining the kinematics and dynamics of a robot. @@ -31,19 +31,19 @@ class JaxSimModel(JaxsimDataclass): model_name: Static[str] terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( - default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False + default=jaxsim.terrain.FlatTerrain(), repr=False ) kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = ( - dataclasses.field(default=None, repr=False, compare=False, hash=False) + dataclasses.field(default=None, repr=False) ) built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( - default=None, repr=False, compare=False, hash=False + default=None, repr=False ) description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = ( - dataclasses.field(default=None, repr=False, compare=False, hash=False) + dataclasses.field(default=None, repr=False) ) def __eq__(self, other: JaxSimModel) -> bool: diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 1aa4eef84..c6539bffb 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -41,7 +41,7 @@ def __eq__(self, other: JointGenericAxis) -> bool: return hash(self) == hash(other) -@jax_dataclasses.pytree_dataclass +@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class JointDescription(JaxsimDataclass): """ In-memory description of a robot link. @@ -97,23 +97,25 @@ def __post_init__(self) -> None: def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + return hash( ( hash(self.name), - hash(tuple(self.axis.tolist())), - hash(tuple(self.pose.flatten().tolist())), + HashedNumpyArray.hash_of_array(self.axis), + HashedNumpyArray.hash_of_array(self.pose), hash(int(self.jtype)), hash(self.child), hash(self.parent), hash(int(self.index)) if self.index is not None else 0, - hash(float(self.friction_static)), - hash(float(self.friction_viscous)), - hash(float(self.position_limit_damper)), - hash(float(self.position_limit_spring)), - hash((float(el) for el in self.position_limit)), - hash(tuple(np.atleast_1d(self.initial_position).tolist())), - hash(float(self.motor_inertia)), - hash(float(self.motor_viscous_friction)), - hash(float(self.motor_gear_ratio)), + HashedNumpyArray.hash_of_array(np.array(self.friction_static)), + HashedNumpyArray.hash_of_array(np.array(self.friction_viscous)), + HashedNumpyArray.hash_of_array(np.array(self.position_limit_damper)), + HashedNumpyArray.hash_of_array(np.array(self.position_limit_spring)), + HashedNumpyArray.hash_of_array(np.array(self.position_limit)), + HashedNumpyArray.hash_of_array(self.initial_position), + HashedNumpyArray.hash_of_array(np.array(self.motor_inertia)), + HashedNumpyArray.hash_of_array(np.array(self.motor_viscous_friction)), + HashedNumpyArray.hash_of_array(np.array(self.motor_gear_ratio)), ), ) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 859aa7122..59e46f702 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -12,7 +12,7 @@ from jaxsim.utils import JaxsimDataclass -@jax_dataclasses.pytree_dataclass +@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False) class LinkDescription(JaxsimDataclass): """ In-memory description of a robot link. @@ -40,13 +40,15 @@ class LinkDescription(JaxsimDataclass): def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + return hash( ( hash(self.name), hash(float(self.mass)), - hash(tuple(np.atleast_1d(self.inertia).flatten().tolist())), - hash(int(self.index)) if self.index is not None else 0, - hash(tuple(np.atleast_1d(self.pose).flatten().tolist())), + HashedNumpyArray.hash_of_array(self.inertia), + hash(int(self.index)) if self.index is not None else self.index, + HashedNumpyArray.hash_of_array(self.pose), hash(tuple(self.children)), # Here only using the name to prevent circular recursion: hash(self.parent.name) if self.parent is not None else 0, diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 14f294ac2..6c4baf1a7 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -12,7 +12,7 @@ from .link import LinkDescription -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, eq=False, unsafe_hash=False) class ModelDescription(KinematicGraph): """ Intermediate representation representing the kinematic graph of a robot model. @@ -28,7 +28,7 @@ class ModelDescription(KinematicGraph): fixed_base: bool = True collision_shapes: tuple[CollisionShape, ...] = dataclasses.field( - default_factory=list, repr=False, hash=False + default_factory=list, repr=False ) @staticmethod From 7112e0f12906471a40c4d7162cdf7a2954d9e663 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 13 Jun 2024 15:39:31 +0200 Subject: [PATCH 03/12] Speed up eq methods by not using hash --- src/jaxsim/api/model.py | 15 ++++++- src/jaxsim/parsers/descriptions/joint.py | 55 ++++++++++++++++++++++++ src/jaxsim/parsers/descriptions/link.py | 29 ++++++++++++- src/jaxsim/parsers/descriptions/model.py | 20 ++++++++- 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 96649c8f8..5fa276cac 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -51,15 +51,26 @@ def __eq__(self, other: JaxSimModel) -> bool: if not isinstance(other, JaxSimModel): return False - return hash(self) == hash(other) + if self.model_name != other.model_name: + return False + + if self.kin_dyn_parameters != other.kin_dyn_parameters: + return False + + # Here we compare only the static quantities of ModelDescription + # that are actually used by our APIs. + if self.description.frames != other.description.frames: + return False + + return True def __hash__(self) -> int: return hash( ( hash(self.model_name), - hash(self.description), hash(self.kin_dyn_parameters), + hash(self.description), ) ) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index c6539bffb..a5503e1b1 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -95,6 +95,61 @@ def __post_init__(self) -> None: norm_of_axis = np.linalg.norm(self.axis) self.axis = self.axis / norm_of_axis + def __eq__(self, other: JointDescription) -> bool: + + if not isinstance(other, JointDescription): + return False + + if self.name != other.name: + return False + + if not np.allclose(self.axis, other.axis): + return False + + if not np.allclose(self.pose, other.pose): + return False + + if self.jtype != other.jtype: + return False + + if self.child != other.child: + return False + + if self.parent != other.parent: + return False + + if self.index != other.index: + return False + + if not np.allclose(self.friction_static, other.friction_static): + return False + + if not np.allclose(self.friction_viscous, other.friction_viscous): + return False + + if not np.allclose(self.position_limit_damper, other.position_limit_damper): + return False + + if not np.allclose(self.position_limit_spring, other.position_limit_spring): + return False + + if not np.allclose(self.position_limit, other.position_limit): + return False + + if not np.allclose(self.initial_position, other.initial_position): + return False + + if not np.allclose(self.motor_inertia, other.motor_inertia): + return False + + if not np.allclose(self.motor_viscous_friction, other.motor_viscous_friction): + return False + + if not np.allclose(self.motor_gear_ratio, other.motor_gear_ratio): + return False + + return True + def __hash__(self) -> int: from jaxsim.utils.wrappers import HashedNumpyArray diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 59e46f702..37955bbed 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -31,7 +31,7 @@ class LinkDescription(JaxsimDataclass): mass: float = dataclasses.field(repr=False) inertia: jtp.Matrix = dataclasses.field(repr=False) index: int | None = None - parent: LinkDescription = dataclasses.field(default=None, repr=False) + parent: LinkDescription | None = dataclasses.field(default=None, repr=False) pose: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.eye(4), repr=False) children: Static[tuple[LinkDescription]] = dataclasses.field( @@ -60,7 +60,32 @@ def __eq__(self, other: LinkDescription) -> bool: if not isinstance(other, LinkDescription): return False - return hash(self) == hash(other) + if self.name != other.name: + return False + + if not np.allclose(self.mass, other.mass): + return False + + if not np.allclose(self.inertia, other.inertia): + return False + + if self.index != other.index: + return False + + if not np.allclose(self.pose, other.pose): + return False + + if self.children != other.children: + return False + + # Here only using the name to prevent circular recursion + if self.parent is not None and self.parent.name != other.parent.name: + return False + + if self.parent is None and other.parent is not None: + return False + + return True @property def name_and_index(self) -> str: diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 6c4baf1a7..1faf9ebe0 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -249,7 +249,25 @@ def __eq__(self, other: ModelDescription) -> bool: if not isinstance(other, ModelDescription): return False - return hash(self) == hash(other) + if self.name != other.name: + return False + + if self.fixed_base != other.fixed_base: + return False + + if self.root != other.root: + return False + + if self.joints != other.joints: + return False + + if self.frames != other.frames: + return False + + if self.root_pose != other.root_pose: + return False + + return True def __hash__(self) -> int: From d1f5f59a9e9de0ef3dff4af3a70f550d87ecdd5d Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 13 Jun 2024 13:37:43 +0200 Subject: [PATCH 04/12] Make RootPose a dataclass --- src/jaxsim/parsers/kinematic_graph.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 7588af660..1e289d931 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -3,7 +3,7 @@ import copy import dataclasses import functools -from typing import Any, Callable, Iterable, NamedTuple, Sequence +from typing import Any, Callable, Iterable, Sequence import numpy as np import numpy.typing as npt @@ -15,7 +15,8 @@ from . import descriptions -class RootPose(NamedTuple): +@dataclasses.dataclass +class RootPose: """ Represents the root pose in a kinematic graph. @@ -28,15 +29,20 @@ class RootPose(NamedTuple): The root link of the kinematic graph is the base link. """ - root_position: npt.NDArray = np.zeros(3) - root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0]) + root_position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3)) + + root_quaternion: npt.NDArray = dataclasses.field( + default_factory=lambda: np.array([1.0, 0, 0, 0]) + ) def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + return hash( ( - hash(tuple(self.root_position.tolist())), - hash(tuple(self.root_quaternion.tolist())), + HashedNumpyArray.hash_of_array(self.root_position), + HashedNumpyArray.hash_of_array(self.root_quaternion), ) ) @@ -45,7 +51,13 @@ def __eq__(self, other: RootPose) -> bool: if not isinstance(other, RootPose): return False - return hash(self) == hash(other) + if not np.allclose(self.root_position, other.root_position): + return False + + if not np.allclose(self.root_quaternion, other.root_quaternion): + return False + + return True @dataclasses.dataclass(frozen=True) From 240102744b8cc7039a5553d0764f231be02d5d2b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 13 Jun 2024 15:43:19 +0200 Subject: [PATCH 05/12] Add wrappers.CustomHashedObject --- src/jaxsim/utils/wrappers.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/utils/wrappers.py b/src/jaxsim/utils/wrappers.py index 00b055cc5..1750d2c37 100644 --- a/src/jaxsim/utils/wrappers.py +++ b/src/jaxsim/utils/wrappers.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Generic, TypeVar +from typing import Callable, Generic, TypeVar import jax import jax_dataclasses @@ -40,6 +40,33 @@ def __eq__(self, other: HashlessObject[T]) -> bool: return hash(self) == hash(other) +@dataclasses.dataclass +class CustomHashedObject(Generic[T]): + """ + A class that wraps an object and computes its hash with a custom hash function. + """ + + obj: T + + hash_function: Callable[[T], int] = dataclasses.field(default=lambda obj: hash(obj)) + + def get(self: CustomHashedObject[T]) -> T: + return self.obj + + def __hash__(self) -> int: + + return self.hash_function(self.obj) + + def __eq__(self, other: CustomHashedObject[T]) -> bool: + + if not isinstance(other, CustomHashedObject) and isinstance( + other.get(), type(self.get()) + ): + return False + + return hash(self) == hash(other) + + @jax_dataclasses.pytree_dataclass class HashedNumpyArray: """ From 9857439074a22679e6b3898f742e5cfcbd462e0d Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 13 Jun 2024 15:42:27 +0200 Subject: [PATCH 06/12] Improve speed of calling JIT-compiled methods on compatible models --- src/jaxsim/api/model.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 5fa276cac..8a5e9bde3 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -17,7 +17,7 @@ import jaxsim.parsers.descriptions import jaxsim.typing as jtp from jaxsim.math import Cross -from jaxsim.utils import JaxsimDataclass, Mutability +from jaxsim.utils import JaxsimDataclass, Mutability, wrappers from .common import VelRepr @@ -42,9 +42,13 @@ class JaxSimModel(JaxsimDataclass): default=None, repr=False ) - description: Static[jaxsim.parsers.descriptions.ModelDescription | None] = ( - dataclasses.field(default=None, repr=False) - ) + _description: Static[ + wrappers.CustomHashedObject[jaxsim.parsers.descriptions.ModelDescription] | None + ] = dataclasses.field(default=None, repr=False) + + @property + def description(self) -> jaxsim.parsers.descriptions.ModelDescription: + return self._description.get() def __eq__(self, other: JaxSimModel) -> bool: @@ -163,10 +167,13 @@ def build( # Set the model name (if not provided, use the one from the model description) model_name = model_name if model_name is not None else model_description.name - # Build the model + # Build the model. model = JaxSimModel( model_name=model_name, - description=model_description, + _description=wrappers.CustomHashedObject( + obj=model_description, + hash_function=lambda desc: hash(tuple(desc.frames)), + ), kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build( model_description=model_description ), From a781e2e614655369b28d74c05df4d6714addbfdf Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 14 Jun 2024 10:52:41 +0200 Subject: [PATCH 07/12] Copy frame-related data from ModelDescription to KinDynParameters In this way, JaxSimModel.description is never used by JIT-compiled functions, and can be treated as an ignored static leaf of the pytree (otherwise computing its hash would be expensive). --- src/jaxsim/api/frame.py | 56 +++++++++--------- src/jaxsim/api/kin_dyn_parameters.py | 85 ++++++++++++++++++++++++++++ src/jaxsim/api/model.py | 19 +++---- 3 files changed, 122 insertions(+), 38 deletions(-) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 0db020bf5..760ed874e 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -17,7 +17,9 @@ # ======================= -def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) -> int: +def idx_of_parent_link( + model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike +) -> jtp.Int: """ Get the index of the link to which the frame is rigidly attached. @@ -29,17 +31,13 @@ def idx_of_parent_link(model: js.model.JaxSimModel, *, frame_idx: jtp.IntLike) - The index of the frame's parent link. """ - # Get the intermediate representation parsed from the model description. - ir = model.description + return model.kin_dyn_parameters.frame_parameters.body[ + frame_idx - model.number_of_links() + ] - # Extract the indices of the frame and the link it is attached to. - F = ir.frames[frame_idx - model.number_of_links()] - L = ir.links_dict[F.parent.name].index - return int(L) - - -def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int: +@functools.partial(jax.jit, static_argnames="frame_name") +def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> jtp.Int: """ Convert the name of a frame to its index. @@ -51,13 +49,19 @@ def name_to_idx(model: js.model.JaxSimModel, *, frame_name: str) -> int: The index of the frame. """ - frame_names = np.array([frame.name for frame in model.description.frames]) + if frame_name in model.kin_dyn_parameters.frame_parameters.name: + return ( + jnp.array( + np.argwhere( + np.array(model.kin_dyn_parameters.frame_parameters.name) + == frame_name + ) + ) + .squeeze() + .astype(int) + ) + model.number_of_links() - if frame_name in frame_names: - idx_in_list = np.argwhere(frame_names == frame_name) - return int(idx_in_list.squeeze().tolist()) + model.number_of_links() - - return -1 + return jnp.array(-1).astype(int) def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str: @@ -72,7 +76,9 @@ def idx_to_name(model: js.model.JaxSimModel, *, frame_index: jtp.IntLike) -> str The name of the frame. """ - return model.description.frames[frame_index - model.number_of_links()].name + return model.kin_dyn_parameters.frame_parameters.name[ + frame_index - model.number_of_links() + ] @functools.partial(jax.jit, static_argnames=["frame_names"]) @@ -91,7 +97,7 @@ def names_to_idxs( """ return jnp.array( - [name_to_idx(model=model, frame_name=frame_name) for frame_name in frame_names] + [name_to_idx(model=model, frame_name=name) for name in frame_names] ).astype(int) @@ -109,10 +115,7 @@ def idxs_to_names( The names of the frames. """ - return tuple( - idx_to_name(model=model, frame_index=frame_index) - for frame_index in frame_indices - ) + return tuple(idx_to_name(model=model, frame_index=idx) for idx in frame_indices) # ========== @@ -120,7 +123,7 @@ def idxs_to_names( # ========== -@functools.partial(jax.jit, static_argnames=["frame_index"]) +@jax.jit def transform( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -144,14 +147,15 @@ def transform( W_H_L = js.link.transform(model=model, data=data, link_index=L) # Get the static frame pose wrt the parent link. - frame = model.description.frames[frame_index - model.number_of_links()] - L_H_F = frame.pose + L_H_F = model.kin_dyn_parameters.frame_parameters.transform[ + frame_index - model.number_of_links() + ] # Combine the transforms computing the frame pose. return W_H_L @ L_H_F -@functools.partial(jax.jit, static_argnames=["frame_index", "output_vel_repr"]) +@functools.partial(jax.jit, static_argnames=["output_vel_repr"]) def jacobian( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index d018fdba9..defd01f58 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -25,6 +25,7 @@ class KynDynParameters(JaxsimDataclass): support_body_array_bool: The boolean support parent array :math:`\kappa_{b}(i)` of the model. link_parameters: The parameters of the links. + frame_parameters: The parameters of the frames. contact_parameters: The parameters of the collidable points. joint_model: The joint model of the model. joint_parameters: The parameters of the joints. @@ -41,6 +42,9 @@ class KynDynParameters(JaxsimDataclass): # Contacts contact_parameters: ContactParameters + # Frames + frame_parameters: FrameParameters + # Joints joint_model: JointModel joint_parameters: JointParameters | None @@ -140,6 +144,19 @@ def build(model_description: ModelDescription) -> KynDynParameters: model_description=model_description ) + # ================= + # Frames properties + # ================= + + # Create the object storing the parameters of frames. + # Note that, contrarily to LinkParameters and JointsParameters, this object + # is not created with vmap. This is because the "name" attribute of the object + # must be Static for JIT-related reasons, and tree_map would not consider it + # as a leaf. + frame_parameters = FrameParameters.build_from( + model_description=model_description + ) + # =============== # Tree properties # =============== @@ -205,6 +222,7 @@ def scan_body(carry: tuple, i: jtp.Int) -> tuple[tuple, None]: joint_model=joint_model, joint_parameters=joint_parameters, contact_parameters=contact_parameters, + frame_parameters=frame_parameters, ) def __eq__(self, other: KynDynParameters) -> bool: @@ -220,6 +238,8 @@ def __hash__(self) -> int: ( hash(self.number_of_links()), hash(self.number_of_joints()), + hash(self.frame_parameters.name), + hash(tuple(self.frame_parameters.body.tolist())), hash(self._parent_array), hash(self._support_body_array_bool), ) @@ -778,3 +798,68 @@ def build_from(model_description: ModelDescription) -> ContactParameters: assert cp.point.shape[0] == len(cp.body) return cp + + +@jax_dataclasses.pytree_dataclass +class FrameParameters(JaxsimDataclass): + """ + Class storing the frame parameters of a model. + + Attributes: + name: A tuple of strings defining the frame names. + body: + A vector of integers representing, for each frame, the index of + the body (link) to which it is rigidly attached to. + transform: The transforms of the frames w.r.t. their parent link. + + Note: + Contrarily to LinkParameters and JointParameters, this class is not meant + to be created with vmap. This is because the `name` attribute must be `Static`. + """ + + name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple) + + body: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + + transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([])) + + @staticmethod + def build_from(model_description: ModelDescription) -> FrameParameters: + """ + Build a FrameParameters object from a model description. + + Args: + model_description: The model description to consider. + + Returns: + The FrameParameters object. + """ + + if len(model_description.frames) == 0: + return FrameParameters() + + # Extract the frame names. + names = tuple(frame.name for frame in model_description.frames) + + # For each frame, extract the index of the link to which it is attached to. + parent_link_index_of_frames = tuple( + model_description.links_dict[frame.parent.name].index + for frame in model_description.frames + ) + + # For each frame, extract the transform w.r.t. its parent link. + transforms = jnp.atleast_3d( + jnp.stack([frame.pose for frame in model_description.frames]) + ) + + # Build the FrameParameters object. + fp = FrameParameters( + name=names, + transform=transforms.astype(float), + body=jnp.array(parent_link_index_of_frames).astype(int), + ) + + assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:] + assert fp.transform.shape[0] == len(fp.body), fp.transform.shape[0] + + return fp diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 8a5e9bde3..4e86e9e6f 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -43,7 +43,7 @@ class JaxSimModel(JaxsimDataclass): ) _description: Static[ - wrappers.CustomHashedObject[jaxsim.parsers.descriptions.ModelDescription] | None + wrappers.HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None] ] = dataclasses.field(default=None, repr=False) @property @@ -61,11 +61,6 @@ def __eq__(self, other: JaxSimModel) -> bool: if self.kin_dyn_parameters != other.kin_dyn_parameters: return False - # Here we compare only the static quantities of ModelDescription - # that are actually used by our APIs. - if self.description.frames != other.description.frames: - return False - return True def __hash__(self) -> int: @@ -74,7 +69,6 @@ def __hash__(self) -> int: ( hash(self.model_name), hash(self.kin_dyn_parameters), - hash(self.description), ) ) @@ -170,10 +164,7 @@ def build( # Build the model. model = JaxSimModel( model_name=model_name, - _description=wrappers.CustomHashedObject( - obj=model_description, - hash_function=lambda desc: hash(tuple(desc.frames)), - ), + _description=wrappers.HashlessObject(obj=model_description), kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build( model_description=model_description ), @@ -288,6 +279,10 @@ def link_names(self) -> tuple[str, ...]: return self.kin_dyn_parameters.link_names + # ===================== + # Frame-related methods + # ===================== + def frame_names(self) -> tuple[str, ...]: """ Return the names of the links in the model. @@ -296,7 +291,7 @@ def frame_names(self) -> tuple[str, ...]: The names of the links in the model. """ - return tuple(frame.name for frame in self.description.frames) + return self.kin_dyn_parameters.frame_parameters.name # ===================== From 750e4f35b9eb9904f1c9218139d682bb427d48e3 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 14 Jun 2024 10:53:47 +0200 Subject: [PATCH 08/12] Update frames test --- tests/test_api_frame.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index eb67b7bf3..af5dcd5fc 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -1,5 +1,5 @@ import jax -import numpy as np +import jax.numpy as jnp import pytest import jaxsim.api as js @@ -16,27 +16,33 @@ def test_frame_index(jaxsim_models_types: js.model.JaxSimModel): # Tests # ===== - frame_indices = tuple( - frame.index for frame in model.description.frames if frame.index is not None - ) - - frame_names = np.array([frame.name for frame in model.description.frames]) + n_l = model.number_of_links() + n_f = len(model.frame_names()) - for frame_idx, frame_name in zip(frame_indices, frame_names): - assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_idx - assert js.frame.idx_to_name(model=model, frame_index=frame_idx) == frame_name + for idx, frame_name in enumerate(model.frame_names()): + frame_index = n_l + idx + assert js.frame.name_to_idx(model=model, frame_name=frame_name) == frame_index + assert js.frame.idx_to_name(model=model, frame_index=frame_index) == frame_name assert ( - js.frame.idx_of_parent_link(model=model, frame_idx=frame_idx) + js.frame.idx_of_parent_link(model=model, frame_idx=frame_index) < model.number_of_links() ) assert js.frame.names_to_idxs( - model=model, frame_names=tuple(frame_names) - ) == pytest.approx(frame_indices) + model=model, frame_names=model.frame_names() + ) == pytest.approx(jnp.arange(n_l, n_l + n_f)) - assert js.frame.idxs_to_names( - model=model, frame_indices=frame_indices - ) == pytest.approx(frame_names) + assert ( + js.frame.idxs_to_names( + model=model, + frame_indices=tuple( + js.frame.names_to_idxs( + model=model, frame_names=model.frame_names() + ).tolist() + ), + ) + == model.frame_names() + ) def test_frame_transforms( From 50dc320e3e2b5bcf6c62dae7862f0ce480411a94 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 14 Jun 2024 11:29:04 +0200 Subject: [PATCH 09/12] Minor changes --- src/jaxsim/api/kin_dyn_parameters.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index defd01f58..f9ffbc0a7 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -748,7 +748,7 @@ class ContactParameters(JaxsimDataclass): A tuple of integers representing, for each collidable point, the index of the body (link) to which it is rigidly attached to. point: - The translation between the link frame and the collidable point, expressed + The translations between the link frame and the collidable point, expressed in the coordinates of the parent link frame. Note: @@ -791,11 +791,11 @@ def build_from(model_description: ModelDescription) -> ContactParameters: links_dict[cp.parent_link.name].index for cp in collidable_points ) - # Build the GroundContact object. + # Build the ContactParameters object. cp = ContactParameters(point=points, body=link_index_of_points) # noqa - assert cp.point.shape[1] == 3 - assert cp.point.shape[0] == len(cp.body) + assert cp.point.shape[1] == 3, cp.point.shape[1] + assert cp.point.shape[0] == len(cp.body), cp.point.shape[0] return cp From ec6764845f186d8c2fe742b9b014520a6c920881 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 14 Jun 2024 11:38:37 +0200 Subject: [PATCH 10/12] Update hash and comparison of JaxSimModelData --- src/jaxsim/api/data.py | 6 ++++-- src/jaxsim/api/ode_data.py | 18 +++++++++++------- src/jaxsim/rbda/soft_contacts.py | 8 +++++--- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 7dc3de827..0539aa0ec 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -45,12 +45,14 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + return hash( ( hash(self.state), - hash(tuple(self.gravity.flatten().tolist())), + HashedNumpyArray.hash_of_array(self.gravity), hash(self.soft_contacts_params), - hash(jnp.atleast_1d(self.time_ns).flatten().tolist()), + hash(tuple(self.time_ns.flatten().tolist())), ) ) diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index aec0d7277..d90a2e672 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -283,12 +283,16 @@ class PhysicsModelState(JaxsimDataclass): def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + return hash( ( - hash(tuple(jnp.atleast_1d(self.joint_positions.flatten().tolist()))), - hash(tuple(jnp.atleast_1d(self.joint_velocities.flatten().tolist()))), - hash(tuple(self.base_position.flatten().tolist())), - hash(tuple(self.base_quaternion.flatten().tolist())), + HashedNumpyArray.hash_of_array(self.joint_positions), + HashedNumpyArray.hash_of_array(self.joint_velocities), + HashedNumpyArray.hash_of_array(self.base_position), + HashedNumpyArray.hash_of_array(self.base_quaternion), + HashedNumpyArray.hash_of_array(self.base_linear_velocity), + HashedNumpyArray.hash_of_array(self.base_angular_velocity), ) ) @@ -613,9 +617,9 @@ class SoftContactsState(JaxsimDataclass): def __hash__(self) -> int: - return hash( - tuple(jnp.atleast_1d(self.tangential_deformation.flatten()).tolist()) - ) + from jaxsim.utils.wrappers import HashedNumpyArray + + return HashedNumpyArray.hash_of_array(self.tangential_deformation) def __eq__(self, other: SoftContactsState) -> bool: diff --git a/src/jaxsim/rbda/soft_contacts.py b/src/jaxsim/rbda/soft_contacts.py index e20722178..d10371082 100644 --- a/src/jaxsim/rbda/soft_contacts.py +++ b/src/jaxsim/rbda/soft_contacts.py @@ -31,11 +31,13 @@ class SoftContactsParams(JaxsimDataclass): def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + return hash( ( - hash(tuple(jnp.atleast_1d(self.K).flatten().tolist())), - hash(tuple(jnp.atleast_1d(self.D).flatten().tolist())), - hash(tuple(jnp.atleast_1d(self.mu).flatten().tolist())), + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + HashedNumpyArray.hash_of_array(self.mu), ) ) From b7ff488d24bc348b506cdd1799135692100aec6c Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 14 Jun 2024 11:46:48 +0200 Subject: [PATCH 11/12] Extend pytree test --- tests/test_pytree.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 712b61441..8d063dc87 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -3,6 +3,7 @@ from contextlib import redirect_stdout import jax +import jax.numpy as jnp import jaxsim.api as js @@ -45,3 +46,21 @@ def test_call_jit_compiled_function_passing_different_objects( f"Compiling {js.contact.estimate_good_soft_contacts_parameters.__name__}" not in stdout ) + + # Define a new JIT-compiled function and check that is not recompiled for + # different model objects having the same pytree structure. + @jax.jit + def my_jit_function(model: js.model.JaxSimModel, data: js.data.JaxSimModelData): + # Return random elements from model and data, just to have something returned. + return ( + jnp.sum(model.kin_dyn_parameters.link_parameters.mass), + data.base_position(), + ) + + data1 = js.data.JaxSimModelData.build(model=model1) + + _ = my_jit_function(model=model1, data=data1) + assert my_jit_function._cache_size() == 1 + + _ = my_jit_function(model=model2, data=data1) + assert my_jit_function._cache_size() == 1 From a959a0009a277435ba27ed1a257a1f9deb157677 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Fri, 14 Jun 2024 15:56:45 +0200 Subject: [PATCH 12/12] Apply suggestions from code review Co-authored-by: Filippo Luca Ferretti --- src/jaxsim/api/data.py | 2 +- src/jaxsim/parsers/descriptions/joint.py | 85 +++++++++--------------- src/jaxsim/parsers/descriptions/link.py | 38 ++++------- src/jaxsim/parsers/descriptions/model.py | 24 +++---- 4 files changed, 54 insertions(+), 95 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 0539aa0ec..42b5d2562 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -52,7 +52,7 @@ def __hash__(self) -> int: hash(self.state), HashedNumpyArray.hash_of_array(self.gravity), hash(self.soft_contacts_params), - hash(tuple(self.time_ns.flatten().tolist())), + HashedNumpyArray.hash_of_array(self.time_ns), ) ) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index a5503e1b1..610b8c4fc 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -100,52 +100,29 @@ def __eq__(self, other: JointDescription) -> bool: if not isinstance(other, JointDescription): return False - if self.name != other.name: - return False - - if not np.allclose(self.axis, other.axis): - return False - - if not np.allclose(self.pose, other.pose): - return False - - if self.jtype != other.jtype: - return False - - if self.child != other.child: - return False - - if self.parent != other.parent: - return False - - if self.index != other.index: - return False - - if not np.allclose(self.friction_static, other.friction_static): - return False - - if not np.allclose(self.friction_viscous, other.friction_viscous): - return False - - if not np.allclose(self.position_limit_damper, other.position_limit_damper): - return False - - if not np.allclose(self.position_limit_spring, other.position_limit_spring): - return False - - if not np.allclose(self.position_limit, other.position_limit): - return False - - if not np.allclose(self.initial_position, other.initial_position): - return False - - if not np.allclose(self.motor_inertia, other.motor_inertia): - return False - - if not np.allclose(self.motor_viscous_friction, other.motor_viscous_friction): - return False - - if not np.allclose(self.motor_gear_ratio, other.motor_gear_ratio): + if not ( + self.name == other.name + and self.jtype == other.jtype + and self.child == other.child + and self.parent == other.parent + and self.index == other.index + and all( + np.allclose(getattr(self, attr), getattr(other, attr)) + for attr in [ + "axis", + "pose", + "friction_static", + "friction_viscous", + "position_limit_damper", + "position_limit_spring", + "position_limit", + "initial_position", + "motor_inertia", + "motor_viscous_friction", + "motor_gear_ratio", + ] + ), + ): return False return True @@ -163,14 +140,14 @@ def __hash__(self) -> int: hash(self.child), hash(self.parent), hash(int(self.index)) if self.index is not None else 0, - HashedNumpyArray.hash_of_array(np.array(self.friction_static)), - HashedNumpyArray.hash_of_array(np.array(self.friction_viscous)), - HashedNumpyArray.hash_of_array(np.array(self.position_limit_damper)), - HashedNumpyArray.hash_of_array(np.array(self.position_limit_spring)), - HashedNumpyArray.hash_of_array(np.array(self.position_limit)), + HashedNumpyArray.hash_of_array(self.friction_static), + HashedNumpyArray.hash_of_array(self.friction_viscous), + HashedNumpyArray.hash_of_array(self.position_limit_damper), + HashedNumpyArray.hash_of_array(self.position_limit_spring), + HashedNumpyArray.hash_of_array(self.position_limit), HashedNumpyArray.hash_of_array(self.initial_position), - HashedNumpyArray.hash_of_array(np.array(self.motor_inertia)), - HashedNumpyArray.hash_of_array(np.array(self.motor_viscous_friction)), - HashedNumpyArray.hash_of_array(np.array(self.motor_gear_ratio)), + HashedNumpyArray.hash_of_array(self.motor_inertia), + HashedNumpyArray.hash_of_array(self.motor_viscous_friction), + HashedNumpyArray.hash_of_array(self.motor_gear_ratio), ), ) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 37955bbed..41f5399df 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -47,7 +47,7 @@ def __hash__(self) -> int: hash(self.name), hash(float(self.mass)), HashedNumpyArray.hash_of_array(self.inertia), - hash(int(self.index)) if self.index is not None else self.index, + hash(int(self.index)) if self.index is not None else 0, HashedNumpyArray.hash_of_array(self.pose), hash(tuple(self.children)), # Here only using the name to prevent circular recursion: @@ -60,29 +60,19 @@ def __eq__(self, other: LinkDescription) -> bool: if not isinstance(other, LinkDescription): return False - if self.name != other.name: - return False - - if not np.allclose(self.mass, other.mass): - return False - - if not np.allclose(self.inertia, other.inertia): - return False - - if self.index != other.index: - return False - - if not np.allclose(self.pose, other.pose): - return False - - if self.children != other.children: - return False - - # Here only using the name to prevent circular recursion - if self.parent is not None and self.parent.name != other.parent.name: - return False - - if self.parent is None and other.parent is not None: + if not ( + self.name == other.name + and np.allclose(self.mass, other.mass) + and np.allclose(self.inertia, other.inertia) + and self.index == other.index + and np.allclose(self.pose, other.pose) + and self.children == other.children + and ( + (self.parent is not None and self.parent.name == other.parent.name) + if self.parent is not None + else other.parent is None + ), + ): return False return True diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 1faf9ebe0..ac488104d 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -249,22 +249,14 @@ def __eq__(self, other: ModelDescription) -> bool: if not isinstance(other, ModelDescription): return False - if self.name != other.name: - return False - - if self.fixed_base != other.fixed_base: - return False - - if self.root != other.root: - return False - - if self.joints != other.joints: - return False - - if self.frames != other.frames: - return False - - if self.root_pose != other.root_pose: + if not ( + self.name == other.name + and self.fixed_base == other.fixed_base + and self.root == other.root + and self.joints == other.joints + and self.frames == other.frames + and self.root_pose == other.root_pose + ): return False return True