diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe070eb6..a5383beb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,6 @@ repos: - id: mixed-line-ending - id: name-tests-test args: ["--pytest-test-first"] - exclude: '^tests\/.*_helper\.py$' - id: trailing-whitespace - repo: https://github.com/pre-commit/pygrep-hooks diff --git a/src/galax/_galax_interop_gala/potential.py b/src/galax/_galax_interop_gala/potential.py index 20b837e7..c0bb3622 100644 --- a/src/galax/_galax_interop_gala/potential.py +++ b/src/galax/_galax_interop_gala/potential.py @@ -1,18 +1,24 @@ """Interoperability.""" -__all__ = ["gala_to_galax"] +__all__ = ["gala_to_galax", "galax_to_gala"] from functools import singledispatch from typing import TypeVar import gala.potential as gp -from gala.units import DimensionlessUnitSystem as GalaDimensionlessUnitSystem +from astropy.units import Quantity as APYQuantity +from gala.units import ( + DimensionlessUnitSystem as GalaDimensionlessUnitSystem, + UnitSystem as GalaUnitSystem, + dimensionless as gala_dimensionless, +) from packaging.version import Version -from plum import dispatch +from plum import convert, dispatch import coordinax.operators as cxo from coordinax.operators import IdentityOperator from unxt import Quantity +from unxt.unitsystems import AbstractUnitSystem, DimensionlessUnitSystem import galax.potential as gpx from galax.utils._optional_deps import HAS_GALA @@ -21,7 +27,7 @@ # Hook into general dispatcher -@dispatch # type: ignore[misc] +@dispatch def convert_potential( to_: gpx.AbstractPotentialBase | type[gpx.io.GalaxLibrary], # noqa: ARG001 from_: gp.CPotentialBase | gp.PotentialBase, @@ -30,8 +36,60 @@ def convert_potential( return gala_to_galax(from_) +@dispatch +def convert_potential( + to_: gp.CPotentialBase | gp.PotentialBase | type[gpx.io.GalaLibrary], # noqa: ARG001 + from_: gpx.AbstractPotentialBase, + /, +) -> gp.CPotentialBase | gp.PotentialBase: + return galax_to_gala(from_) + + ############################################################################## -# GALA -> GALAX +# GALAX <-> GALA + +# ----------------------- +# Helper functions + +PT = TypeVar("PT", bound=gpx.AbstractPotentialBase) + + +def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator: + """Convert a Gala frame to a Galax frame.""" + frame = cxo.GalileanSpatialTranslationOperator( + Quantity(pot.origin, unit=pot.units["length"]) + ) + if pot.R is not None: + frame = cxo.GalileanRotationOperator(pot.R) | frame + return cxo.simplify_op(frame) + + +def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialFrame: + """Apply a Galax frame to a potential.""" + # A framed Galax potential never simplifies to a frameless potential. This + # function applies a frame if it is not the identity operator. + return ( + pot if isinstance(frame, IdentityOperator) else gpx.PotentialFrame(pot, frame) + ) + + +def _galax_to_gala_units(units: AbstractUnitSystem, /) -> GalaUnitSystem: + """Convert a Galax unit system to a Gala unit system.""" + # Galax potentials naturally convert Gala unit systems, but Gala potentials + # do not convert Galax unit systems. This function is used for that purpose. + if isinstance(units, DimensionlessUnitSystem): + return gala_dimensionless + return GalaUnitSystem(units) + + +def _all_constant_parameters(pot: gpx.AbstractPotentialBase, *params: str) -> bool: + """Check if all parameters are constant.""" + return all( + isinstance(getattr(pot, name), gpx.params.ConstantParameter) for name in params + ) + + +# ----------------------------------------------------------------------------- @singledispatch @@ -155,29 +213,32 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase: raise NotImplementedError(msg) -# ----------------------- -# Helper functions - -PT = TypeVar("PT", bound=gpx.AbstractPotentialBase) - - -def _get_frame(pot: gp.PotentialBase, /) -> cxo.AbstractOperator: - frame = cxo.GalileanSpatialTranslationOperator( - Quantity(pot.origin, unit=pot.units["length"]) - ) - if pot.R is not None: - frame = cxo.GalileanRotationOperator(pot.R) | frame - return cxo.simplify_op(frame) +# TODO: add an argument to specify how to handle time-dependent parameters. +# Gala potentials are not time-dependent, so we need to specify how to +# handle time-dependent Galax parameters. +@singledispatch +def galax_to_gala(pot: gpx.AbstractPotentialBase, /) -> gp.PotentialBase: + """Convert a Galax potential to a Gala potential. + Parameters + ---------- + pot : :class:`~galax.potential.AbstractPotentialBase` + Galax potential. -def _apply_frame(frame: cxo.AbstractOperator, pot: PT, /) -> PT | gpx.PotentialFrame: - return ( - pot if isinstance(frame, IdentityOperator) else gpx.PotentialFrame(pot, frame) + Returns + ------- + gala_pot : :class:`~gala.potential.PotentialBase` + Gala potential. + """ + msg = ( + "`galax_to_gala` does not have a registered function to convert " + f"{pot.__class__.__name__!r} to a galax potential." ) + raise NotImplementedError(msg) # ----------------------------------------------------------------------------- -# General rules +# Composite potentials @gala_to_galax.register @@ -186,6 +247,15 @@ def _gala_to_galax_composite(pot: gp.CompositePotential, /) -> gpx.CompositePote return gpx.CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()}) +@galax_to_gala.register +def _galax_to_gala_composite(pot: gpx.CompositePotential, /) -> gp.CompositePotential: + """Convert a Galax CompositePotential to a Gala potential.""" + return gp.CompositePotential(**{k: galax_to_gala(p) for k, p in pot.items()}) + + +# ----------------------------------------------------------------------------- +# General rules + _GALA_TO_GALAX_REGISTRY: dict[type[gp.PotentialBase], type[gpx.AbstractPotential]] = { gp.IsochronePotential: gpx.IsochronePotential, gp.KeplerPotential: gpx.KeplerPotential, @@ -219,26 +289,45 @@ def _gala_to_galax_registered( return _apply_frame(_get_frame(gala), pot) -# ----------------------------------------------------------------------------- -# Builtin potentials +_GALAX_TO_GALA_REGISTRY: dict[type[gpx.AbstractPotential], type[gp.PotentialBase]] = { + v: k for k, v in _GALA_TO_GALAX_REGISTRY.items() +} -@gala_to_galax.register -def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential: - """Convert a Gala NullPotential to a Galax potential. +@galax_to_gala.register(gpx.IsochronePotential) +@galax_to_gala.register(gpx.KeplerPotential) +@galax_to_gala.register(gpx.KuzminPotential) +@galax_to_gala.register(gpx.MiyamotoNagaiPotential) +@galax_to_gala.register(gpx.PlummerPotential) +@galax_to_gala.register(gpx.PowerLawCutoffPotential) +def _galax_to_gala_registered(pot: gpx.AbstractPotential, /) -> gp.PotentialBase: + """Convert a Galax AbstractPotential to a Gala potential.""" + if not _all_constant_parameters(pot, *pot.parameters.keys()): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) - Examples - -------- - >>> import gala.potential as gp - >>> import galax.potential as gpx + # TODO: this is a temporary solution. It would be better to map each + # potential individually. + params = { + k: convert(getattr(pot, k)(0), APYQuantity) + for (k, f) in type(pot).parameters.items() + } + if "m_tot" in params: + params["m"] = params.pop("m_tot") + + return _GALAX_TO_GALA_REGISTRY[type(pot)]( + **params, units=_galax_to_gala_units(pot.units) + ) - >>> gpot = gp.NullPotential() - >>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot) - NullPotential( units=DimensionlessUnitSystem(), - constants=ImmutableDict({'G': ...}) ) - """ - return gpx.NullPotential(units=pot.units) +# ----------------------------------------------------------------------------- +# Builtin potentials + + +@galax_to_gala.register +def _galax_to_gala_bar(_: gpx.BarPotential, /) -> gp.PotentialBase: + """Convert a Galax BarPotential to a Gala potential.""" + raise NotImplementedError # TODO: implement if HAS_GALA and (Version("1.8.2") <= HAS_GALA): @@ -280,6 +369,22 @@ def _gala_to_galax_burkert( ) return _apply_frame(_get_frame(gala), pot) + @galax_to_gala.register + def _galax_to_gala_burkert(pot: gpx.BurkertPotential, /) -> gp.BurkertPotential: + """Convert a Galax BurkertPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.BurkertPotential( + rho=convert(pot.rho0(0), APYQuantity), + r0=convert(pot.r_s(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + +# --------------------------- +# Hernquist potentials + @gala_to_galax.register def _gala_to_galax_hernquist( @@ -307,6 +412,24 @@ def _gala_to_galax_hernquist( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_hernquist(pot: gpx.HernquistPotential, /) -> gp.HernquistPotential: + """Convert a Galax HernquistPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m_tot", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.HernquistPotential( + m=convert(pot.m_tot(0), APYQuantity), + c=convert(pot.r_s(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + +# --------------------------- +# Jaffe potentials + + @gala_to_galax.register def _gala_to_galax_jaffe( gala: gp.JaffePotential, / @@ -333,6 +456,24 @@ def _gala_to_galax_jaffe( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_jaffe(pot: gpx.JaffePotential, /) -> gp.JaffePotential: + """Convert a Galax JaffePotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.JaffePotential( + m=convert(pot.m(0), APYQuantity), + c=convert(pot.r_s(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + +# --------------------------- +# Long & Murali Bar potentials + + @gala_to_galax.register def _gala_to_galax_longmuralibar( gala: gp.LongMuraliBarPotential, / @@ -369,6 +510,58 @@ def _gala_to_galax_longmuralibar( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_longmuralibar( + pot: gpx.LongMuraliBarPotential, / +) -> gp.LongMuraliBarPotential: + """Convert a Galax LongMuraliBarPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m_tot", "a", "b", "c", "alpha"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.LongMuraliBarPotential( + m=convert(pot.m_tot(0), APYQuantity), + a=convert(pot.a(0), APYQuantity), + b=convert(pot.b(0), APYQuantity), + c=convert(pot.c(0), APYQuantity), + alpha=convert(pot.alpha(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + +# --------------------------- +# Null potentials + + +@gala_to_galax.register +def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential: + """Convert a Gala NullPotential to a Galax potential. + + Examples + -------- + >>> import gala.potential as gp + >>> import galax.potential as gpx + + >>> gpot = gp.NullPotential() + >>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot) + NullPotential( units=DimensionlessUnitSystem(), + constants=ImmutableDict({'G': ...}) ) + + """ + return gpx.NullPotential(units=pot.units) + + +@galax_to_gala.register +def _galax_to_gala_null(pot: gpx.NullPotential, /) -> gp.NullPotential: + return gp.NullPotential( + units=_galax_to_gala_units(pot.units), + ) + + +# --------------------------- +# Satoh potentials + + @gala_to_galax.register def _gala_to_galax_satoh( gala: gp.SatohPotential, / @@ -398,6 +591,25 @@ def _gala_to_galax_satoh( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_satoh(pot: gpx.SatohPotential, /) -> gp.SatohPotential: + """Convert a Galax SatohPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m_tot", "a", "b"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.SatohPotential( + m=convert(pot.m_tot(0), APYQuantity), + a=convert(pot.a(0), APYQuantity), + b=convert(pot.b(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + +# --------------------------- +# Stone & Ostriker potentials + + @gala_to_galax.register def _gala_to_galax_stoneostriker15( gala: gp.StonePotential, / @@ -427,6 +639,23 @@ def _gala_to_galax_stoneostriker15( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_stoneostriker15( + pot: gpx.StoneOstriker15Potential, / +) -> gp.StonePotential: + """Convert a Galax StoneOstriker15Potential to a Gala potential.""" + if not _all_constant_parameters(pot, "m_tot", "r_c", "r_h"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.StonePotential( + m=convert(pot.m_tot(0), APYQuantity), + r_c=convert(pot.r_c(0), APYQuantity), + r_h=convert(pot.r_h(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + # ----------------------------------------------------------------------------- # Logarithmic potentials @@ -483,6 +712,42 @@ def _gala_to_galax_logarithmic( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_logarithmic( + pot: gpx.LogarithmicPotential, / +) -> gp.LogarithmicPotential: + """Convert a Galax LogarithmicPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "v_c", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.LogarithmicPotential( + v_c=convert(pot.v_c(0), APYQuantity), + r_h=convert(pot.r_s(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + +@galax_to_gala.register +def _galax_to_gala_logarithmic( + pot: gpx.LMJ09LogarithmicPotential, / +) -> gp.LogarithmicPotential: + """Convert a Galax LogarithmicPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "v_c", "r_s", "q1", "q2", "q3", "phi"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.LogarithmicPotential( + v_c=convert(pot.v_c(0), APYQuantity), + r_h=convert(pot.r_s(0), APYQuantity), + q1=convert(pot.q1(0), APYQuantity), + q2=convert(pot.q2(0), APYQuantity), + q3=convert(pot.q3(0), APYQuantity), + phi=convert(pot.phi(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + # ----------------------------------------------------------------------------- # NFW potentials @@ -514,6 +779,20 @@ def _gala_to_galax_nfw( return _apply_frame(_get_frame(gala), pot) +@galax_to_gala.register +def _galax_to_gala_nfw(pot: gpx.NFWPotential, /) -> gp.NFWPotential: + """Convert a Galax NFWPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return gp.NFWPotential( + m=convert(pot.m(0), APYQuantity), + r_s=convert(pot.r_s(0), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + @gala_to_galax.register def _gala_to_galax_leesutotriaxialnfw( pot: gp.LeeSutoTriaxialNFWPotential, / @@ -555,9 +834,33 @@ def _gala_to_galax_leesutotriaxialnfw( ) +@galax_to_gala.register +def _galax_to_gala_leesutotriaxialnfw( + pot: gpx.LeeSutoTriaxialNFWPotential, / +) -> gp.LeeSutoTriaxialNFWPotential: + """Convert a Galax LeeSutoTriaxialNFWPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "r_s", "a1", "a2", "a3"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + t = Quantity(0.0, pot.units["time"]) + + return gp.LeeSutoTriaxialNFWPotential( + v_c=convert(xp.sqrt(pot.constants["G"] * pot.m(t) / pot.r_s(t)), APYQuantity), + r_s=convert(pot.r_s(t), APYQuantity), + a=convert(pot.a1(t), APYQuantity), + b=convert(pot.a2(t), APYQuantity), + c=convert(pot.a3(t), APYQuantity), + units=_galax_to_gala_units(pot.units), + ) + + # ----------------------------------------------------------------------------- # MW potentials +# --------------------------- +# Bovy MWPotential2014 + @gala_to_galax.register def _gala_to_galax_bovymw2014( @@ -592,6 +895,31 @@ def _gala_to_galax_bovymw2014( ) +@galax_to_gala.register +def _galax_to_gala_bovymw2014( + pot: gpx.BovyMWPotential2014, / +) -> gp.BovyMWPotential2014: + """Convert a Galax BovyMWPotential2014 to a Gala potential.""" + + def rename(k: str) -> str: + match k: + case "m_tot": + return "m" + case _: + return k + + return gp.BovyMWPotential2014( + **{ + c: {rename(k): getattr(p, k)(0) for k in p.parameters} + for c, p in pot.items() + } + ) + + +# --------------------------- +# LM10 potentials + + @gala_to_galax.register def _gala_to_galax_lm10(pot: gp.LM10Potential, /) -> gpx.LM10Potential: """Convert a Gala LM10Potential to a Galax potential. @@ -615,6 +943,33 @@ def _gala_to_galax_lm10(pot: gp.LM10Potential, /) -> gpx.LM10Potential: ) +@galax_to_gala.register +def _galax_to_gala_lm10(pot: gpx.LM10Potential, /) -> gp.LM10Potential: + """Convert a Galax LM10Potential to a Gala potential.""" + + def rename(c: str, k: str) -> str: + match k: + case "m_tot": + return "m" + case "r_s" if c == "halo": + return "r_h" + case "r_s" if c == "bulge": + return "c" + case _: + return k + + return gp.LM10Potential( + **{ + c: {rename(c, k): getattr(p, k)(0) for k in p.parameters} + for c, p in pot.items() + } + ) + + +# --------------------------- +# Galax MilkyWayPotential + + @gala_to_galax.register def _gala_to_galax_mw(pot: gp.MilkyWayPotential, /) -> gpx.MilkyWayPotential: """Convert a Gala MilkyWayPotential to a Galax potential. @@ -638,3 +993,24 @@ def _gala_to_galax_mw(pot: gp.MilkyWayPotential, /) -> gpx.MilkyWayPotential: bulge=gala_to_galax(pot["bulge"]), nucleus=gala_to_galax(pot["nucleus"]), ) + + +@galax_to_gala.register +def _galax_to_gala_mwpotential(pot: gpx.MilkyWayPotential, /) -> gp.MilkyWayPotential: + """Convert a Galax MilkyWayPotential to a Gala potential.""" + + def rename(c: str, k: str) -> str: + match k: + case "m_tot": + return "m" + case "r_s" if c in ("bulge", "nucleus"): + return "c" + case _: + return k + + return gp.MilkyWayPotential( + **{ + c: {rename(c, k): getattr(p, k)(0) for k in p.parameters} + for c, p in pot.items() + } + ) diff --git a/tests/unit/potential/io/gala_helper.py b/tests/unit/potential/io/gala_helper.py deleted file mode 100644 index 6aaa622a..00000000 --- a/tests/unit/potential/io/gala_helper.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Interoperability.""" - -__all__ = ["galax_to_gala"] - -from functools import singledispatch - -import gala.potential as gp -from astropy.units import Quantity as APYQuantity -from gala.units import UnitSystem as GalaUnitSystem, dimensionless as gala_dimensionless -from packaging.version import Version -from plum import convert - -import quaxed.array_api as xp -from unxt import Quantity -from unxt.unitsystems import AbstractUnitSystem, DimensionlessUnitSystem - -import galax.potential as gpx -from galax._galax_interop_gala.potential import _GALA_TO_GALAX_REGISTRY -from galax.utils._optional_deps import HAS_GALA - -############################################################################## -# UnitSystem - - -def galax_to_gala_units(units: AbstractUnitSystem, /) -> GalaUnitSystem: - if isinstance(units, DimensionlessUnitSystem): - return gala_dimensionless - return GalaUnitSystem(units) - - -############################################################################## -# GALAX -> GALA - - -def _all_constant_parameters( - pot: "gpx.AbstractPotentialBase", - *params: str, -) -> bool: - return all( - isinstance(getattr(pot, name), gpx.params.ConstantParameter) for name in params - ) - - -# TODO: add an argument to specify how to handle time-dependent parameters. -# Gala potentials are not time-dependent, so we need to specify how to -# handle time-dependent Galax parameters. -@singledispatch -def galax_to_gala(pot: gpx.AbstractPotentialBase, /) -> gp.PotentialBase: - """Convert a Galax potential to a Gala potential. - - Parameters - ---------- - pot : :class:`~galax.potential.AbstractPotentialBase` - Galax potential. - - Returns - ------- - gala_pot : :class:`~gala.potential.PotentialBase` - Gala potential. - """ - msg = ( - "`galax_to_gala` does not have a registered function to convert " - f"{pot.__class__.__name__!r} to a galax potential." - ) - raise NotImplementedError(msg) - - -# ----------------------------------------------------------------------------- -# General rules - - -@galax_to_gala.register -def _galax_to_gala_composite(pot: gpx.CompositePotential, /) -> gp.CompositePotential: - """Convert a Galax CompositePotential to a Gala potential.""" - return gp.CompositePotential(**{k: galax_to_gala(p) for k, p in pot.items()}) - - -_GALAX_TO_GALA_REGISTRY: dict[type[gpx.AbstractPotential], type[gp.PotentialBase]] = { - v: k for k, v in _GALA_TO_GALAX_REGISTRY.items() -} - - -@galax_to_gala.register(gpx.IsochronePotential) -@galax_to_gala.register(gpx.KeplerPotential) -@galax_to_gala.register(gpx.KuzminPotential) -@galax_to_gala.register(gpx.MiyamotoNagaiPotential) -@galax_to_gala.register(gpx.PlummerPotential) -@galax_to_gala.register(gpx.PowerLawCutoffPotential) -def _galax_to_gala_abstractpotential(pot: gpx.AbstractPotential, /) -> gp.PotentialBase: - """Convert a Galax AbstractPotential to a Gala potential.""" - if not _all_constant_parameters(pot, *pot.parameters.keys()): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - # TODO: this is a temporary solution. It would be better to map each - # potential individually. - params = { - k: convert(getattr(pot, k)(0), APYQuantity) - for (k, f) in type(pot).parameters.items() - } - if "m_tot" in params: - params["m"] = params.pop("m_tot") - - return _GALAX_TO_GALA_REGISTRY[type(pot)]( - **params, - units=galax_to_gala_units(pot.units), - ) - - -# ----------------------------------------------------------------------------- -# Builtin potentials - - -@galax_to_gala.register -def _galax_to_gala_bar(pot: gpx.BarPotential, /) -> gp.PotentialBase: - """Convert a Galax BarPotential to a Gala potential.""" - raise NotImplementedError # TODO: implement - - -if HAS_GALA and (Version("1.8.2") <= HAS_GALA): - - @galax_to_gala.register - def _galax_to_gala_burkert(pot: gpx.BurkertPotential, /) -> gp.BurkertPotential: - """Convert a Galax BurkertPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m", "r_s"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.BurkertPotential( - rho=convert(pot.rho0(0), APYQuantity), - r0=convert(pot.r_s(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_hernquist(pot: gpx.HernquistPotential, /) -> gp.HernquistPotential: - """Convert a Galax HernquistPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m_tot", "r_s"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.HernquistPotential( - m=convert(pot.m_tot(0), APYQuantity), - c=convert(pot.r_s(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_jaffe(pot: gpx.JaffePotential, /) -> gp.JaffePotential: - """Convert a Galax JaffePotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m", "r_s"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.JaffePotential( - m=convert(pot.m(0), APYQuantity), - c=convert(pot.r_s(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_longmuralibar( - pot: gpx.LongMuraliBarPotential, / -) -> gp.LongMuraliBarPotential: - """Convert a Galax LongMuraliBarPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m_tot", "a", "b", "c", "alpha"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.LongMuraliBarPotential( - m=convert(pot.m_tot(0), APYQuantity), - a=convert(pot.a(0), APYQuantity), - b=convert(pot.b(0), APYQuantity), - c=convert(pot.c(0), APYQuantity), - alpha=convert(pot.alpha(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_null(pot: gpx.NullPotential, /) -> gp.NullPotential: - return gp.NullPotential( - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_satoh(pot: gpx.SatohPotential, /) -> gp.SatohPotential: - """Convert a Galax SatohPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m_tot", "a", "b"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.SatohPotential( - m=convert(pot.m_tot(0), APYQuantity), - a=convert(pot.a(0), APYQuantity), - b=convert(pot.b(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_stoneostriker15( - pot: gpx.StoneOstriker15Potential, / -) -> gp.StonePotential: - """Convert a Galax StoneOstriker15Potential to a Gala potential.""" - if not _all_constant_parameters(pot, "m_tot", "r_c", "r_h"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.StonePotential( - m=convert(pot.m_tot(0), APYQuantity), - r_c=convert(pot.r_c(0), APYQuantity), - r_h=convert(pot.r_h(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -# ----------------------------------------------------------------------------- -# Logarithmic potentials - - -@galax_to_gala.register -def _galax_to_gala_logarithmic( - pot: gpx.LogarithmicPotential, / -) -> gp.LogarithmicPotential: - """Convert a Galax LogarithmicPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "v_c", "r_s"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.LogarithmicPotential( - v_c=convert(pot.v_c(0), APYQuantity), - r_h=convert(pot.r_s(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_logarithmic( - pot: gpx.LMJ09LogarithmicPotential, / -) -> gp.LogarithmicPotential: - """Convert a Galax LogarithmicPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "v_c", "r_s", "q1", "q2", "q3", "phi"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.LogarithmicPotential( - v_c=convert(pot.v_c(0), APYQuantity), - r_h=convert(pot.r_s(0), APYQuantity), - q1=convert(pot.q1(0), APYQuantity), - q2=convert(pot.q2(0), APYQuantity), - q3=convert(pot.q3(0), APYQuantity), - phi=convert(pot.phi(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -# ----------------------------------------------------------------------------- -# NFW potentials - - -@galax_to_gala.register -def _galax_to_gala_nfw(pot: gpx.NFWPotential, /) -> gp.NFWPotential: - """Convert a Galax NFWPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m", "r_s"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - return gp.NFWPotential( - m=convert(pot.m(0), APYQuantity), - r_s=convert(pot.r_s(0), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -@galax_to_gala.register -def _galax_to_gala_leesutotriaxialnfw( - pot: gpx.LeeSutoTriaxialNFWPotential, / -) -> gp.LeeSutoTriaxialNFWPotential: - """Convert a Galax LeeSutoTriaxialNFWPotential to a Gala potential.""" - if not _all_constant_parameters(pot, "m", "r_s", "a1", "a2", "a3"): - msg = "Gala does not support time-dependent parameters." - raise TypeError(msg) - - t = Quantity(0.0, pot.units["time"]) - - return gp.LeeSutoTriaxialNFWPotential( - v_c=convert(xp.sqrt(pot.constants["G"] * pot.m(t) / pot.r_s(t)), APYQuantity), - r_s=convert(pot.r_s(t), APYQuantity), - a=convert(pot.a1(t), APYQuantity), - b=convert(pot.a2(t), APYQuantity), - c=convert(pot.a3(t), APYQuantity), - units=galax_to_gala_units(pot.units), - ) - - -# ----------------------------------------------------------------------------- -# Composite potentials - - -@galax_to_gala.register -def _galax_to_gala_bovymw2014( - pot: gpx.BovyMWPotential2014, / -) -> gp.BovyMWPotential2014: - """Convert a Galax BovyMWPotential2014 to a Gala potential.""" - - def rename(k: str) -> str: - match k: - case "m_tot": - return "m" - case _: - return k - - return gp.BovyMWPotential2014( - **{ - c: {rename(k): getattr(p, k)(0) for k in p.parameters} - for c, p in pot.items() - } - ) - - -@galax_to_gala.register -def _galax_to_gala_lm10(pot: gpx.LM10Potential, /) -> gp.LM10Potential: - """Convert a Galax LM10Potential to a Gala potential.""" - - def rename(c: str, k: str) -> str: - match k: - case "m_tot": - return "m" - case "r_s" if c == "halo": - return "r_h" - case "r_s" if c == "bulge": - return "c" - case _: - return k - - return gp.LM10Potential( - **{ - c: {rename(c, k): getattr(p, k)(0) for k in p.parameters} - for c, p in pot.items() - } - ) - - -@galax_to_gala.register -def _galax_to_gala_mwpotential(pot: gpx.MilkyWayPotential, /) -> gp.MilkyWayPotential: - """Convert a Galax MilkyWayPotential to a Gala potential.""" - - def rename(c: str, k: str) -> str: - match k: - case "m_tot": - return "m" - case "r_s" if c in ("bulge", "nucleus"): - return "c" - case _: - return k - - return gp.MilkyWayPotential( - **{ - c: {rename(c, k): getattr(p, k)(0) for k in p.parameters} - for c, p in pot.items() - } - )