Skip to content

Commit

Permalink
Add explicit __eq__ operators for classes that have ndarray attribute…
Browse files Browse the repository at this point in the history
…s and could be used as Static jax_dataclasses attributes
  • Loading branch information
traversaro committed Mar 10, 2024
1 parent 9a2f26b commit 1c34033
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def change_link(
enabled=self.enabled,
)

def __eq__(self,other):
retval = (self.parent_link == other.parent_link and
(self.position == other.position).all() and
self.enabled == other.enabled)
return retval

def __str__(self):
return (
f"{self.__class__.__name__}("
Expand Down Expand Up @@ -93,6 +99,9 @@ class BoxCollision(CollisionShape):

center: npt.NDArray

def __eq__(self,other):
return ((self.center == other.center).all() and
super().__eq__(other))

@dataclasses.dataclass
class SphereCollision(CollisionShape):
Expand All @@ -105,3 +114,7 @@ class SphereCollision(CollisionShape):
"""

center: npt.NDArray

def __eq__(self,other):
return ((self.center == other.center).all() and
super().__eq__(other))
9 changes: 9 additions & 0 deletions src/jaxsim/parsers/descriptions/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ class LinkDescription(JaxsimDataclass):
def __hash__(self) -> int:
return hash(self.__repr__())

def __eq__(self,other) -> bool:
return (self.name == other.name and
self.mass == other.mass and
(self.inertia == other.inertia).all() and
self.index == other.index and
self.parent == other.parent and
(self.pose == other.pose).all() and
self.children == other.children)

@property
def name_and_index(self) -> str:
"""
Expand Down
4 changes: 4 additions & 0 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class RootPose(NamedTuple):
root_position: npt.NDArray = np.zeros(3)
root_quaternion: npt.NDArray = np.array([1.0, 0, 0, 0])

def __eq__(self, other):
return ((self.root_position == other.root_position).all() and
(self.root_quaternion == other.root_quaternion).all())


@dataclasses.dataclass(frozen=True)
class KinematicGraph:
Expand Down

0 comments on commit 1c34033

Please sign in to comment.