diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 0539aa0ec..ae6a27e40 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -52,7 +52,7 @@ def __hash__(self) -> int: hash(self.state), HashedNumpyArray.hash_of_array(self.gravity), hash(self.soft_contacts_params), - hash(tuple(self.time_ns.flatten().tolist())), + HashedNumpyArray.hash_of_array(self.time_ns.flatten()), ) ) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index a5503e1b1..f1a87dd2d 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -100,52 +100,29 @@ def __eq__(self, other: JointDescription) -> bool: if not isinstance(other, JointDescription): return False - if self.name != other.name: - return False - - if not np.allclose(self.axis, other.axis): - return False - - if not np.allclose(self.pose, other.pose): - return False - - if self.jtype != other.jtype: - return False - - if self.child != other.child: - return False - - if self.parent != other.parent: - return False - - if self.index != other.index: - return False - - if not np.allclose(self.friction_static, other.friction_static): - return False - - if not np.allclose(self.friction_viscous, other.friction_viscous): - return False - - if not np.allclose(self.position_limit_damper, other.position_limit_damper): - return False - - if not np.allclose(self.position_limit_spring, other.position_limit_spring): - return False - - if not np.allclose(self.position_limit, other.position_limit): - return False - - if not np.allclose(self.initial_position, other.initial_position): - return False - - if not np.allclose(self.motor_inertia, other.motor_inertia): - return False - - if not np.allclose(self.motor_viscous_friction, other.motor_viscous_friction): - return False - - if not np.allclose(self.motor_gear_ratio, other.motor_gear_ratio): + if not all( + self.name == other.name + and self.jtype == other.jtype + and self.child == other.child + and self.parent == other.parent + and self.index == other.index + and all( + np.allclose(getattr(self, attr), getattr(other, attr)) + for attr in [ + "axis", + "pose", + "friction_static", + "friction_viscous", + "position_limit_damper", + "position_limit_spring", + "position_limit", + "initial_position", + "motor_inertia", + "motor_viscous_friction", + "motor_gear_ratio", + ] + ) + ): return False return True @@ -163,14 +140,14 @@ def __hash__(self) -> int: hash(self.child), hash(self.parent), hash(int(self.index)) if self.index is not None else 0, - HashedNumpyArray.hash_of_array(np.array(self.friction_static)), - HashedNumpyArray.hash_of_array(np.array(self.friction_viscous)), - HashedNumpyArray.hash_of_array(np.array(self.position_limit_damper)), - HashedNumpyArray.hash_of_array(np.array(self.position_limit_spring)), - HashedNumpyArray.hash_of_array(np.array(self.position_limit)), + HashedNumpyArray.hash_of_array(self.friction_static), + HashedNumpyArray.hash_of_array(self.friction_viscous), + HashedNumpyArray.hash_of_array(self.position_limit_damper), + HashedNumpyArray.hash_of_array(self.position_limit_spring), + HashedNumpyArray.hash_of_array(self.position_limit), HashedNumpyArray.hash_of_array(self.initial_position), - HashedNumpyArray.hash_of_array(np.array(self.motor_inertia)), - HashedNumpyArray.hash_of_array(np.array(self.motor_viscous_friction)), - HashedNumpyArray.hash_of_array(np.array(self.motor_gear_ratio)), + HashedNumpyArray.hash_of_array(self.motor_inertia), + HashedNumpyArray.hash_of_array(self.motor_viscous_friction), + HashedNumpyArray.hash_of_array(self.motor_gear_ratio), ), ) diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index 37955bbed..f8b4f8e1b 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -60,29 +60,21 @@ def __eq__(self, other: LinkDescription) -> bool: if not isinstance(other, LinkDescription): return False - if self.name != other.name: - return False - - if not np.allclose(self.mass, other.mass): - return False - - if not np.allclose(self.inertia, other.inertia): - return False - - if self.index != other.index: - return False - - if not np.allclose(self.pose, other.pose): - return False - - if self.children != other.children: - return False - - # Here only using the name to prevent circular recursion - if self.parent is not None and self.parent.name != other.parent.name: - return False - - if self.parent is None and other.parent is not None: + if not all( + [ + self.name == other.name, + np.allclose(self.mass, other.mass), + np.allclose(self.inertia, other.inertia), + self.index == other.index, + np.allclose(self.pose, other.pose), + self.children == other.children, + ( + (self.parent is not None and self.parent.name == other.parent.name) + if self.parent is not None + else other.parent is None + ), + ] + ): return False return True diff --git a/src/jaxsim/parsers/descriptions/model.py b/src/jaxsim/parsers/descriptions/model.py index 1faf9ebe0..e1023485f 100644 --- a/src/jaxsim/parsers/descriptions/model.py +++ b/src/jaxsim/parsers/descriptions/model.py @@ -249,22 +249,16 @@ def __eq__(self, other: ModelDescription) -> bool: if not isinstance(other, ModelDescription): return False - if self.name != other.name: - return False - - if self.fixed_base != other.fixed_base: - return False - - if self.root != other.root: - return False - - if self.joints != other.joints: - return False - - if self.frames != other.frames: - return False - - if self.root_pose != other.root_pose: + if not all( + [ + self.name == other.name, + self.fixed_base == other.fixed_base, + self.root == other.root, + self.joints == other.joints, + self.frames == other.frames, + self.root_pose == other.root_pose, + ] + ): return False return True