diff --git a/src/galax/_interop/galax_interop_gala/potential.py b/src/galax/_interop/galax_interop_gala/potential.py index ccd9b4ed..450285b4 100644 --- a/src/galax/_interop/galax_interop_gala/potential.py +++ b/src/galax/_interop/galax_interop_gala/potential.py @@ -340,6 +340,63 @@ def galax_to_gala(pot: gpx.BurkertPotential, /) -> gp.BurkertPotential: units=_galax_to_gala_units(pot.units), ) +# --------------------------- +# Harmonic oscillator potentials + + +@dispatch # type: ignore[misc] +def gala_to_galax( + gala: gp.HarmonicOscillatorPotential, / +) -> gpx.HarmonicOscillatorPotential | gpx.PotentialFrame: + r"""Convert a `gala.potential.HarmonicOscillatorPotential` to a `galax.potential.HarmonicOscillatorPotential`. + + Examples + -------- + >>> import gala.potential as galap + >>> from gala.units import galactic + >>> import galax.potential as gp + + >>> pot = galap.HarmonicOscillatorPotential(omega=1, units=galactic) + >>> gp.io.convert_potential(gp.io.GalaxLibrary, pot) + HarmonicOscillatorPotential( + units=LTMAUnitSystem( length=Unit("kpc"), ...), + constants=ImmutableMap({'G': ...}), + omega=ConstantParameter( ... ) + ) + + """ # noqa: E501 + params = gala.parameters + pot = gpx.HarmonicOscillatorPotential( + omega=params["omega"], units=_check_gala_units(gala.units) + ) + return _apply_frame(_get_frame(gala), pot) + + +@dispatch # type: ignore[misc] +def galax_to_gala( + pot: gpx.HarmonicOscillatorPotential, / +) -> gp.HarmonicOscillatorPotential: + """Convert a `galax.potential.HarmonicOscillatorPotential` to a `gala.potential.HarmonicOscillatorPotential`. + + Examples + -------- + >>> import gala.potential as galap + >>> from unxt import Quantity + >>> import galax.potential as gp + + >>> pot = gp.HarmonicOscillatorPotential(omega=Quantity(1, "1/Myr"), units="galactic") + >>> gp.io.convert_potential(gp.io.GalaLibrary, pot) + + + """ # noqa: E501 + _error_if_not_all_constant_parameters(pot, *pot.parameters.keys()) + + return gp.HarmonicOscillatorPotential( + omega=convert(pot.omega(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + # --------------------------- # Hernquist potentials diff --git a/src/galax/potential/__init__.py b/src/galax/potential/__init__.py index 3f9192c6..9ebf0e7e 100644 --- a/src/galax/potential/__init__.py +++ b/src/galax/potential/__init__.py @@ -14,6 +14,7 @@ "CompositePotential", # builtin "BurkertPotential", + "HarmonicOscillatorPotential", "HernquistPotential", "IsochronePotential", "JaffePotential", @@ -69,6 +70,7 @@ from ._src.builtin.bars import BarPotential, LongMuraliBarPotential from ._src.builtin.builtin import ( BurkertPotential, + HarmonicOscillatorPotential, HernquistPotential, IsochronePotential, JaffePotential, diff --git a/src/galax/potential/_src/builtin/builtin.py b/src/galax/potential/_src/builtin/builtin.py index 7307da92..7c45974f 100644 --- a/src/galax/potential/_src/builtin/builtin.py +++ b/src/galax/potential/_src/builtin/builtin.py @@ -2,6 +2,7 @@ __all__ = [ "BurkertPotential", + "HarmonicOscillatorPotential", "HernquistPotential", "IsochronePotential", "JaffePotential", @@ -158,6 +159,72 @@ def from_central_density( # ------------------------------------------------------------------- +@final +class HarmonicOscillatorPotential(AbstractPotential): + r"""Harmonic Oscillator Potential. + + Represents an N-dimensional harmonic oscillator. + + .. math:: + + \Phi(\mathbf{q}, t) = \frac{1}{2} |\omega(t) \cdot \mathbf{q}|^2 + + Examples + -------- + >>> from unxt import Quantity + >>> import galax.potential as gp + + >>> pot = gp.HarmonicOscillatorPotential(omega=Quantity(1, "1 / Myr"), + ... units="galactic") + >>> pot + HarmonicOscillatorPotential( + units=LTMAUnitSystem( ... ), + constants=ImmutableMap({'G': ...}), + omega=ConstantParameter( value=Quantity[...](value=f64[], unit=Unit("1 / Myr")) ) + ) + + >>> q = Quantity([1.0, 0, 0], "kpc") + >>> t = Quantity(0, "Gyr") + + >>> pot.potential(q, t) + Quantity[...](Array(0.5, dtype=float64), unit='kpc2 / Myr2') + + >>> pot.density(q, t) + Quantity[...](Array(1.76897707e+10, dtype=float64), unit='solMass / kpc3') + + """ + + # TODO: enable omega to be a 3D vector + omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment] + """The frequency.""" + + _: KW_ONLY + units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True) + constants: ImmutableMap[str, Quantity] = eqx.field( + default=default_constants, converter=ImmutableMap + ) + + @partial(jax.jit, inline=True) + def _potential( + self, q: gt.BatchQVec3, t: gt.BatchableRealQScalar, / + ) -> gt.SpecificEnergyBatchScalar: + # \Phi(\mathbf{q}, t) = \frac{1}{2} |\omega(t) \cdot \mathbf{q}|^2 + omega = jnp.atleast_1d(self.omega(t)) + return 0.5 * jnp.sum(jnp.square(omega * q), axis=-1) + + @partial(jax.jit, inline=True) + def _density( + self, _: gt.BatchQVec3, t: gt.BatchRealQScalar | gt.RealQScalar, / + ) -> gt.BatchFloatQScalar: + # \rho(\mathbf{q}, t) = \frac{1}{4 \pi G} \sum_i \omega_i^2 + omega = jnp.atleast_1d(self.omega(t)) + denom = 4 * jnp.pi * self.constants["G"] + return jnp.sum(omega**2, axis=-1) / denom + + +# ------------------------------------------------------------------- + + @final class HernquistPotential(AbstractPotential): """Hernquist Potential.""" diff --git a/tests/unit/potential/builtin/misc/test_harmonicoscillator.py b/tests/unit/potential/builtin/misc/test_harmonicoscillator.py new file mode 100644 index 00000000..c036b920 --- /dev/null +++ b/tests/unit/potential/builtin/misc/test_harmonicoscillator.py @@ -0,0 +1,163 @@ +"""Unit tests for the `HarmonicOscillatorPotential` class.""" + +from typing import Any +from typing_extensions import override + +import astropy.units as u +import pytest +from plum import convert + +import quaxed.numpy as jnp +from unxt import Quantity + +import galax.potential as gp +import galax.typing as gt +from ...param.test_field import ParameterFieldMixin +from ...test_core import AbstractPotential_Test +from galax._interop.optional_deps import OptDeps +from galax.potential._src.base import AbstractPotentialBase + + +class ParameterOmegaMixin(ParameterFieldMixin): + """Test the omega parameter.""" + + @pytest.fixture(scope="class") + def field_omega(self) -> Quantity["frequency"]: + return Quantity(1.0, "Hz") + + # ===================================================== + + def test_omega_constant(self, pot_cls, fields): + """Test the `omega` parameter.""" + fields["omega"] = Quantity(1.0, "Hz") + pot = pot_cls(**fields) + assert pot.omega(t=0) == Quantity(1.0, "Hz") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_omega_userfunc(self, pot_cls, fields): + """Test the `omega` parameter.""" + fields["omega"] = lambda t: t * 1.2 + pot = pot_cls(**fields) + assert pot.omega(t=0) == 2 + + +################################################################################ + + +class TestHarmonicOscillatorPotential( + AbstractPotential_Test, + # Parameters + ParameterOmegaMixin, +): + @pytest.fixture(scope="class") + @override + def pot_cls(self) -> type[gp.HarmonicOscillatorPotential]: + return gp.HarmonicOscillatorPotential + + @pytest.fixture(scope="class") + @override + def fields_(self, field_omega, field_units) -> dict[str, Any]: + return {"omega": field_omega, "units": field_units} + + # ========================================================================== + + def test_potential(self, pot: gp.HarmonicOscillatorPotential, x: gt.QVec3) -> None: + got = pot.potential(x, t=0) + expect = Quantity(6.97117482e27, pot.units["specific energy"]) + assert jnp.isclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + def test_gradient(self, pot: gp.HarmonicOscillatorPotential, x: gt.Vec3) -> None: + got = convert(pot.gradient(x, t=0), Quantity) + expect = Quantity([9.95882118e26, 1.99176424e27, 2.98764635e27], "kpc / Myr2") + assert jnp.allclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + def test_density(self, pot: gp.HarmonicOscillatorPotential, x: gt.QVec3) -> None: + got = pot.density(x, t=0) + expect = Quantity(1.76169263e37, unit="solMass / kpc3") + assert jnp.isclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + def test_hessian(self, pot: gp.HarmonicOscillatorPotential, x: gt.QVec3) -> None: + got = pot.hessian(x, t=0) + expect = Quantity( + [ + [9.95882118e26, 0.00000000e00, 0.00000000e00], + [0.00000000e00, 9.95882118e26, 0.00000000e00], + [0.00000000e00, 0.00000000e00, 9.95882118e26], + ], + "1/Myr2", + ) + assert jnp.allclose(got, expect, atol=Quantity(1e-8, expect.unit)) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + "1/Myr2", + ) + assert jnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Interoperability + + @pytest.mark.skipif(not OptDeps.GALA.installed, reason="requires gala") + @pytest.mark.parametrize( + ("method0", "method1", "atol"), + [ + ("potential", "energy", 1e-8), + ("gradient", "gradient", 1e-8), + # ("density", "density", 5e-7), # Gala doesn't have the density + # ("hessian", "hessian", 1e-8), # TODO: why doesn't this match? + ], + ) + def test_method_gala( + self, + pot: gp.HarmonicOscillatorPotential, + method0: str, + method1: str, + x: gt.QVec3, + atol: float, + ) -> None: + """Test the equivalence of methods between gala and galax. + + This test only runs if the potential can be mapped to gala. + """ + # First we need to check that the potential is gala-compatible + if not self.HAS_GALA_COUNTERPART: + pytest.skip("potential does not have a gala counterpart") + + # Evaluate the galax method. Gala is in 1D, so we take the norm. + galax = convert(getattr(pot, method0)(x, t=0), Quantity) + galax1d = jnp.linalg.vector_norm(jnp.atleast_1d(galax), axis=-1) + + # Evaluate the gala method. This works in 1D on Astropy quantities. + galap = gp.io.convert_potential(gp.io.GalaLibrary, pot) + r = convert(jnp.linalg.vector_norm(x, axis=-1), u.Quantity) + gala = getattr(galap, method1)(r, t=0 * u.Myr) + + assert jnp.allclose( + jnp.ravel(galax1d), + jnp.ravel(convert(gala, Quantity)), + atol=Quantity(atol, galax.unit), + ) + + # ========================================================================== + # TODO: Implement these tests + + @pytest.mark.skip("TODO") + def test_evaluate_orbit(self, pot: gp.AbstractPotentialBase, xv: gt.Vec6) -> None: + """Test the `AbstractPotentialBase.evaluate_orbit` method.""" + + @pytest.mark.skip("TODO") + def test_evaluate_orbit_batch( + self, pot: gp.AbstractPotentialBase, xv: gt.Vec6 + ) -> None: + """Test the `AbstractPotentialBase.evaluate_orbit` method."""