diff --git a/src/galax/_galax_interop_gala/potential.py b/src/galax/_galax_interop_gala/potential.py index 96cd5132..818e53aa 100644 --- a/src/galax/_galax_interop_gala/potential.py +++ b/src/galax/_galax_interop_gala/potential.py @@ -171,7 +171,9 @@ def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator: return cxo.simplify_op(frame) -def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialFrame: +def _apply_frame( + frame: cxo.AbstractOperator, pot: PT, / +) -> PT | gpx.PotentialFrame[PT]: return ( pot if isinstance(frame, IdentityOperator) else gpx.PotentialFrame(pot, frame) ) @@ -205,7 +207,7 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote @gala_to_galax.register(gp.PowerLawCutoffPotential) def _gala_to_galax_registered( gala: gp.PotentialBase, / -) -> gpx.AbstractPotential | gpx.PotentialFrame: +) -> gpx.AbstractPotential | gpx.PotentialFrame[gpx.AbstractPotential]: """Convert a Gala potential to a Galax potential.""" if isinstance(gala.units, GalaDimensionlessUnitSystem): msg = "Galax does not support converting dimensionless units." @@ -247,7 +249,7 @@ def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential: @gala_to_galax.register def _gala_to_galax_burkert( gala: gp.BurkertPotential, / - ) -> gpx.BurkertPotential | gpx.PotentialFrame: + ) -> gpx.BurkertPotential | gpx.PotentialFrame[gpx.BurkertPotential]: """Convert a Gala BurkertPotential to a Galax potential. Examples @@ -285,7 +287,7 @@ def _gala_to_galax_burkert( @gala_to_galax.register def _gala_to_galax_hernquist( gala: gp.HernquistPotential, / -) -> gpx.HernquistPotential | gpx.PotentialFrame: +) -> gpx.HernquistPotential | gpx.PotentialFrame[gpx.HernquistPotential]: r"""Convert a Gala HernquistPotential to a Galax potential. Examples @@ -311,7 +313,7 @@ def _gala_to_galax_hernquist( @gala_to_galax.register def _gala_to_galax_jaffe( gala: gp.JaffePotential, / -) -> gpx.JaffePotential | gpx.PotentialFrame: +) -> gpx.JaffePotential | gpx.PotentialFrame[gpx.JaffePotential]: """Convert a Gala JaffePotential to a Galax potential. Examples @@ -337,7 +339,7 @@ def _gala_to_galax_jaffe( @gala_to_galax.register def _gala_to_galax_longmuralibar( gala: gp.LongMuraliBarPotential, / -) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame: +) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame[gpx.LongMuraliBarPotential]: """Convert a Gala LongMuraliBarPotential to a Galax potential. Examples @@ -373,7 +375,7 @@ def _gala_to_galax_longmuralibar( @gala_to_galax.register def _gala_to_galax_satoh( gala: gp.SatohPotential, / -) -> gpx.SatohPotential | gpx.PotentialFrame: +) -> gpx.SatohPotential | gpx.PotentialFrame[gpx.SatohPotential]: """Convert a Gala SatohPotential to a Galax potential. Examples @@ -402,7 +404,7 @@ def _gala_to_galax_satoh( @gala_to_galax.register def _gala_to_galax_stoneostriker15( gala: gp.StonePotential, / -) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame: +) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame[gpx.StoneOstriker15Potential]: """Convert a Gala StonePotential to a Galax potential. Examples @@ -435,7 +437,12 @@ def _gala_to_galax_stoneostriker15( @gala_to_galax.register def _gala_to_galax_logarithmic( gala: gp.LogarithmicPotential, / -) -> gpx.LogarithmicPotential | gpx.LMJ09LogarithmicPotential | gpx.PotentialFrame: +) -> ( + gpx.LogarithmicPotential + | gpx.LMJ09LogarithmicPotential + | gpx.PotentialFrame[gpx.LogarithmicPotential] + | gpx.PotentialFrame[gpx.LMJ09LogarithmicPotential] +): """Convert a Gala LogarithmicPotential to a Galax potential. If the flattening or rotation 'phi' is non-zero, the potential is a @@ -491,7 +498,7 @@ def _gala_to_galax_logarithmic( @gala_to_galax.register def _gala_to_galax_nfw( gala: gp.NFWPotential, / -) -> gpx.NFWPotential | gpx.PotentialFrame: +) -> gpx.NFWPotential | gpx.PotentialFrame[gpx.NFWPotential]: """Convert a Gala NFWPotential to a Galax potential. Examples diff --git a/src/galax/potential/_potential/frame.py b/src/galax/potential/_potential/frame.py index efba12fa..362faa0f 100644 --- a/src/galax/potential/_potential/frame.py +++ b/src/galax/potential/_potential/frame.py @@ -4,7 +4,7 @@ from dataclasses import replace -from typing import cast, final +from typing import Generic, TypeVar, cast, final import equinox as eqx @@ -15,9 +15,11 @@ from .base import AbstractPotentialBase from galax.utils import ImmutableDict +FramedPotT = TypeVar("FramedPotT", bound=AbstractPotentialBase) + @final -class PotentialFrame(AbstractPotentialBase): +class PotentialFrame(AbstractPotentialBase, Generic[FramedPotT]): """Reference frame of the potential. Examples @@ -184,7 +186,7 @@ class PotentialFrame(AbstractPotentialBase): Array(-2.23568166, dtype=float64) """ # noqa: E501 - original_potential: AbstractPotentialBase + original_potential: FramedPotT operator: OperatorSequence = eqx.field(default=(), converter=OperatorSequence) """Transformation to reference frame of the potential. @@ -249,6 +251,6 @@ def _potential( @simplify_op.register # type: ignore[misc] -def _simplify_op(frame: PotentialFrame, /) -> PotentialFrame: +def _simplify_op(frame: PotentialFrame[FramedPotT], /) -> PotentialFrame[FramedPotT]: """Simplify the operators in an PotentialFrame.""" return replace(frame, operator=simplify_op(frame.operator))