From b7e7f6535663c482aed3d6f715a3346ea7fdd819 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sat, 27 Apr 2024 22:15:23 -0400 Subject: [PATCH] test: add gala test Signed-off-by: nstarman --- .../builtin/test_harmonicoscillator.py | 34 +++++++++++++++++++ tests/unit/potential/builtin/test_kuzmin.py | 14 ++++---- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/tests/unit/potential/builtin/test_harmonicoscillator.py b/tests/unit/potential/builtin/test_harmonicoscillator.py index afbf3115..1eb0c7ad 100644 --- a/tests/unit/potential/builtin/test_harmonicoscillator.py +++ b/tests/unit/potential/builtin/test_harmonicoscillator.py @@ -1,6 +1,8 @@ from typing import Any +import astropy.units as u import pytest +from plum import convert import quaxed.numpy as qnp from unxt import Quantity @@ -10,6 +12,7 @@ from ..test_core import TestAbstractPotential as AbstractPotential_Test from galax.potential import HarmonicOscillatorPotential from galax.potential._potential.base import AbstractPotentialBase +from galax.utils._optional_deps import HAS_GALA class ParameterOmegaMixin(ParameterFieldMixin): @@ -101,3 +104,34 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None: assert qnp.allclose( pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) + + # --------------------------------- + # Interoperability + + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential_energy", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + ("density", "density", 5e-7), # TODO: why is this different? + # ("hessian", "hessian", 1e-8), # TODO: why is gala's 0? + ], + ) + def test_potential_energy_gala( + self, + pot: HarmonicOscillatorPotential, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + from ..io.gala_helper import galax_to_gala + + galax = getattr(pot, method0)(x, t=0) + gala = getattr(galax_to_gala(pot), method1)(convert(x, u.Quantity), t=0 * u.Myr) + assert qnp.allclose( + qnp.ravel(galax), + qnp.ravel(convert(gala, Quantity)), + atol=Quantity(atol, galax.unit), + ) diff --git a/tests/unit/potential/builtin/test_kuzmin.py b/tests/unit/potential/builtin/test_kuzmin.py index 0a7ad0a3..44b62c00 100644 --- a/tests/unit/potential/builtin/test_kuzmin.py +++ b/tests/unit/potential/builtin/test_kuzmin.py @@ -8,10 +8,10 @@ from unxt import AbstractUnitSystem, Quantity import galax.potential as gp +import galax.typing as gt from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import ParameterMTotMixin, ShapeAParameterMixin from galax.potential import AbstractPotentialBase, KuzminPotential -from galax.typing import Vec3 from galax.utils._optional_deps import HAS_GALA @@ -38,25 +38,25 @@ def fields_( # ========================================================================== - def test_potential_energy(self, pot: KuzminPotential, x: Vec3) -> None: + def test_potential_energy(self, pot: KuzminPotential, x: gt.QVec3) -> None: expect = Quantity(-0.98165365, unit="kpc2 / Myr2") assert qnp.isclose( pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) - def test_gradient(self, pot: KuzminPotential, x: Vec3) -> None: + def test_gradient(self, pot: KuzminPotential, x: gt.QVec3) -> None: expect = Quantity([0.04674541, 0.09349082, 0.18698165], "kpc / Myr2") assert qnp.allclose( pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) - def test_density(self, pot: KuzminPotential, x: Vec3) -> None: + def test_density(self, pot: KuzminPotential, x: gt.QVec3) -> None: expect = Quantity(2.45494884e-07, "solMass / kpc3") assert qnp.isclose( pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) ) - def test_hessian(self, pot: KuzminPotential, x: Vec3) -> None: + def test_hessian(self, pot: KuzminPotential, x: gt.QVec3) -> None: expect = Quantity( [ [0.0400675, -0.01335583, -0.02671166], @@ -72,7 +72,7 @@ def test_hessian(self, pot: KuzminPotential, x: Vec3) -> None: # --------------------------------- # Convenience methods - def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None: """Test the `AbstractPotentialBase.tidal_tensor` method.""" expect = Quantity( [ @@ -100,7 +100,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None: ], ) def test_potential_energy_gala( - self, pot: KuzminPotential, method0: str, method1: str, x: Vec3, atol: float + self, pot: KuzminPotential, method0: str, method1: str, x: gt.QVec3, atol: float ) -> None: from ..io.gala_helper import galax_to_gala