Skip to content

Commit

Permalink
feat: PotentialFrame is parametric
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Jul 4, 2024
1 parent 607b525 commit a54bffb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
27 changes: 17 additions & 10 deletions src/galax/_galax_interop_gala/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator:
return cxo.simplify_op(frame)


def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialFrame:
def _apply_frame(
frame: cxo.AbstractOperator, pot: PT, /
) -> PT | gpx.PotentialFrame[PT]:
return (
pot if isinstance(frame, IdentityOperator) else gpx.PotentialFrame(pot, frame)
)
Expand Down Expand Up @@ -205,7 +207,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote
@gala_to_galax.register(gp.PowerLawCutoffPotential)
def _gala_to_galax_registered(
gala: gp.PotentialBase, /
) -> gpx.AbstractPotential | gpx.PotentialFrame:
) -> gpx.AbstractPotential | gpx.PotentialFrame[gpx.AbstractPotential]:
"""Convert a Gala potential to a Galax potential."""
if isinstance(gala.units, GalaDimensionlessUnitSystem):
msg = "Galax does not support converting dimensionless units."
Expand Down Expand Up @@ -247,7 +249,7 @@ def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential:
@gala_to_galax.register
def _gala_to_galax_burkert(
gala: gp.BurkertPotential, /
) -> gpx.BurkertPotential | gpx.PotentialFrame:
) -> gpx.BurkertPotential | gpx.PotentialFrame[gpx.BurkertPotential]:
"""Convert a Gala BurkertPotential to a Galax potential.
Examples
Expand Down Expand Up @@ -285,7 +287,7 @@ def _gala_to_galax_burkert(
@gala_to_galax.register
def _gala_to_galax_hernquist(
gala: gp.HernquistPotential, /
) -> gpx.HernquistPotential | gpx.PotentialFrame:
) -> gpx.HernquistPotential | gpx.PotentialFrame[gpx.HernquistPotential]:
r"""Convert a Gala HernquistPotential to a Galax potential.
Examples
Expand All @@ -311,7 +313,7 @@ def _gala_to_galax_hernquist(
@gala_to_galax.register
def _gala_to_galax_jaffe(
gala: gp.JaffePotential, /
) -> gpx.JaffePotential | gpx.PotentialFrame:
) -> gpx.JaffePotential | gpx.PotentialFrame[gpx.JaffePotential]:
"""Convert a Gala JaffePotential to a Galax potential.
Examples
Expand All @@ -337,7 +339,7 @@ def _gala_to_galax_jaffe(
@gala_to_galax.register
def _gala_to_galax_longmuralibar(
gala: gp.LongMuraliBarPotential, /
) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame:
) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame[gpx.LongMuraliBarPotential]:
"""Convert a Gala LongMuraliBarPotential to a Galax potential.
Examples
Expand Down Expand Up @@ -373,7 +375,7 @@ def _gala_to_galax_longmuralibar(
@gala_to_galax.register
def _gala_to_galax_satoh(
gala: gp.SatohPotential, /
) -> gpx.SatohPotential | gpx.PotentialFrame:
) -> gpx.SatohPotential | gpx.PotentialFrame[gpx.SatohPotential]:
"""Convert a Gala SatohPotential to a Galax potential.
Examples
Expand Down Expand Up @@ -402,7 +404,7 @@ def _gala_to_galax_satoh(
@gala_to_galax.register
def _gala_to_galax_stoneostriker15(
gala: gp.StonePotential, /
) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame:
) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame[gpx.StoneOstriker15Potential]:
"""Convert a Gala StonePotential to a Galax potential.
Examples
Expand Down Expand Up @@ -435,7 +437,12 @@ def _gala_to_galax_stoneostriker15(
@gala_to_galax.register
def _gala_to_galax_logarithmic(
gala: gp.LogarithmicPotential, /
) -> gpx.LogarithmicPotential | gpx.LMJ09LogarithmicPotential | gpx.PotentialFrame:
) -> (
gpx.LogarithmicPotential
| gpx.LMJ09LogarithmicPotential
| gpx.PotentialFrame[gpx.LogarithmicPotential]
| gpx.PotentialFrame[gpx.LMJ09LogarithmicPotential]
):
"""Convert a Gala LogarithmicPotential to a Galax potential.
If the flattening or rotation 'phi' is non-zero, the potential is a
Expand Down Expand Up @@ -491,7 +498,7 @@ def _gala_to_galax_logarithmic(
@gala_to_galax.register
def _gala_to_galax_nfw(
gala: gp.NFWPotential, /
) -> gpx.NFWPotential | gpx.PotentialFrame:
) -> gpx.NFWPotential | gpx.PotentialFrame[gpx.NFWPotential]:
"""Convert a Gala NFWPotential to a Galax potential.
Examples
Expand Down
10 changes: 6 additions & 4 deletions src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from dataclasses import replace
from typing import cast, final
from typing import Generic, TypeVar, cast, final

import equinox as eqx

Expand All @@ -15,9 +15,11 @@
from .base import AbstractPotentialBase
from galax.utils import ImmutableDict

FramedPotT = TypeVar("FramedPotT", bound=AbstractPotentialBase)


@final
class PotentialFrame(AbstractPotentialBase):
class PotentialFrame(AbstractPotentialBase, Generic[FramedPotT]):
"""Reference frame of the potential.
Examples
Expand Down Expand Up @@ -184,7 +186,7 @@ class PotentialFrame(AbstractPotentialBase):
Array(-2.23568166, dtype=float64)
""" # noqa: E501

original_potential: AbstractPotentialBase
original_potential: FramedPotT

operator: OperatorSequence = eqx.field(default=(), converter=OperatorSequence)
"""Transformation to reference frame of the potential.
Expand Down Expand Up @@ -249,6 +251,6 @@ def _potential(


@simplify_op.register # type: ignore[misc]
def _simplify_op(frame: PotentialFrame, /) -> PotentialFrame:
def _simplify_op(frame: PotentialFrame[FramedPotT], /) -> PotentialFrame[FramedPotT]:
"""Simplify the operators in an PotentialFrame."""
return replace(frame, operator=simplify_op(frame.operator))

0 comments on commit a54bffb

Please sign in to comment.