Skip to content

Commit

Permalink
feat: HarmonicOscillatorPotential
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Apr 28, 2024
1 parent 937e917 commit 5316add
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/galax/potential/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __all__ = [
"ParameterField",
# builtin
"BarPotential",
"HarmonicOscillatorPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
Expand All @@ -39,6 +40,7 @@ from ._potential import io
from ._potential.base import AbstractPotentialBase
from ._potential.builtin import (
BarPotential,
HarmonicOscillatorPotential,
HernquistPotential,
IsochronePotential,
KeplerPotential,
Expand Down
32 changes: 32 additions & 0 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"BarPotential",
"HarmonicOscillatorPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
Expand Down Expand Up @@ -96,6 +97,37 @@ def _potential_energy(self, q: gt.QVec3, t: gt.RealQScalar, /) -> gt.FloatQScala
# -------------------------------------------------------------------


@final
class HarmonicOscillatorPotential(AbstractPotential):
r"""Harmonic Oscillator Potential.
Represents an N-dimensional harmonic oscillator.
.. math::
\Phi = \frac{1}{2} \omega^2 x^2
"""

omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment]
"""The frequency."""

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
)

@partial(jax.jit)
def _potential_energy(
self, q: gt.BatchQVec3, /, t: gt.BatchableRealQScalar
) -> gt.BatchFloatQScalar:
return 0.5 * self.omega(t) ** 2 * xp.linalg.vector_norm(q, axis=-1) ** 2


# -------------------------------------------------------------------


@final
class HernquistPotential(AbstractPotential):
"""Hernquist Potential."""
Expand Down
1 change: 1 addition & 0 deletions src/galax/potential/_potential/io/_gala.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote


_GALA_TO_GALAX_REGISTRY: dict[type[gp.PotentialBase], type[gpx.AbstractPotential]] = {
gp.HarmonicOscillatorPotential: gpx.HarmonicOscillatorPotential,
gp.HernquistPotential: gpx.HernquistPotential,
gp.IsochronePotential: gpx.IsochronePotential,
gp.KeplerPotential: gpx.KeplerPotential,
Expand Down
103 changes: 103 additions & 0 deletions tests/unit/potential/builtin/test_harmonicoscillator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Any

import pytest

import quaxed.numpy as qnp
from unxt import Quantity

import galax.typing as gt
from ..param.test_field import ParameterFieldMixin
from ..test_core import TestAbstractPotential as AbstractPotential_Test
from galax.potential import HarmonicOscillatorPotential
from galax.potential._potential.base import AbstractPotentialBase


class ParameterOmegaMixin(ParameterFieldMixin):
"""Test the omega parameter."""

@pytest.fixture(scope="class")
def field_omega(self) -> Quantity["frequency"]:
return Quantity(1.0, "Hz")

# =====================================================

def test_omega_constant(self, pot_cls, fields):
"""Test the `omega` parameter."""
fields["omega"] = Quantity(1.0, "Hz")
pot = pot_cls(**fields)
assert pot.omega(t=0) == Quantity(1.0, "Hz")

@pytest.mark.xfail(reason="TODO: user function doesn't have units")
def test_omega_userfunc(self, pot_cls, fields):
"""Test the `omega` parameter."""
fields["omega"] = lambda t: t * 1.2
pot = pot_cls(**fields)
assert pot.omega(t=0) == 2


class TestHarmonicOscillatorPotential(
AbstractPotential_Test,
# Parameters
ParameterOmegaMixin,
):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[HarmonicOscillatorPotential]:
return HarmonicOscillatorPotential

@pytest.fixture(scope="class")
def fields_(self, field_omega, field_units) -> dict[str, Any]:
return {"omega": field_omega, "units": field_units}

# ==========================================================================

def test_potential_energy(
self, pot: HarmonicOscillatorPotential, x: gt.Vec3
) -> None:
expect = Quantity(-0.94871936, pot.units["specific energy"])
assert qnp.isclose(
pot.potential_energy(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_gradient(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(
[0.05347411, 0.10694822, 0.16042233], pot.units["acceleration"]
)
assert qnp.allclose(
pot.gradient(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_density(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(3.989933e08, pot.units["mass density"])
assert qnp.isclose(
pot.density(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

def test_hessian(self, pot: HarmonicOscillatorPotential, x: gt.Vec3) -> None:
expect = Quantity(
[
[0.04362645, -0.01969533, -0.02954299],
[-0.01969533, 0.01408345, -0.05908599],
[-0.02954299, -0.05908599, -0.03515487],
],
"1/Myr2",
)
assert qnp.allclose(
pot.hessian(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

# ---------------------------------
# Convenience methods

def test_tidal_tensor(self, pot: AbstractPotentialBase, x: gt.Vec3) -> None:
"""Test the `AbstractPotentialBase.tidal_tensor` method."""
expect = Quantity(
[
[0.0361081, -0.01969533, -0.02954299],
[-0.01969533, 0.00656511, -0.05908599],
[-0.02954299, -0.05908599, -0.04267321],
],
"1/Myr2",
)
assert qnp.allclose(
pot.tidal_tensor(x, t=0), expect, atol=Quantity(1e-8, expect.unit)
)

0 comments on commit 5316add

Please sign in to comment.