diff --git a/tests/unit/potential/builtin/test_harmonicoscillator.py b/tests/unit/potential/builtin/test_harmonicoscillator.py index afbf3115..1eb0c7ad 100644 --- a/tests/unit/potential/builtin/test_harmonicoscillator.py +++ b/tests/unit/potential/builtin/test_harmonicoscillator.py @@ -1,6 +1,8 @@ from typing import Any +import astropy.units as u import pytest +from plum import convert import quaxed.numpy as qnp from unxt import Quantity @@ -10,6 +12,7 @@ from ..test_core import TestAbstractPotential as AbstractPotential_Test from galax.potential import HarmonicOscillatorPotential from galax.potential._potential.base import AbstractPotentialBase +from galax.utils._optional_deps import HAS_GALA class ParameterOmegaMixin(ParameterFieldMixin): @@ -101,3 +104,34 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None: assert qnp.allclose( pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) + + # --------------------------------- + # Interoperability + + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential_energy", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + ("density", "density", 5e-7), # TODO: why is this different? + # ("hessian", "hessian", 1e-8), # TODO: why is gala's 0? + ], + ) + def test_potential_energy_gala( + self, + pot: HarmonicOscillatorPotential, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + from ..io.gala_helper import galax_to_gala + + galax = getattr(pot, method0)(x, t=0) + gala = getattr(galax_to_gala(pot), method1)(convert(x, u.Quantity), t=0 * u.Myr) + assert qnp.allclose( + qnp.ravel(galax), + qnp.ravel(convert(gala, Quantity)), + atol=Quantity(atol, galax.unit), + ) diff --git a/tests/unit/potential/builtin/test_kuzmin.py b/tests/unit/potential/builtin/test_kuzmin.py index 0a7ad0a3..2aa930df 100644 --- a/tests/unit/potential/builtin/test_kuzmin.py +++ b/tests/unit/potential/builtin/test_kuzmin.py @@ -8,10 +8,10 @@ from unxt import AbstractUnitSystem, Quantity import galax.potential as gp +import galax.typing as gt from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import ParameterMTotMixin, ShapeAParameterMixin from galax.potential import AbstractPotentialBase, KuzminPotential -from galax.typing import Vec3 from galax.utils._optional_deps import HAS_GALA @@ -100,7 +100,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: ], ) def test_potential_energy_gala( - self, pot: KuzminPotential, method0: str, method1: str, x: Vec3, atol: float + self, pot: KuzminPotential, method0: str, method1: str, x: gt.QVec3, atol: float ) -> None: from ..io.gala_helper import galax_to_gala