Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multipole potentials #357

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ filterwarnings = [
"ignore:ast\\.Str is deprecated:DeprecationWarning",
"ignore:numpy\\.ndarray size changed:RuntimeWarning",
"ignore:Passing arguments 'a':DeprecationWarning", # TODO: from diffrax
"ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning",
]
log_cli_level = "INFO"
markers = [
Expand Down
137 changes: 136 additions & 1 deletion src/galax/_interop/galax_interop_gala/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import equinox as eqx
import gala.potential as gp
import jax.numpy as jnp
from astropy.units import Quantity as APYQuantity
from gala.units import (
DimensionlessUnitSystem as GalaDimensionlessUnitSystem,
Expand Down Expand Up @@ -1170,10 +1171,144 @@


# -----------------------------------------------------------------------------
# NFW potentials
# Multipole potentials


@dispatch # type: ignore[misc]
def gala_to_galax(
gala: gp.MultipolePotential, /
) -> gpx.MultipoleInnerPotential | gpx.MultipoleOuterPotential | gpx.PotentialFrame:
params = gala.parameters
cls = (
gpx.MultipoleInnerPotential
if params["inner"] == 1
else gpx.MultipoleOuterPotential
)

l_max = gala._lmax # noqa: SLF001
Slm = jnp.zeros((l_max + 1, l_max + 1), dtype=float)
Tlm = jnp.zeros_like(Slm)

for l, m in zip(*jnp.tril_indices(l_max + 1), strict=True):
skey = f"S{l}{m}"
if skey in params:
Slm = Slm.at[l, m].set(params[skey])

tkey = f"T{l}{m}"
if tkey in params:
Tlm = Tlm.at[l, m].set(params[tkey])

pot = cls(
m_tot=params["m"],
r_s=params["r_s"],
l_max=l_max,
Slm=Slm,
Tlm=Tlm,
units=gala.units,
)
return _apply_frame(_get_frame(gala), pot)


@dispatch.multi((gpx.MultipoleInnerPotential,), (gpx.MultipoleOuterPotential,)) # type: ignore[misc]
def galax_to_gala(
pot: gpx.MultipoleInnerPotential | gpx.MultipoleOuterPotential, /
) -> gp.MultipolePotential:
"""Convert a Galax Multipole to a Gala potential."""
_error_if_not_all_constant_parameters(pot, "m_tot", "r_s", "Slm", "Tlm")

Slm, Tlm = pot.Slm(0).value, pot.Tlm(0).value
ls, ms = jnp.tril_indices(pot.l_max + 1)

return gp.MultipolePotential(
m=convert(pot.m_tot(0), APYQuantity),
r_s=convert(pot.r_s(0), APYQuantity),
lmax=pot.l_max,
**{
f"S{l}{m}": Slm[l, m] for l, m in zip(ls, ms, strict=True) if Slm[l, m] != 0
},
**{
f"T{l}{m}": Tlm[l, m] for l, m in zip(ls, ms, strict=True) if Tlm[l, m] != 0
},
inner=isinstance(pot, gpx.MultipoleInnerPotential),
units=_galax_to_gala_units(pot.units),
)


# -----------------------------------------------------------------------------
# NFW potentials


@dispatch
def gala_to_galax(gala: gp.NFWPotential, /) -> gpx.NFWPotential | gpx.PotentialFrame:
"""Convert a Gala NFWPotential to a Galax potential.

Examples
--------
>>> import gala.potential as gp
>>> import gala.units as gu
>>> import galax.potential as gpx

>>> gpot = gp.NFWPotential(m=1e12, r_s=20, units=gu.galactic)
>>> gpx.io.convert_potential(gpx.io.GalaxLibrary, gpot)
NFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( ... ),
r_s=ConstantParameter( ... )
)

"""
params = gala.parameters
pot = gpx.NFWPotential(m=params["m"], r_s=params["r_s"], units=gala.units)
return _apply_frame(_get_frame(gala), pot)

Check warning on line 1263 in src/galax/_interop/galax_interop_gala/potential.py

View check run for this annotation

Codecov / codecov/patch

src/galax/_interop/galax_interop_gala/potential.py#L1261-L1263

Added lines #L1261 - L1263 were not covered by tests


@dispatch
def gala_to_galax(
pot: gp.LeeSutoTriaxialNFWPotential, /
) -> gpx.LeeSutoTriaxialNFWPotential:
"""Convert a :class:`gala.potential.LeeSutoTriaxialNFWPotential` to a :class:`galax.potential.LeeSutoTriaxialNFWPotential`.

Examples
--------
>>> import gala.potential as gp
>>> import gala.units as gu
>>> import galax.potential as gpx

>>> gpot = gp.LeeSutoTriaxialNFWPotential(
... v_c=220, r_s=20, a=1, b=0.9, c=0.8, units=gu.galactic )
>>> gpx.io.convert_potential(gpx.io.GalaxLibrary, gpot)
LeeSutoTriaxialNFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( ... ),
r_s=ConstantParameter( ... ),
a1=ConstantParameter( ... ),
a2=ConstantParameter( ... ),
a3=ConstantParameter( ... )
)

""" # noqa: E501
units = pot.units
params = pot.parameters
G = Quantity(pot.G, units["length"] ** 3 / units["time"] ** 2 / units["mass"])

Check warning on line 1294 in src/galax/_interop/galax_interop_gala/potential.py

View check run for this annotation

Codecov / codecov/patch

src/galax/_interop/galax_interop_gala/potential.py#L1292-L1294

Added lines #L1292 - L1294 were not covered by tests

return gpx.LeeSutoTriaxialNFWPotential(

Check warning on line 1296 in src/galax/_interop/galax_interop_gala/potential.py

View check run for this annotation

Codecov / codecov/patch

src/galax/_interop/galax_interop_gala/potential.py#L1296

Added line #L1296 was not covered by tests
m=params["v_c"] ** 2 * params["r_s"] / G,
r_s=params["r_s"],
a1=params["a"],
a2=params["b"],
a3=params["c"],
units=units,
constants={"G": G},
)


# -----------------------------------------------------------------------------
# NFW potentials


@dispatch
def gala_to_galax(gala: gp.NFWPotential, /) -> gpx.NFWPotential | gpx.PotentialFrame:
"""Convert a Gala NFWPotential to a Galax potential.

Expand Down
11 changes: 11 additions & 0 deletions src/galax/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
# logarithmic
"LogarithmicPotential",
"LMJ09LogarithmicPotential",
# multipole
"AbstractMultipolePotential",
"MultipoleInnerPotential",
"MultipoleOuterPotential",
"MultipolePotential",
# nfw
"NFWPotential",
"LeeSutoTriaxialNFWPotential",
Expand Down Expand Up @@ -79,6 +84,12 @@
LMJ09LogarithmicPotential,
LogarithmicPotential,
)
from ._potential.builtin.multipole import (
AbstractMultipolePotential,
MultipoleInnerPotential,
MultipoleOuterPotential,
MultipolePotential,
)
from ._potential.builtin.nfw import (
LeeSutoTriaxialNFWPotential,
NFWPotential,
Expand Down
4 changes: 3 additions & 1 deletion src/galax/potential/_potential/builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""``galax`` Potentials."""
# ruff:noqa: F401

from . import bars, builtin, logarithmic, nfw, special
from . import bars, builtin, logarithmic, multipole, nfw, special
from .bars import *
from .builtin import *
from .logarithmic import *
from .multipole import *
from .nfw import *
from .special import *

__all__: list[str] = []
__all__ += builtin.__all__
__all__ += bars.__all__
__all__ += logarithmic.__all__
__all__ += multipole.__all__
__all__ += nfw.__all__
__all__ += special.__all__
Loading