Skip to content

Commit

Permalink
style: cleanup gala_helper
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Apr 28, 2024
1 parent 5316add commit d63f8bd
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions tests/unit/potential/io/gala_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,16 @@

from functools import singledispatch

import gala.potential as gp
from astropy.units import Quantity as APYQuantity
from gala.potential import (
CompositePotential as GalaCompositePotential,
LeeSutoTriaxialNFWPotential as GalaLeeSutoTriaxialNFWPotential,
MilkyWayPotential as GalaMilkyWayPotential,
NFWPotential as GalaNFWPotential,
NullPotential as GalaNullPotential,
PotentialBase as GalaPotentialBase,
)
from gala.units import UnitSystem as GalaUnitSystem, dimensionless as gala_dimensionless
from plum import convert

import quaxed.array_api as xp
from unxt import Quantity
from unxt.unitsystems import AbstractUnitSystem, DimensionlessUnitSystem

import galax.potential as gp
import galax.potential as gpx
from galax.potential._potential.io._gala import _GALA_TO_GALAX_REGISTRY

##############################################################################
Expand All @@ -38,17 +31,17 @@ def galax_to_gala_units(units: AbstractUnitSystem, /) -> GalaUnitSystem:


def _all_constant_parameters(
pot: "gp.AbstractPotentialBase",
pot: "gpx.AbstractPotentialBase",
*params: str,
) -> bool:
return all(isinstance(getattr(pot, name), gp.ConstantParameter) for name in params)
return all(isinstance(getattr(pot, name), gpx.ConstantParameter) for name in params)


# TODO: add an argument to specify how to handle time-dependent parameters.
# Gala potentials are not time-dependent, so we need to specify how to
# handle time-dependent Galax parameters.
@singledispatch
def galax_to_gala(pot: gp.AbstractPotentialBase, /) -> GalaPotentialBase:
def galax_to_gala(pot: gpx.AbstractPotentialBase, /) -> gp.PotentialBase:
"""Convert a Galax potential to a Gala potential.
Parameters
Expand All @@ -73,22 +66,23 @@ def galax_to_gala(pot: gp.AbstractPotentialBase, /) -> GalaPotentialBase:


@galax_to_gala.register
def _galax_to_gala_composite(pot: gp.CompositePotential, /) -> GalaCompositePotential:
def _galax_to_gala_composite(pot: gpx.CompositePotential, /) -> gp.CompositePotential:
"""Convert a Galax CompositePotential to a Gala potential."""
return GalaCompositePotential(**{k: galax_to_gala(p) for k, p in pot.items()})
return gp.CompositePotential(**{k: galax_to_gala(p) for k, p in pot.items()})


_GALAX_TO_GALA_REGISTRY: dict[type[gp.AbstractPotential], type[GalaPotentialBase]] = {
_GALAX_TO_GALA_REGISTRY: dict[type[gpx.AbstractPotential], type[gp.PotentialBase]] = {
v: k for k, v in _GALA_TO_GALAX_REGISTRY.items()
}


@galax_to_gala.register(gp.HernquistPotential)
@galax_to_gala.register(gp.IsochronePotential)
@galax_to_gala.register(gp.KeplerPotential)
@galax_to_gala.register(gp.KuzminPotential)
@galax_to_gala.register(gp.MiyamotoNagaiPotential)
def _galax_to_gala_abstractpotential(pot: gp.AbstractPotential, /) -> GalaPotentialBase:
@galax_to_gala.register(gpx.HarmonicOscillatorPotential)
@galax_to_gala.register(gpx.HernquistPotential)
@galax_to_gala.register(gpx.IsochronePotential)
@galax_to_gala.register(gpx.KeplerPotential)
@galax_to_gala.register(gpx.KuzminPotential)
@galax_to_gala.register(gpx.MiyamotoNagaiPotential)
def _galax_to_gala_abstractpotential(pot: gpx.AbstractPotential, /) -> gp.PotentialBase:
"""Convert a Galax AbstractPotential to a Gala potential."""
if not _all_constant_parameters(pot, *pot.parameters.keys()):
msg = "Gala does not support time-dependent parameters."
Expand All @@ -113,24 +107,24 @@ def _galax_to_gala_abstractpotential(pot: gp.AbstractPotential, /) -> GalaPotent


@galax_to_gala.register
def _galax_to_gala_bar(pot: gp.BarPotential, /) -> GalaPotentialBase:
def _galax_to_gala_bar(pot: gpx.BarPotential, /) -> gp.PotentialBase:
"""Convert a Galax BarPotential to a Gala potential."""
raise NotImplementedError # TODO: implement


@galax_to_gala.register
def _galax_to_gala_null(_: gp.NullPotential, /) -> GalaNullPotential:
return GalaNullPotential(units=gala_dimensionless)
def _galax_to_gala_null(_: gpx.NullPotential, /) -> gp.NullPotential:
return gp.NullPotential(units=gala_dimensionless)


@galax_to_gala.register
def _galax_to_gala_nfw(pot: gp.NFWPotential, /) -> GalaNFWPotential:
def _galax_to_gala_nfw(pot: gpx.NFWPotential, /) -> gp.NFWPotential:
"""Convert a Galax NFWPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m", "r_s"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

return GalaNFWPotential(
return gp.NFWPotential(
m=convert(pot.m(0), APYQuantity),
r_s=convert(pot.r_s(0), APYQuantity),
units=galax_to_gala_units(pot.units),
Expand All @@ -139,16 +133,16 @@ def _galax_to_gala_nfw(pot: gp.NFWPotential, /) -> GalaNFWPotential:

@galax_to_gala.register
def _galax_to_gala_leesutotriaxialnfw(
pot: gp.LeeSutoTriaxialNFWPotential, /
) -> GalaLeeSutoTriaxialNFWPotential:
pot: gpx.LeeSutoTriaxialNFWPotential, /
) -> gp.LeeSutoTriaxialNFWPotential:
"""Convert a Galax LeeSutoTriaxialNFWPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m", "r_s", "a1", "a2", "a3"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

t = Quantity(0.0, pot.units["time"])

return GalaLeeSutoTriaxialNFWPotential(
return gp.LeeSutoTriaxialNFWPotential(
v_c=convert(xp.sqrt(pot.constants["G"] * pot.m(t) / pot.r_s(t)), APYQuantity),
r_s=convert(pot.r_s(t), APYQuantity),
a=convert(pot.a1(t), APYQuantity),
Expand All @@ -159,7 +153,7 @@ def _galax_to_gala_leesutotriaxialnfw(


@galax_to_gala.register
def _galax_to_gala_mwpotential(pot: gp.MilkyWayPotential, /) -> GalaMilkyWayPotential:
def _galax_to_gala_mwpotential(pot: gpx.MilkyWayPotential, /) -> gp.MilkyWayPotential:
"""Convert a Gala MilkyWayPotential to a Galax potential."""

def rename(k: str) -> str:
Expand All @@ -169,7 +163,7 @@ def rename(k: str) -> str:
case _:
return k

return GalaMilkyWayPotential(
return gp.MilkyWayPotential(
disk={rename(k): getattr(pot["disk"], k)(0) for k in ("m_tot", "a", "b")},
halo={rename(k): getattr(pot["halo"], k)(0) for k in ("m", "r_s")},
bulge={rename(k): getattr(pot["bulge"], k)(0) for k in ("m_tot", "c")},
Expand Down

0 comments on commit d63f8bd

Please sign in to comment.