Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PotentialFrame is parametric #377

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/galax/_galax_interop_gala/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator:
return cxo.simplify_op(frame)


def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialFrame:
def _apply_frame(
frame: cxo.AbstractOperator, pot: PT, /
) -> PT | gpx.PotentialFrame[PT]:
return (
pot if isinstance(frame, IdentityOperator) else gpx.PotentialFrame(pot, frame)
)
Expand Down Expand Up @@ -205,7 +207,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote
@gala_to_galax.register(gp.PowerLawCutoffPotential)
def _gala_to_galax_registered(
gala: gp.PotentialBase, /
) -> gpx.AbstractPotential | gpx.PotentialFrame:
) -> gpx.AbstractPotential | gpx.PotentialFrame[gpx.AbstractPotential]:
"""Convert a Gala potential to a Galax potential."""
if isinstance(gala.units, GalaDimensionlessUnitSystem):
msg = "Galax does not support converting dimensionless units."
Expand Down Expand Up @@ -247,7 +249,7 @@ def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential:
@gala_to_galax.register
def _gala_to_galax_burkert(
gala: gp.BurkertPotential, /
) -> gpx.BurkertPotential | gpx.PotentialFrame:
) -> gpx.BurkertPotential | gpx.PotentialFrame[gpx.BurkertPotential]:
"""Convert a Gala BurkertPotential to a Galax potential.

Examples
Expand Down Expand Up @@ -285,7 +287,7 @@ def _gala_to_galax_burkert(
@gala_to_galax.register
def _gala_to_galax_hernquist(
gala: gp.HernquistPotential, /
) -> gpx.HernquistPotential | gpx.PotentialFrame:
) -> gpx.HernquistPotential | gpx.PotentialFrame[gpx.HernquistPotential]:
r"""Convert a Gala HernquistPotential to a Galax potential.

Examples
Expand All @@ -311,7 +313,7 @@ def _gala_to_galax_hernquist(
@gala_to_galax.register
def _gala_to_galax_jaffe(
gala: gp.JaffePotential, /
) -> gpx.JaffePotential | gpx.PotentialFrame:
) -> gpx.JaffePotential | gpx.PotentialFrame[gpx.JaffePotential]:
"""Convert a Gala JaffePotential to a Galax potential.

Examples
Expand All @@ -337,7 +339,7 @@ def _gala_to_galax_jaffe(
@gala_to_galax.register
def _gala_to_galax_longmuralibar(
gala: gp.LongMuraliBarPotential, /
) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame:
) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame[gpx.LongMuraliBarPotential]:
"""Convert a Gala LongMuraliBarPotential to a Galax potential.

Examples
Expand Down Expand Up @@ -373,7 +375,7 @@ def _gala_to_galax_longmuralibar(
@gala_to_galax.register
def _gala_to_galax_satoh(
gala: gp.SatohPotential, /
) -> gpx.SatohPotential | gpx.PotentialFrame:
) -> gpx.SatohPotential | gpx.PotentialFrame[gpx.SatohPotential]:
"""Convert a Gala SatohPotential to a Galax potential.

Examples
Expand Down Expand Up @@ -402,7 +404,7 @@ def _gala_to_galax_satoh(
@gala_to_galax.register
def _gala_to_galax_stoneostriker15(
gala: gp.StonePotential, /
) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame:
) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame[gpx.StoneOstriker15Potential]:
"""Convert a Gala StonePotential to a Galax potential.

Examples
Expand Down Expand Up @@ -435,7 +437,12 @@ def _gala_to_galax_stoneostriker15(
@gala_to_galax.register
def _gala_to_galax_logarithmic(
gala: gp.LogarithmicPotential, /
) -> gpx.LogarithmicPotential | gpx.LMJ09LogarithmicPotential | gpx.PotentialFrame:
) -> (
gpx.LogarithmicPotential
| gpx.LMJ09LogarithmicPotential
| gpx.PotentialFrame[gpx.LogarithmicPotential]
| gpx.PotentialFrame[gpx.LMJ09LogarithmicPotential]
):
"""Convert a Gala LogarithmicPotential to a Galax potential.

If the flattening or rotation 'phi' is non-zero, the potential is a
Expand Down Expand Up @@ -491,7 +498,7 @@ def _gala_to_galax_logarithmic(
@gala_to_galax.register
def _gala_to_galax_nfw(
gala: gp.NFWPotential, /
) -> gpx.NFWPotential | gpx.PotentialFrame:
) -> gpx.NFWPotential | gpx.PotentialFrame[gpx.NFWPotential]:
"""Convert a Gala NFWPotential to a Galax potential.

Examples
Expand Down
10 changes: 6 additions & 4 deletions src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from dataclasses import replace
from typing import cast, final
from typing import Generic, TypeVar, cast, final

import equinox as eqx

Expand All @@ -15,9 +15,11 @@
from .base import AbstractPotentialBase
from galax.utils import ImmutableDict

FramedPotT = TypeVar("FramedPotT", bound=AbstractPotentialBase)


@final
class PotentialFrame(AbstractPotentialBase):
class PotentialFrame(AbstractPotentialBase, Generic[FramedPotT]):
"""Reference frame of the potential.

Examples
Expand Down Expand Up @@ -184,7 +186,7 @@ class PotentialFrame(AbstractPotentialBase):
Array(-2.23568166, dtype=float64)
""" # noqa: E501

original_potential: AbstractPotentialBase
original_potential: FramedPotT

operator: OperatorSequence = eqx.field(default=(), converter=OperatorSequence)
"""Transformation to reference frame of the potential.
Expand Down Expand Up @@ -249,6 +251,6 @@ def _potential(


@simplify_op.register # type: ignore[misc]
def _simplify_op(frame: PotentialFrame, /) -> PotentialFrame:
def _simplify_op(frame: PotentialFrame[FramedPotT], /) -> PotentialFrame[FramedPotT]:
"""Simplify the operators in an PotentialFrame."""
return replace(frame, operator=simplify_op(frame.operator))
Loading