Skip to content

Commit

Permalink
test: add gala test
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Apr 28, 2024
1 parent 7bf72ec commit b7e7f65
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
34 changes: 34 additions & 0 deletions tests/unit/potential/builtin/test_harmonicoscillator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

import astropy.units as u
import pytest
from plum import convert

import quaxed.numpy as qnp
from unxt import Quantity
Expand All @@ -10,6 +12,7 @@
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from galax.potential import HarmonicOscillatorPotential
from galax.potential._potential.base import AbstractPotentialBase
from galax.utils._optional_deps import HAS_GALA


class ParameterOmegaMixin(ParameterFieldMixin):
Expand Down Expand Up @@ -101,3 +104,34 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None:
assert qnp.allclose(
pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

# ---------------------------------
# Interoperability

@pytest.mark.skipif(not HAS_GALA, reason="requires gala")
@pytest.mark.parametrize(
("method0", "method1", "atol"),
[
("potential_energy", "energy", 1e-8),
("gradient", "gradient", 1e-8),
("density", "density", 5e-7), # TODO: why is this different?
# ("hessian", "hessian", 1e-8), # TODO: why is gala's 0?
],
)
def test_potential_energy_gala(
self,
pot: HarmonicOscillatorPotential,
method0: str,
method1: str,
x: gt.QVec3,
atol: float,
) -> None:
from ..io.gala_helper import galax_to_gala

galax = getattr(pot, method0)(x, t=0)
gala = getattr(galax_to_gala(pot), method1)(convert(x, u.Quantity), t=0 * u.Myr)
assert qnp.allclose(
qnp.ravel(galax),
qnp.ravel(convert(gala, Quantity)),
atol=Quantity(atol, galax.unit),
)
14 changes: 7 additions & 7 deletions tests/unit/potential/builtin/test_kuzmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from unxt import AbstractUnitSystem, Quantity

import galax.potential as gp
import galax.typing as gt
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import ParameterMTotMixin, ShapeAParameterMixin
from galax.potential import AbstractPotentialBase, KuzminPotential
from galax.typing import Vec3
from galax.utils._optional_deps import HAS_GALA


Expand All @@ -38,25 +38,25 @@ def fields_(

# ==========================================================================

def test_potential_energy(self, pot: KuzminPotential, x: Vec3) -> None:
def test_potential_energy(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity(-0.98165365, unit="kpc2 / Myr2")
assert qnp.isclose(
pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_gradient(self, pot: KuzminPotential, x: Vec3) -> None:
def test_gradient(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity([0.04674541, 0.09349082, 0.18698165], "kpc / Myr2")
assert qnp.allclose(
pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_density(self, pot: KuzminPotential, x: Vec3) -> None:
def test_density(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity(2.45494884e-07, "solMass / kpc3")
assert qnp.isclose(
pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_hessian(self, pot: KuzminPotential, x: Vec3) -> None:
def test_hessian(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity(
[
[0.0400675, -0.01335583, -0.02671166],
Expand All @@ -72,7 +72,7 @@ def test_hessian(self, pot: KuzminPotential, x: Vec3) -> None:
# ---------------------------------
# Convenience methods

def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None:
"""Test the `AbstractPotentialBase.tidal_tensor` method."""
expect = Quantity(
[
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
],
)
def test_potential_energy_gala(
self, pot: KuzminPotential, method0: str, method1: str, x: Vec3, atol: float
self, pot: KuzminPotential, method0: str, method1: str, x: gt.QVec3, atol: float
) -> None:
from ..io.gala_helper import galax_to_gala

Expand Down

0 comments on commit b7e7f65

Please sign in to comment.