Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: HarmonicOscillatorPotential #262

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/galax/potential/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __all__ = [
"ParameterField",
# builtin
"BarPotential",
"HarmonicOscillatorPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
Expand All @@ -39,6 +40,7 @@ from ._potential import io
from ._potential.base import AbstractPotentialBase
from ._potential.builtin import (
BarPotential,
HarmonicOscillatorPotential,
HernquistPotential,
IsochronePotential,
KeplerPotential,
Expand Down
32 changes: 32 additions & 0 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"BarPotential",
"HarmonicOscillatorPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
Expand Down Expand Up @@ -96,6 +97,37 @@ def _potential_energy(self, q: gt.QVec3, t: gt.RealQScalar, /) -> gt.FloatQScala
# -------------------------------------------------------------------


@final
class HarmonicOscillatorPotential(AbstractPotential):
r"""Harmonic Oscillator Potential.

Represents an N-dimensional harmonic oscillator.

.. math::

\Phi = \frac{1}{2} \omega^2 x^2

"""

omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment]
"""The frequency."""

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

@partial(jax.jit)
def _potential_energy(
self, q: gt.BatchQVec3, /, t: gt.BatchableRealQScalar
) -> gt.BatchFloatQScalar:
return 0.5 * self.omega(t) ** 2 * xp.linalg.vector_norm(q, axis=-1) ** 2


# -------------------------------------------------------------------


@final
class HernquistPotential(AbstractPotential):
"""Hernquist Potential."""
Expand Down
1 change: 1 addition & 0 deletions src/galax/potential/_potential/io/_gala.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote


_GALA_TO_GALAX_REGISTRY: dict[type[gp.PotentialBase], type[gpx.AbstractPotential]] = {
gp.HarmonicOscillatorPotential: gpx.HarmonicOscillatorPotential,
gp.HernquistPotential: gpx.HernquistPotential,
gp.IsochronePotential: gpx.IsochronePotential,
gp.KeplerPotential: gpx.KeplerPotential,
Expand Down
137 changes: 137 additions & 0 deletions tests/unit/potential/builtin/test_harmonicoscillator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Any

import astropy.units as u
import pytest
from plum import convert

import quaxed.numpy as qnp
from unxt import Quantity

import galax.typing as gt
from ..param.test_field import ParameterFieldMixin
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from galax.potential import HarmonicOscillatorPotential
from galax.potential._potential.base import AbstractPotentialBase
from galax.utils._optional_deps import HAS_GALA


class ParameterOmegaMixin(ParameterFieldMixin):
"""Test the omega parameter."""

@pytest.fixture(scope="class")
def field_omega(self) -> Quantity["frequency"]:
return Quantity(1.0, "Hz")

# =====================================================

def test_omega_constant(self, pot_cls, fields):
"""Test the `omega` parameter."""
fields["omega"] = Quantity(1.0, "Hz")
pot = pot_cls(**fields)
assert pot.omega(t=0) == Quantity(1.0, "Hz")

@pytest.mark.xfail(reason="TODO: user function doesn't have units")
def test_omega_userfunc(self, pot_cls, fields):
"""Test the `omega` parameter."""
fields["omega"] = lambda t: t * 1.2
pot = pot_cls(**fields)
assert pot.omega(t=0) == 2


class TestHarmonicOscillatorPotential(
AbstractPotential_Test,
# Parameters
ParameterOmegaMixin,
):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[HarmonicOscillatorPotential]:
return HarmonicOscillatorPotential

@pytest.fixture(scope="class")
def fields_(self, field_omega, field_units) -> dict[str, Any]:
return {"omega": field_omega, "units": field_units}

# ==========================================================================

def test_potential_energy(
self, pot: HarmonicOscillatorPotential, x: gt.Vec3
) -> None:
expect = Quantity(-0.94871936, pot.units["specific energy"])
assert qnp.isclose(
pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_gradient(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(
[0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"]
)
assert qnp.allclose(
pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_density(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(3.989933e08, pot.units["mass density"])
assert qnp.isclose(
pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_hessian(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(
[
[0.04362645, -0.01969533, -0.02954299],
[-0.01969533, 0.01408345, -0.05908599],
[-0.02954299, -0.05908599, -0.03515487],
],
"1/Myr2",
)
assert qnp.allclose(
pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

# ---------------------------------
# Convenience methods

def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None:
"""Test the `AbstractPotentialBase.tidal_tensor` method."""
expect = Quantity(
[
[0.0361081, -0.01969533, -0.02954299],
[-0.01969533, 0.00656511, -0.05908599],
[-0.02954299, -0.05908599, -0.04267321],
],
"1/Myr2",
)
assert qnp.allclose(
pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

# ---------------------------------
# Interoperability

@pytest.mark.skipif(not HAS_GALA, reason="requires gala")
@pytest.mark.parametrize(
("method0", "method1", "atol"),
[
("potential_energy", "energy", 1e-8),
("gradient", "gradient", 1e-8),
("density", "density", 5e-7), # TODO: why is this different?
# ("hessian", "hessian", 1e-8), # TODO: why is gala's 0?
],
)
def test_potential_energy_gala(
self,
pot: HarmonicOscillatorPotential,
method0: str,
method1: str,
x: gt.QVec3,
atol: float,
) -> None:
from ..io.gala_helper import galax_to_gala

galax = getattr(pot, method0)(x, t=0)
gala = getattr(galax_to_gala(pot), method1)(convert(x, u.Quantity), t=0 * u.Myr)
assert qnp.allclose(
qnp.ravel(galax),
qnp.ravel(convert(gala, Quantity)),
atol=Quantity(atol, galax.unit),
)
14 changes: 7 additions & 7 deletions tests/unit/potential/builtin/test_kuzmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from unxt import AbstractUnitSystem, Quantity

import galax.potential as gp
import galax.typing as gt
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import ParameterMTotMixin, ShapeAParameterMixin
from galax.potential import AbstractPotentialBase, KuzminPotential
from galax.typing import Vec3
from galax.utils._optional_deps import HAS_GALA


Expand All @@ -38,25 +38,25 @@ def fields_(

# ==========================================================================

def test_potential_energy(self, pot: KuzminPotential, x: Vec3) -> None:
def test_potential_energy(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity(-0.98165365, unit="kpc2 / Myr2")
assert qnp.isclose(
pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_gradient(self, pot: KuzminPotential, x: Vec3) -> None:
def test_gradient(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity([0.04674541, 0.09349082, 0.18698165], "kpc / Myr2")
assert qnp.allclose(
pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_density(self, pot: KuzminPotential, x: Vec3) -> None:
def test_density(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity(2.45494884e-07, "solMass / kpc3")
assert qnp.isclose(
pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_hessian(self, pot: KuzminPotential, x: Vec3) -> None:
def test_hessian(self, pot: KuzminPotential, x: gt.QVec3) -> None:
expect = Quantity(
[
[0.0400675, -0.01335583, -0.02671166],
Expand All @@ -72,7 +72,7 @@ def test_hessian(self, pot: KuzminPotential, x: Vec3) -> None:
# ---------------------------------
# Convenience methods

def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.QVec3) -> None:
"""Test the `AbstractPotentialBase.tidal_tensor` method."""
expect = Quantity(
[
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_tidal_tensor(self, pot: AbstractPotentialBase, x: Vec3) -> None:
],
)
def test_potential_energy_gala(
self, pot: KuzminPotential, method0: str, method1: str, x: Vec3, atol: float
self, pot: KuzminPotential, method0: str, method1: str, x: gt.QVec3, atol: float
) -> None:
from ..io.gala_helper import galax_to_gala

Expand Down
Loading
Loading