diff --git a/src/galax/potential/__init__.pyi b/src/galax/potential/__init__.pyi index bcdfee2e..137920c3 100644 --- a/src/galax/potential/__init__.pyi +++ b/src/galax/potential/__init__.pyi @@ -20,6 +20,7 @@ __all__ = [ "ParameterField", # builtin "BarPotential", + "HarmonicOscillatorPotential", "HernquistPotential", "IsochronePotential", "KeplerPotential", @@ -38,6 +39,7 @@ from ._potential import io from ._potential.base import AbstractPotentialBase from ._potential.builtin import ( BarPotential, + HarmonicOscillatorPotential, HernquistPotential, IsochronePotential, KeplerPotential, diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index e0ec41af..6ac26fa0 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -2,6 +2,7 @@ __all__ = [ "BarPotential", + "HarmonicOscillatorPotential", "HernquistPotential", "IsochronePotential", "KeplerPotential", @@ -95,6 +96,37 @@ def _potential_energy(self, q: gt.QVec3, t: gt.RealQScalar, /) -> gt.FloatQScala # ------------------------------------------------------------------- +@final +class HarmonicOscillatorPotential(AbstractPotential): + r"""Harmonic Oscillator Potential. + + Represents an N-dimensional harmonic oscillator. + + .. math:: + + \Phi = \frac{1}{2} \omega^2 x^2 + + """ + + omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment] + """The frequency.""" + + _: KW_ONLY + units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True) + constants: ImmutableDict[Quantity] = eqx.field( + default=default_constants, converter=ImmutableDict + ) + + @partial(jax.jit) + def _potential_energy( + self, q: gt.BatchQVec3, /, t: gt.BatchableRealQScalar + ) -> gt.BatchFloatQScalar: + return 0.5 * self.omega(t) ** 2 * xp.linalg.vector_norm(q, axis=-1) ** 2 + + +# ------------------------------------------------------------------- + + @final class HernquistPotential(AbstractPotential): """Hernquist Potential.""" diff --git a/src/galax/potential/_potential/io/_gala.py b/src/galax/potential/_potential/io/_gala.py index fe89d867..26de850c 100644 --- a/src/galax/potential/_potential/io/_gala.py +++ b/src/galax/potential/_potential/io/_gala.py @@ -14,6 +14,7 @@ from gala.potential import ( CompositePotential as GalaCompositePotential, + HarmonicOscillatorPotential as GalaHarmonicOscillatorPotential, HernquistPotential as GalaHernquistPotential, IsochronePotential as GalaIsochronePotential, KeplerPotential as GalaKeplerPotential, @@ -33,6 +34,7 @@ from galax.potential._potential.base import AbstractPotentialBase from galax.potential._potential.builtin import ( + HarmonicOscillatorPotential, HernquistPotential, IsochronePotential, KeplerPotential, @@ -201,6 +203,7 @@ def _gala_to_galax_composite(pot: GalaCompositePotential, /) -> CompositePotenti _GALA_TO_GALAX_REGISTRY: dict[type[GalaPotentialBase], type[AbstractPotential]] = { + GalaHarmonicOscillatorPotential: HarmonicOscillatorPotential, GalaHernquistPotential: HernquistPotential, GalaIsochronePotential: IsochronePotential, GalaKeplerPotential: KeplerPotential, diff --git a/tests/unit/potential/builtin/test_harmonicoscillator.py b/tests/unit/potential/builtin/test_harmonicoscillator.py new file mode 100644 index 00000000..afbf3115 --- /dev/null +++ b/tests/unit/potential/builtin/test_harmonicoscillator.py @@ -0,0 +1,103 @@ +from typing import Any + +import pytest + +import quaxed.numpy as qnp +from unxt import Quantity + +import galax.typing as gt +from ..param.test_field import ParameterFieldMixin +from ..test_core import TestAbstractPotential as AbstractPotential_Test +from galax.potential import HarmonicOscillatorPotential +from galax.potential._potential.base import AbstractPotentialBase + + +class ParameterOmegaMixin(ParameterFieldMixin): + """Test the omega parameter.""" + + @pytest.fixture(scope="class") + def field_omega(self) -> Quantity["frequency"]: + return Quantity(1.0, "Hz") + + # ===================================================== + + def test_omega_constant(self, pot_cls, fields): + """Test the `omega` parameter.""" + fields["omega"] = Quantity(1.0, "Hz") + pot = pot_cls(**fields) + assert pot.omega(t=0) == Quantity(1.0, "Hz") + + @pytest.mark.xfail(reason="TODO: user function doesn't have units") + def test_omega_userfunc(self, pot_cls, fields): + """Test the `omega` parameter.""" + fields["omega"] = lambda t: t * 1.2 + pot = pot_cls(**fields) + assert pot.omega(t=0) == 2 + + +class TestHarmonicOscillatorPotential( + AbstractPotential_Test, + # Parameters + ParameterOmegaMixin, +): + @pytest.fixture(scope="class") + def pot_cls(self) -> type[HarmonicOscillatorPotential]: + return HarmonicOscillatorPotential + + @pytest.fixture(scope="class") + def fields_(self, field_omega, field_units) -> dict[str, Any]: + return {"omega": field_omega, "units": field_units} + + # ========================================================================== + + def test_potential_energy( + self, pot: HarmonicOscillatorPotential, x: gt.Vec3 + ) -> None: + expect = Quantity(-0.94871936, pot.units["specific energy"]) + assert qnp.isclose( + pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_gradient(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None: + expect = Quantity( + [0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"] + ) + assert qnp.allclose( + pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_density(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None: + expect = Quantity(3.989933e08, pot.units["mass density"]) + assert qnp.isclose( + pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + def test_hessian(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None: + expect = Quantity( + [ + [0.04362645, -0.01969533, -0.02954299], + [-0.01969533, 0.01408345, -0.05908599], + [-0.02954299, -0.05908599, -0.03515487], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + ) + + # --------------------------------- + # Convenience methods + + def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None: + """Test the `AbstractPotentialBase.tidal_tensor` method.""" + expect = Quantity( + [ + [0.0361081, -0.01969533, -0.02954299], + [-0.01969533, 0.00656511, -0.05908599], + [-0.02954299, -0.05908599, -0.04267321], + ], + "1/Myr2", + ) + assert qnp.allclose( + pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit) + )