Skip to content

Commit

Permalink
feat: HarmonicOscillatorPotential
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Apr 27, 2024
1 parent 5e66b3a commit 7a288b1
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 22 deletions.
2 changes: 2 additions & 0 deletions src/galax/potential/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __all__ = [
"ParameterField",
# builtin
"BarPotential",
"HarmonicOscillatorPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
Expand All @@ -38,6 +39,7 @@ from ._potential import io
from ._potential.base import AbstractPotentialBase
from ._potential.builtin import (
BarPotential,
HarmonicOscillatorPotential,
HernquistPotential,
IsochronePotential,
KeplerPotential,
Expand Down
32 changes: 32 additions & 0 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"BarPotential",
"HarmonicOscillatorPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
Expand Down Expand Up @@ -95,6 +96,37 @@ def _potential_energy(self, q: gt.QVec3, t: gt.RealQScalar, /) -> gt.FloatQScala
# -------------------------------------------------------------------


@final
class HarmonicOscillatorPotential(AbstractPotential):
r"""Harmonic Oscillator Potential.
Represents an N-dimensional harmonic oscillator.
.. math::
\Phi = \frac{1}{2} \omega^2 x^2
"""

omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment]
"""The frequency."""

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

@partial(jax.jit)
def _potential_energy(
self, q: gt.BatchQVec3, /, t: gt.BatchableRealQScalar
) -> gt.BatchFloatQScalar:
return 0.5 * self.omega(t) ** 2 * xp.linalg.vector_norm(q, axis=-1) ** 2


# -------------------------------------------------------------------


@final
class HernquistPotential(AbstractPotential):
"""Hernquist Potential."""
Expand Down
41 changes: 19 additions & 22 deletions src/galax/potential/_potential/io/_gala.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
else:
_ = pytest.importorskip("gala")

import gala.potential as galap
import gala.potential as gp
from gala.units import DimensionlessUnitSystem as GalaDimensionlessUnitSystem

import coordinax.operators as cxo
Expand All @@ -27,7 +27,7 @@


@singledispatch
def gala_to_galax(pot: galap.PotentialBase, /) -> gpx.AbstractPotentialBase:
def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.
Parameters
Expand Down Expand Up @@ -153,7 +153,7 @@ def gala_to_galax(pot: galap.PotentialBase, /) -> gpx.AbstractPotentialBase:
PT = TypeVar("PT", bound=gpx.AbstractPotentialBase)


def _get_frame(pot: galap.PotentialBase, /) -> cxo.AbstractOperator:
def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator:
frame = cxo.GalileanSpatialTranslationOperator(
Quantity(pot.origin, unit=pot.units["length"])
)
Expand All @@ -173,29 +173,26 @@ def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialF


@gala_to_galax.register
def _gala_to_galax_composite(
pot: galap.CompositePotential, /
) -> gpx.CompositePotential:
def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePotential:
"""Convert a Gala CompositePotential to a Galax potential."""
return gpx.CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()})


