From a0fe6993d4506675ff55fd6e6bc109e404ea3df6 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Fri, 5 Jul 2024 16:02:34 +0200 Subject: [PATCH] Fix calling PlaneTerrain.__eq__ from jit-compiled functions --- src/jaxsim/terrain/terrain.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index 9f7378476..e9161da27 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import jax_dataclasses +import numpy as np import jaxsim.typing as jtp @@ -149,10 +150,10 @@ def __eq__(self, other: PlaneTerrain) -> bool: return False if not ( - jnp.allclose(self.z, other.z) - and jnp.allclose( - jnp.array(self.plane_normal, dtype=float), - jnp.array(other.plane_normal, dtype=float), + np.allclose(self.z, other.z) + and np.allclose( + np.array(self.plane_normal, dtype=float), + np.array(other.plane_normal, dtype=float), ) ): return False