Skip to content

Commit

Permalink
test: add gala test
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Apr 28, 2024
1 parent d63f8bd commit 6268a48
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
34 changes: 34 additions & 0 deletions tests/unit/potential/builtin/test_harmonicoscillator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

import astropy.units as u
import pytest
from plum import convert

import quaxed.numpy as qnp
from unxt import Quantity
Expand All @@ -10,6 +12,7 @@
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from galax.potential import HarmonicOscillatorPotential
from galax.potential._potential.base import AbstractPotentialBase
from galax.utils._optional_deps import HAS_GALA


class ParameterOmegaMixin(ParameterFieldMixin):
Expand Down Expand Up @@ -101,3 +104,34 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None:
assert qnp.allclose(
pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

# ---------------------------------
# Interoperability

@pytest.mark.skipif(not HAS_GALA, reason="requires gala")
@pytest.mark.parametrize(
("method0", "method1", "atol"),
[
("potential_energy", "energy", 1e-8),
("gradient", "gradient", 1e-8),
("density", "density", 5e-7), # TODO: why is this different?
# ("hessian", "hessian", 1e-8), # TODO: why is gala's 0?
],
)
def test_potential_energy_gala(
self,
pot: HarmonicOscillatorPotential,
method0: str,
method1: str,
x: gt.QVec3,
atol: float,
) -> None:
from ..io.gala_helper import galax_to_gala

galax = getattr(pot, method0)(x, t=0)
gala = getattr(galax_to_gala(pot), method1)(convert(x, u.Quantity), t=0 * u.Myr)
assert qnp.allclose(
qnp.ravel(galax),
qnp.ravel(convert(gala, Quantity)),
atol=Quantity(atol, galax.unit),
)
4 changes: 2 additions & 2 deletions tests/unit/potential/builtin/test_kuzmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from unxt import AbstractUnitSystem, Quantity

import galax.potential as gp
import galax.typing as gt
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import ParameterMTotMixin, ShapeAParameterMixin
from galax.potential import AbstractPotentialBase, KuzminPotential
from galax.typing import Vec3
from galax.utils._optional_deps import HAS_GALA


Expand Down Expand Up @@ -100,7 +100,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
],
)
def test_potential_energy_gala(
self, pot: KuzminPotential, method0: str, method1: str, x: Vec3, atol: float
self, pot: KuzminPotential, method0: str, method1: str, x: gt.QVec3, atol: float
) -> None:
from ..io.gala_helper import galax_to_gala

Expand Down

0 comments on commit 6268a48

Please sign in to comment.