diff --git a/src/galax/potential/__init__.pyi b/src/galax/potential/__init__.pyi index bcdfee2e..137920c3 100644 --- a/src/galax/potential/__init__.pyi +++ b/src/galax/potential/__init__.pyi @@ -20,6 +20,7 @@ __all__ = [ "ParameterField", # builtin "BarPotential", + "HarmonicOscillatorPotential", "HernquistPotential", "IsochronePotential", "KeplerPotential", @@ -38,6 +39,7 @@ from ._potential import io from ._potential.base import AbstractPotentialBase from ._potential.builtin import ( BarPotential, + HarmonicOscillatorPotential, HernquistPotential, IsochronePotential, KeplerPotential, diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index e0ec41af..6ac26fa0 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -2,6 +2,7 @@ __all__ = [ "BarPotential", + "HarmonicOscillatorPotential", "HernquistPotential", "IsochronePotential", "KeplerPotential", @@ -95,6 +96,37 @@ def _potential_energy(self, q: gt.QVec3, t: gt.RealQScalar, /) -> gt.FloatQScala # ------------------------------------------------------------------- +@final +class HarmonicOscillatorPotential(AbstractPotential): + r"""Harmonic Oscillator Potential. + + Represents an N-dimensional harmonic oscillator. + + .. math:: + + \Phi = \frac{1}{2} \omega^2 x^2 + + """ + + omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment] + """The frequency.""" + + _: KW_ONLY + units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True) + constants: ImmutableDict[Quantity] = eqx.field( + default=default_constants, converter=ImmutableDict + ) + + @partial(jax.jit) + def _potential_energy( + self, q: gt.BatchQVec3, /, t: gt.BatchableRealQScalar + ) -> gt.BatchFloatQScalar: + return 0.5 * self.omega(t) ** 2 * xp.linalg.vector_norm(q, axis=-1) ** 2 + + +# ------------------------------------------------------------------- + + @final class HernquistPotential(AbstractPotential): """Hernquist Potential.""" diff --git a/src/galax/potential/_potential/io/_gala.py b/src/galax/potential/_potential/io/_gala.py index 6f19b29e..9cf5db0c 100644 --- a/src/galax/potential/_potential/io/_gala.py +++ b/src/galax/potential/_potential/io/_gala.py @@ -12,7 +12,7 @@ else: _ = pytest.importorskip("gala") -import gala.potential as galap +import gala.potential as gp from gala.units import DimensionlessUnitSystem as GalaDimensionlessUnitSystem import coordinax.operators as cxo @@ -27,7 +27,7 @@ @singledispatch -def gala_to_galax(pot: galap.PotentialBase, /) -> gpx.AbstractPotentialBase: +def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase: """Convert a :mod:`gala` potential to a :mod:`galax` potential. Parameters @@ -153,7 +153,7 @@ def gala_to_galax(pot: galap.PotentialBase, /) -> gpx.AbstractPotentialBase: PT = TypeVar("PT", bound=gpx.AbstractPotentialBase) -def _get_frame(pot: galap.PotentialBase, /) -> cxo.AbstractOperator: +def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator: frame = cxo.GalileanSpatialTranslationOperator( Quantity(pot.origin, unit=pot.units["length"]) ) @@ -173,29 +173,26 @@ def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialF @gala_to_galax.register -def _gala_to_galax_composite( - pot: galap.CompositePotential, / -) -> gpx.CompositePotential: +def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePotential: """Convert a Gala CompositePotential to a Galax potential.""" return gpx.CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()}) -_GALA_TO_GALAX_REGISTRY: dict[ - type[galap.PotentialBase], type[gpx.AbstractPotential] -] = { - galap.HernquistPotential: gpx.HernquistPotential, - galap.IsochronePotential: gpx.IsochronePotential, - galap.KeplerPotential: gpx.KeplerPotential, - galap.MiyamotoNagaiPotential: gpx.MiyamotoNagaiPotential, +_GALA_TO_GALAX_REGISTRY: dict[type[gp.PotentialBase], type[gpx.AbstractPotential]] = { + gp.HarmonicOscillatorPotential: gpx.HarmonicOscillatorPotential, + gp.HernquistPotential: gpx.HernquistPotential, + gp.IsochronePotential: gpx.IsochronePotential, + gp.KeplerPotential: gpx.KeplerPotential, + gp.MiyamotoNagaiPotential: gpx.MiyamotoNagaiPotential, } -@gala_to_galax.register(galap.HernquistPotential) -@gala_to_galax.register(galap.IsochronePotential) -@gala_to_galax.register(galap.KeplerPotential) -@gala_to_galax.register(galap.MiyamotoNagaiPotential) +@gala_to_galax.register(gp.HernquistPotential) +@gala_to_galax.register(gp.IsochronePotential) +@gala_to_galax.register(gp.KeplerPotential) +@gala_to_galax.register(gp.MiyamotoNagaiPotential) def _gala_to_galax_registered( - gala: galap.PotentialBase, / + gala: gp.PotentialBase, / ) -> gpx.AbstractPotential | gpx.PotentialFrame: """Convert a Gala HernquistPotential to a Galax potential.""" if isinstance(gala.units, GalaDimensionlessUnitSystem): @@ -216,7 +213,7 @@ def _gala_to_galax_registered( @gala_to_galax.register -def _gala_to_galax_null(_: galap.NullPotential, /) -> gpx.NullPotential: +def _gala_to_galax_null(_: gp.NullPotential, /) -> gpx.NullPotential: """Convert a Gala NullPotential to a Galax potential. Examples @@ -235,7 +232,7 @@ def _gala_to_galax_null(_: galap.NullPotential, /) -> gpx.NullPotential: @gala_to_galax.register def _gala_to_galax_nfw( - gala: galap.NFWPotential, / + gala: gp.NFWPotential, / ) -> gpx.NFWPotential | gpx.PotentialFrame: """Convert a Gala NFWPotential to a Galax potential. @@ -262,7 +259,7 @@ def _gala_to_galax_nfw( @gala_to_galax.register def _gala_to_galax_leesutotriaxialnfw( - pot: galap.LeeSutoTriaxialNFWPotential, / + pot: gp.LeeSutoTriaxialNFWPotential, / ) -> gpx.LeeSutoTriaxialNFWPotential: """Convert a Gala LeeSutoTriaxialNFWPotential to a Galax potential. @@ -306,7 +303,7 @@ def _gala_to_galax_leesutotriaxialnfw( @gala_to_galax.register -def _gala_to_galax_mw(pot: galap.MilkyWayPotential, /) -> gpx.MilkyWayPotential: +def _gala_to_galax_mw(pot: gp.MilkyWayPotential, /) -> gpx.MilkyWayPotential: """Convert a Gala MilkyWayPotential to a Galax potential. Examples diff --git a/tests/unit/potential/builtin/test_harmonicoscillator.py b/tests/unit/potential/builtin/test_harmonicoscillator.py new file mode 100644 index 00000000..afbf3115 --- /dev/null +++ b/tests/unit/potential/builtin/test_harmonicoscillator.py @@ -0,0 +1,103 @@ +from typing import Any + +import pytest + +import quaxed.numpy as qnp +from unxt import Quantity + +import galax.typing as gt +from ..param.test_field import ParameterFieldMixin +from ..test_core import TestAbstractPotential as AbstractPotential_Test +from galax.potential import HarmonicOscillatorPotential +from galax.potential._potential.base import AbstractPotentialBase + + +class ParameterOmegaMixin(ParameterFieldMixin): + """Test the omega parameter.""" + + @pytest.fixture(scope="class") + def field_omega(self) -> Quantity["frequency"]: + return Quantity(1.0, "Hz") + + # ===================================================== + + def test_omega_constant(self, pot_cls, fields): + """Test the `omega` parameter.""" + fields["omega"] = Quantity(1.0, "Hz") + pot = pot_cls(**fields) + assert pot.omega(t=0) == Quantity(1.0, "Hz") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_omega_userfunc(self, pot_cls, fields): + """Test the `omega` parameter.""" + fields["omega"] = lambda t: t * 1.2 + pot = pot_cls(**fields) + assert pot.omega(t=0) == 2 + + +class TestHarmonicOscillatorPotential( + AbstractPotential_Test, + # Parameters + ParameterOmegaMixin, +): + @pytest.fixture(scope="class") + def pot_cls(self) -> type[HarmonicOscillatorPotential]: + return HarmonicOscillatorPotential + + @pytest.fixture(scope="class") + def fields_(self, field_omega, field_units) -> dict[str, Any]: + return {"omega": field_omega, "units": field_units} + + # ========================================================================== + + def test_potential_energy( + self, pot: HarmonicOscillatorPotential, x: gt.Vec3 + ) -> None: + expect = Quantity(-0.94871936, pot.units["specific energy"]) + assert qnp.isclose( + pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None: + expect = Quantity( + [0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"] + ) + assert qnp.allclose( + pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_density(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None: + expect = Quantity(3.989933e08, pot.units["mass density"]) + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None: + expect = Quantity( + [ + [0.04362645, -0.01969533, -0.02954299], + [-0.01969533, 0.01408345, -0.05908599], + [-0.02954299, -0.05908599, -0.03515487], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [0.0361081, -0.01969533, -0.02954299], + [-0.01969533, 0.00656511, -0.05908599], + [-0.02954299, -0.05908599, -0.04267321], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + )