diff --git a/src/galax/_interop/galax_interop_gala/potential.py b/src/galax/_interop/galax_interop_gala/potential.py index 450285b4..c6e43f77 100644 --- a/src/galax/_interop/galax_interop_gala/potential.py +++ b/src/galax/_interop/galax_interop_gala/potential.py @@ -119,7 +119,9 @@ def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator: return cxo.simplify_op(frame) -def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialFrame: +def _apply_frame( + frame: cxo.AbstractOperator, pot: PT, / +) -> PT | gpx.PotentialFrame[PT]: """Apply a Galax frame to a potential.""" # A framed Galax potential never simplifies to a frameless potential. This # function applies a frame if it is not the identity operator. @@ -272,7 +274,7 @@ def galax_to_gala(_: gpx.BarPotential, /) -> gp.PotentialBase: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.BurkertPotential, / - ) -> gpx.BurkertPotential | gpx.PotentialFrame: + ) -> gpx.BurkertPotential | gpx.PotentialFrame[gpx.BurkertPotential]: """Convert a `gala.potential.BurkertPotential` to a galax.potential.BurkertPotential. Examples @@ -347,7 +349,10 @@ def galax_to_gala(pot: gpx.BurkertPotential, /) -> gp.BurkertPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.HarmonicOscillatorPotential, / -) -> gpx.HarmonicOscillatorPotential | gpx.PotentialFrame: +) -> ( + gpx.HarmonicOscillatorPotential + | gpx.PotentialFrame[gpx.HarmonicOscillatorPotential] +): r"""Convert a `gala.potential.HarmonicOscillatorPotential` to a `galax.potential.HarmonicOscillatorPotential`. Examples @@ -404,7 +409,7 @@ def galax_to_gala( @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.HernquistPotential, / -) -> gpx.HernquistPotential | gpx.PotentialFrame: +) -> gpx.HernquistPotential | gpx.PotentialFrame[gpx.HernquistPotential]: r"""Convert a `gala.potential.HernquistPotential` to a `galax.potential.HernquistPotential`. Examples @@ -460,7 +465,7 @@ def galax_to_gala(pot: gpx.HernquistPotential, /) -> gp.HernquistPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.IsochronePotential, / -) -> gpx.IsochronePotential | gpx.PotentialFrame: +) -> gpx.IsochronePotential | gpx.PotentialFrame[gpx.IsochronePotential]: """Convert a `gala.potential.IsochronePotential` to a `galax.potential.IsochronePotential`. Examples @@ -519,7 +524,7 @@ def galax_to_gala(pot: gpx.IsochronePotential, /) -> gp.IsochronePotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.JaffePotential, / -) -> gpx.JaffePotential | gpx.PotentialFrame: +) -> gpx.JaffePotential | gpx.PotentialFrame[gpx.JaffePotential]: """Convert a Gala JaffePotential to a Galax potential. Examples @@ -575,7 +580,7 @@ def galax_to_gala(pot: gpx.JaffePotential, /) -> gp.JaffePotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.KeplerPotential, / -) -> gpx.KeplerPotential | gpx.PotentialFrame: +) -> gpx.KeplerPotential | gpx.PotentialFrame[gpx.KeplerPotential]: """Convert a `gala.potential.KeplerPotential` to a `galax.potential.KeplerPotential`. Examples @@ -632,7 +637,7 @@ def galax_to_gala(pot: gpx.KeplerPotential, /) -> gp.KeplerPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.KuzminPotential, / -) -> gpx.KuzminPotential | gpx.PotentialFrame: +) -> gpx.KuzminPotential | gpx.PotentialFrame[gpx.KuzminPotential]: """Convert a `gala.potential.KuzminPotential` to a `galax.potential.KuzminPotential`. Examples @@ -692,7 +697,7 @@ def galax_to_gala(pot: gpx.KuzminPotential, /) -> gp.KuzminPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.LongMuraliBarPotential, / -) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame: +) -> gpx.LongMuraliBarPotential | gpx.PotentialFrame[gpx.LongMuraliBarPotential]: """Convert a Gala LongMuraliBarPotential to a Galax potential. Examples @@ -766,7 +771,7 @@ def galax_to_gala(pot: gpx.LongMuraliBarPotential, /) -> gp.LongMuraliBarPotenti @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.MiyamotoNagaiPotential, / -) -> gpx.MiyamotoNagaiPotential | gpx.PotentialFrame: +) -> gpx.MiyamotoNagaiPotential | gpx.PotentialFrame[gpx.MiyamotoNagaiPotential]: """Convert a `gala.potential.MiyamotoNagaiPotential` to a `galax.potential.MiyamotoNagaiPotential`. Examples @@ -870,7 +875,7 @@ def galax_to_gala(pot: gpx.NullPotential, /) -> gp.NullPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.PlummerPotential, / -) -> gpx.PlummerPotential | gpx.PotentialFrame: +) -> gpx.PlummerPotential | gpx.PotentialFrame[gpx.PlummerPotential]: """Convert a `gala.potential.PlummerPotential` to a `galax.potential.PlummerPotential`. Examples @@ -929,7 +934,7 @@ def galax_to_gala(pot: gpx.PlummerPotential, /) -> gp.PlummerPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.PowerLawCutoffPotential, / -) -> gpx.PowerLawCutoffPotential | gpx.PotentialFrame: +) -> gpx.PowerLawCutoffPotential | gpx.PotentialFrame[gpx.PowerLawCutoffPotential]: """Convert a `gala.potential.PowerLawCutoffPotential` to a `galax.potential.PowerLawCutoffPotential`. Examples @@ -1005,7 +1010,7 @@ def galax_to_gala(pot: gpx.PowerLawCutoffPotential, /) -> gp.PowerLawCutoffPoten @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.SatohPotential, / -) -> gpx.SatohPotential | gpx.PotentialFrame: +) -> gpx.SatohPotential | gpx.PotentialFrame[gpx.SatohPotential]: """Convert a Gala SatohPotential to a Galax potential. Examples @@ -1063,7 +1068,7 @@ def galax_to_gala(pot: gpx.SatohPotential, /) -> gp.SatohPotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.StonePotential, / -) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame: +) -> gpx.StoneOstriker15Potential | gpx.PotentialFrame[gpx.StoneOstriker15Potential]: """Convert a `gala.potential.StonePotential` to a `galax.potential.StoneOstriker15Potential`. Examples @@ -1121,7 +1126,12 @@ def galax_to_gala(pot: gpx.StoneOstriker15Potential, /) -> gp.StonePotential: @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.LogarithmicPotential, / -) -> gpx.LogarithmicPotential | gpx.LMJ09LogarithmicPotential | gpx.PotentialFrame: +) -> ( + gpx.LogarithmicPotential + | gpx.LMJ09LogarithmicPotential + | gpx.PotentialFrame[gpx.LogarithmicPotential] + | gpx.PotentialFrame[gpx.LMJ09LogarithmicPotential] +): """Convert a Gala LogarithmicPotential to a Galax potential. If the flattening or rotation 'phi' is non-zero, the potential is a @@ -1235,7 +1245,12 @@ def galax_to_gala(pot: gpx.LMJ09LogarithmicPotential, /) -> gp.LogarithmicPotent @dispatch # type: ignore[misc] def gala_to_galax( gala: gp.MultipolePotential, / -) -> gpx.MultipoleInnerPotential | gpx.MultipoleOuterPotential | gpx.PotentialFrame: +) -> ( + gpx.MultipoleInnerPotential + | gpx.MultipoleOuterPotential + | gpx.PotentialFrame[gpx.MultipoleInnerPotential] + | gpx.PotentialFrame[gpx.MultipoleOuterPotential] +): params = gala.parameters cls = ( gpx.MultipoleInnerPotential @@ -1297,7 +1312,9 @@ def galax_to_gala( @dispatch -def gala_to_galax(gala: gp.NFWPotential, /) -> gpx.NFWPotential | gpx.PotentialFrame: +def gala_to_galax( + gala: gp.NFWPotential, / +) -> gpx.NFWPotential | gpx.PotentialFrame[gpx.NFWPotential]: """Convert a Gala NFWPotential to a Galax potential. Examples @@ -1367,7 +1384,9 @@ def gala_to_galax( @dispatch -def gala_to_galax(gala: gp.NFWPotential, /) -> gpx.NFWPotential | gpx.PotentialFrame: +def gala_to_galax( + gala: gp.NFWPotential, / +) -> gpx.NFWPotential | gpx.PotentialFrame[gpx.NFWPotential]: """Convert a Gala NFWPotential to a Galax potential. Examples diff --git a/src/galax/potential/_src/frame.py b/src/galax/potential/_src/frame.py index 4b11c9ce..20af9731 100644 --- a/src/galax/potential/_src/frame.py +++ b/src/galax/potential/_src/frame.py @@ -4,7 +4,7 @@ from dataclasses import replace -from typing import cast, final +from typing import Generic, TypeVar, cast, final import equinox as eqx @@ -14,10 +14,19 @@ import galax.typing as gt from .base import AbstractPotentialBase +from galax.utils.dataclasses import ModuleMeta + +FramedPotT = TypeVar("FramedPotT", bound=AbstractPotentialBase) + + +class WrapsPotential(eqx.Module, Generic[FramedPotT], metaclass=ModuleMeta): # type: ignore[misc] + """Protocol for a class that wraps a potential.""" + + original_potential: FramedPotT @final -class PotentialFrame(AbstractPotentialBase): +class PotentialFrame(AbstractPotentialBase, WrapsPotential[FramedPotT]): """Reference frame of the potential. Examples @@ -184,7 +193,7 @@ class PotentialFrame(AbstractPotentialBase): Array(-2.23568166, dtype=float64) """ # noqa: E501 - original_potential: AbstractPotentialBase + original_potential: FramedPotT operator: OperatorSequence = eqx.field(default=(), converter=OperatorSequence) """Transformation to reference frame of the potential. @@ -248,7 +257,7 @@ def _potential( ##################################################################### -@simplify_op.register # type: ignore[misc] -def _simplify_op(frame: PotentialFrame, /) -> PotentialFrame: +@simplify_op.register(PotentialFrame) # type: ignore[misc] +def _simplify_op(frame: PotentialFrame[FramedPotT], /) -> PotentialFrame[FramedPotT]: """Simplify the operators in an PotentialFrame.""" return replace(frame, operator=simplify_op(frame.operator)) diff --git a/src/galax/potential/_src/params/attr.py b/src/galax/potential/_src/params/attr.py index 2a132125..412e06cb 100644 --- a/src/galax/potential/_src/params/attr.py +++ b/src/galax/potential/_src/params/attr.py @@ -31,11 +31,9 @@ class AbstractParametersAttribute: >>> pot = gp.KeplerPotential(m_tot=1e12, units="galactic") >>> pot.parameters - mappingproxy({'m_tot': ConstantParameter( - value=Quantity[...](value=f64[], unit=Unit("solMass")) - )}) + mappingproxy({'m_tot': ConstantParameter(Quantity['mass'](Array(1.e+12, dtype=float64), unit='solMass'))}) - """ + """ # noqa: E501 parameters: "MappingProxyType[str, ParameterField]" # TODO: specify type hint """Class attribute name on Potential.""" @@ -77,11 +75,9 @@ class ParametersAttribute(AbstractParametersAttribute): >>> kepler = gp.KeplerPotential(m_tot=1e12, units="galactic") >>> kepler.parameters - mappingproxy({'m_tot': ConstantParameter( - value=Quantity[...](value=f64[], unit=Unit("solMass")) - )}) + mappingproxy({'m_tot': ConstantParameter(Quantity['mass'](Array(1.e+12, dtype=float64), unit='solMass'))}) - """ + """ # noqa: E501 def __get__( self, @@ -122,11 +118,9 @@ class CompositeParametersAttribute(AbstractParametersAttribute): >>> kepler = gp.KeplerPotential(m_tot=1e12, units="galactic") >>> composite = gp.CompositePotential(kepler=kepler) >>> composite.parameters - mappingproxy({'kepler': mappingproxy({'m_tot': ConstantParameter( - value=Quantity[PhysicalType('mass')](value=f64[], unit=Unit("solMass")) - )})}) + mappingproxy({'kepler': mappingproxy({'m_tot': ConstantParameter(Quantity['mass'](Array(1.e+12, dtype=float64), unit='solMass'))})}) - """ + """ # noqa: E501 def __get__( self, diff --git a/src/galax/potential/_src/params/core.py b/src/galax/potential/_src/params/core.py index ca92c0d1..589c2c21 100644 --- a/src/galax/potential/_src/params/core.py +++ b/src/galax/potential/_src/params/core.py @@ -116,6 +116,23 @@ def __call__(self, t: BatchableRealQScalar = t0, **_: Any) -> FloatQAnyShape: # ------------------------------------------- + def __repr__(self) -> str: + """Return the string representation. + + Examples + -------- + >>> from galax.potential.params import ConstantParameter + >>> from unxt import Quantity + + >>> cp = ConstantParameter(value=Quantity(1e9, "Msun")) + >>> cp + ConstantParameter(Quantity['mass'](Array(1.e+09, dtype=float64), unit='solMass')) + + """ # noqa: E501 + return f"{self.__class__.__name__}({self.value!r})" + + # ------------------------------------------- + def __mul__(self, other: Any) -> "Self": return replace(self, value=self.value * other) @@ -127,6 +144,7 @@ def __rmul__(self, other: Any) -> "Self": # Linear time dependence Parameter +@final class LinearParameter(AbstractParameter): """Linear time dependence Parameter.