_GALA_TO_GALAX_REGISTRY: dict[
type[galap.PotentialBase], type[gpx.AbstractPotential]
] = {
galap.HernquistPotential: gpx.HernquistPotential,
galap.IsochronePotential: gpx.IsochronePotential,
galap.KeplerPotential: gpx.KeplerPotential,
galap.MiyamotoNagaiPotential: gpx.MiyamotoNagaiPotential,
_GALA_TO_GALAX_REGISTRY: dict[type[gp.PotentialBase], type[gpx.AbstractPotential]] = {
gp.HarmonicOscillatorPotential: gpx.HarmonicOscillatorPotential,
gp.HernquistPotential: gpx.HernquistPotential,
gp.IsochronePotential: gpx.IsochronePotential,
gp.KeplerPotential: gpx.KeplerPotential,
gp.MiyamotoNagaiPotential: gpx.MiyamotoNagaiPotential,
}


@gala_to_galax.register(galap.HernquistPotential)
@gala_to_galax.register(galap.IsochronePotential)
@gala_to_galax.register(galap.KeplerPotential)
@gala_to_galax.register(galap.MiyamotoNagaiPotential)
@gala_to_galax.register(gp.HernquistPotential)
@gala_to_galax.register(gp.IsochronePotential)
@gala_to_galax.register(gp.KeplerPotential)
@gala_to_galax.register(gp.MiyamotoNagaiPotential)
def _gala_to_galax_registered(
gala: galap.PotentialBase, /
gala: gp.PotentialBase, /
) -> gpx.AbstractPotential | gpx.PotentialFrame:
"""Convert a Gala HernquistPotential to a Galax potential."""
if isinstance(gala.units, GalaDimensionlessUnitSystem):
Expand All @@ -216,7 +213,7 @@ def _gala_to_galax_registered(


@gala_to_galax.register
def _gala_to_galax_null(_: galap.NullPotential, /) -> gpx.NullPotential:
def _gala_to_galax_null(_: gp.NullPotential, /) -> gpx.NullPotential:
"""Convert a Gala NullPotential to a Galax potential.
Examples
Expand All @@ -235,7 +232,7 @@ def _gala_to_galax_null(_: galap.NullPotential, /) -> gpx.NullPotential:

@gala_to_galax.register
def _gala_to_galax_nfw(
gala: galap.NFWPotential, /
gala: gp.NFWPotential, /
) -> gpx.NFWPotential | gpx.PotentialFrame:
"""Convert a Gala NFWPotential to a Galax potential.
Expand All @@ -262,7 +259,7 @@ def _gala_to_galax_nfw(

@gala_to_galax.register
def _gala_to_galax_leesutotriaxialnfw(
pot: galap.LeeSutoTriaxialNFWPotential, /
pot: gp.LeeSutoTriaxialNFWPotential, /
) -> gpx.LeeSutoTriaxialNFWPotential:
"""Convert a Gala LeeSutoTriaxialNFWPotential to a Galax potential.
Expand Down Expand Up @@ -306,7 +303,7 @@ def _gala_to_galax_leesutotriaxialnfw(


@gala_to_galax.register
def _gala_to_galax_mw(pot: galap.MilkyWayPotential, /) -> gpx.MilkyWayPotential:
def _gala_to_galax_mw(pot: gp.MilkyWayPotential, /) -> gpx.MilkyWayPotential:
"""Convert a Gala MilkyWayPotential to a Galax potential.
Examples
Expand Down
103 changes: 103 additions & 0 deletions tests/unit/potential/builtin/test_harmonicoscillator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Any

import pytest

import quaxed.numpy as qnp
from unxt import Quantity

import galax.typing as gt
from ..param.test_field import ParameterFieldMixin
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from galax.potential import HarmonicOscillatorPotential
from galax.potential._potential.base import AbstractPotentialBase


class ParameterOmegaMixin(ParameterFieldMixin):
"""Test the omega parameter."""

@pytest.fixture(scope="class")
def field_omega(self) -> Quantity["frequency"]:
return Quantity(1.0, "Hz")

# =====================================================

def test_omega_constant(self, pot_cls, fields):
"""Test the `omega` parameter."""
fields["omega"] = Quantity(1.0, "Hz")
pot = pot_cls(**fields)
assert pot.omega(t=0) == Quantity(1.0, "Hz")

@pytest.mark.xfail(reason="TODO: user function doesn't have units")
def test_omega_userfunc(self, pot_cls, fields):
"""Test the `omega` parameter."""
fields["omega"] = lambda t: t * 1.2
pot = pot_cls(**fields)
assert pot.omega(t=0) == 2


class TestHarmonicOscillatorPotential(
AbstractPotential_Test,
# Parameters
ParameterOmegaMixin,
):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[HarmonicOscillatorPotential]:
return HarmonicOscillatorPotential

@pytest.fixture(scope="class")
def fields_(self, field_omega, field_units) -> dict[str, Any]:
return {"omega": field_omega, "units": field_units}

# ==========================================================================

def test_potential_energy(
self, pot: HarmonicOscillatorPotential, x: gt.Vec3
) -> None:
expect = Quantity(-0.94871936, pot.units["specific energy"])
assert qnp.isclose(
pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_gradient(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(
[0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"]
)
assert qnp.allclose(
pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_density(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(3.989933e08, pot.units["mass density"])
assert qnp.isclose(
pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_hessian(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(
[
[0.04362645, -0.01969533, -0.02954299],
[-0.01969533, 0.01408345, -0.05908599],
[-0.02954299, -0.05908599, -0.03515487],
],
"1/Myr2",
)
assert qnp.allclose(
pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

# ---------------------------------
# Convenience methods

def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None:
"""Test the `AbstractPotentialBase.tidal_tensor` method."""
expect = Quantity(
[
[0.0361081, -0.01969533, -0.02954299],
[-0.01969533, 0.00656511, -0.05908599],
[-0.02954299, -0.05908599, -0.04267321],
],
"1/Myr2",
)
assert qnp.allclose(
pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

0 comments on commit 7a288b1

Please sign in to comment.