diff --git a/src/jaxsim/terrain/terrain.py b/src/jaxsim/terrain/terrain.py index 9f7378476..e9161da27 100644 --- a/src/jaxsim/terrain/terrain.py +++ b/src/jaxsim/terrain/terrain.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import jax_dataclasses +import numpy as np import jaxsim.typing as jtp @@ -149,10 +150,10 @@ def __eq__(self, other: PlaneTerrain) -> bool: return False if not ( - jnp.allclose(self.z, other.z) - and jnp.allclose( - jnp.array(self.plane_normal, dtype=float), - jnp.array(other.plane_normal, dtype=float), + np.allclose(self.z, other.z) + and np.allclose( + np.array(self.plane_normal, dtype=float), + np.array(other.plane_normal, dtype=float), ) ): return False