Skip to content

Commit

Permalink
build: use ImmutableMap (#382)
Browse files Browse the repository at this point in the history
* build: use ImmutableMap

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Jul 11, 2024
1 parent 4e29e25 commit ffb7f56
Show file tree
Hide file tree
Showing 22 changed files with 112 additions and 337 deletions.
8 changes: 4 additions & 4 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ input::
>>> mw
MilkyWayPotential({'disk': MiyamotoNagaiPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
constants=ImmutableMap({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
m_tot=ConstantParameter(
unit=Unit("solMass"),
value=Quantity[PhysicalType('mass')](value=f64[], unit=Unit("solMass"))
Expand All @@ -67,7 +67,7 @@ input::
)
), 'halo': NFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
constants=ImmutableMap({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
m=ConstantParameter(
unit=Unit("solMass"),
value=Quantity[PhysicalType('mass')](value=f64[], unit=Unit("solMass"))
Expand All @@ -78,7 +78,7 @@ input::
)
), 'bulge': HernquistPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
constants=ImmutableMap({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
m_tot=ConstantParameter(
unit=Unit("solMass"),
value=Quantity[PhysicalType('mass')](value=f64[], unit=Unit("solMass"))
Expand All @@ -89,7 +89,7 @@ input::
)
), 'nucleus': HernquistPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
constants=ImmutableMap({'G': Quantity['m3 kg-1 s-2'](Array(4.49850215e-12, dtype=float64), unit='kpc3 / (solMass Myr2)')}),
m_tot=ConstantParameter(
unit=Unit("solMass"),
value=Quantity[PhysicalType('mass')](value=f64[], unit=Unit("solMass"))
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"quaxed >= 0.4",
"typing_extensions >= 4.11",
"unxt",
"immutable_map_jax @ git+https://github.com/GalacticDynamics/immutable_map_jax.git",
]
description = "Galactic Dynamics in Jax."
dynamic = ["version"]
Expand Down Expand Up @@ -247,7 +248,7 @@ exempt-modules = []

[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = ["quaxed", "unxt", "coordinax"]
known-first-party = ["immutable_map_jax", "quaxed", "unxt", "coordinax"]
known-local-folder = ["galax"]

[tool.ruff.lint.pydocstyle]
Expand Down
34 changes: 17 additions & 17 deletions src/galax/_galax_interop_gala/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
HernquistPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
r_s=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ) )
Expand All @@ -86,7 +86,7 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
IsochronePotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
b=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ) )
Expand All @@ -96,15 +96,15 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
KeplerPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ) )
>>> gpot = gp.LeeSutoTriaxialNFWPotential(
... v_c=220, r_s=20, a=1, b=0.9, c=0.8, units=gu.galactic )
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
LeeSutoTriaxialNFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
r_s=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ),
a1=ConstantParameter( unit=Unit(dimensionless), value=Quantity[...]( value=f64[], unit=Unit(dimensionless) ) ),
Expand All @@ -127,7 +127,7 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
MiyamotoNagaiPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
a=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ),
b=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ) )
Expand All @@ -138,7 +138,7 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
NFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
r_s=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ) )
Expand All @@ -147,7 +147,7 @@ def gala_to_galax(pot: gp.PotentialBase, /) -> gpx.AbstractPotentialBase:
>>> gpot = gp.NullPotential()
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
NullPotential( units=DimensionlessUnitSystem(),
constants=ImmutableDict({'G': ...}) )
constants=ImmutableMap({'G': ...}) )
""" # noqa: E501
msg = (
"`gala_to_galax` does not have a registered function to convert "
Expand Down Expand Up @@ -236,7 +236,7 @@ def _gala_to_galax_null(pot: gp.NullPotential, /) -> gpx.NullPotential:
>>> gpot = gp.NullPotential()
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
NullPotential( units=DimensionlessUnitSystem(),
constants=ImmutableDict({'G': ...}) )
constants=ImmutableMap({'G': ...}) )
"""
return gpx.NullPotential(units=pot.units)
Expand Down Expand Up @@ -268,7 +268,7 @@ def _gala_to_galax_burkert(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
BurkertPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( ... ),
r_s=ConstantParameter( ... )
)
Expand Down Expand Up @@ -298,7 +298,7 @@ def _gala_to_galax_hernquist(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
HernquistPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( ... ),
r_s=ConstantParameter( ... )
)
Expand All @@ -324,7 +324,7 @@ def _gala_to_galax_jaffe(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
JaffePotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( ... ),
r_s=ConstantParameter( ... )
)
Expand All @@ -350,7 +350,7 @@ def _gala_to_galax_longmuralibar(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
LongMuraliBarPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': Quantity...}),
constants=ImmutableMap({'G': Quantity...}),
m_tot=ConstantParameter( ... ),
a=ConstantParameter( ... ),
b=ConstantParameter( ... ),
Expand Down Expand Up @@ -386,7 +386,7 @@ def _gala_to_galax_satoh(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
SatohPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( ... ),
a=ConstantParameter( ... ),
b=ConstantParameter( ... )
Expand Down Expand Up @@ -415,7 +415,7 @@ def _gala_to_galax_stoneostriker15(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
StoneOstriker15Potential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m_tot=ConstantParameter( ... ),
r_c=ConstantParameter( ... ),
r_h=ConstantParameter( ... )
Expand Down Expand Up @@ -454,7 +454,7 @@ def _gala_to_galax_logarithmic(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
LogarithmicPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
v_c=ConstantParameter( ... ),
r_s=ConstantParameter( ... )
)
Expand Down Expand Up @@ -504,7 +504,7 @@ def _gala_to_galax_nfw(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
NFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
r_s=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) )
)
Expand Down Expand Up @@ -532,7 +532,7 @@ def _gala_to_galax_leesutotriaxialnfw(
>>> gpx.io.convert_potential(gpx.io.GalaLibrary, gpot)
LeeSutoTriaxialNFWPotential(
units=UnitSystem(kpc, Myr, solMass, rad),
constants=ImmutableDict({'G': ...}),
constants=ImmutableMap({'G': ...}),
m=ConstantParameter( unit=Unit("solMass"), value=Quantity[...](value=f64[], unit=Unit("solMass")) ),
r_s=ConstantParameter( unit=Unit("kpc"), value=Quantity[...](value=f64[], unit=Unit("kpc")) ),
a1=ConstantParameter( unit=Unit(dimensionless), value=Quantity[...]( value=f64[], unit=Unit(dimensionless) ) ),
Expand Down
12 changes: 6 additions & 6 deletions src/galax/coordinates/_psp/base_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@

import coordinax as cx
import quaxed.numpy as jnp
from immutable_map_jax import ImmutableMap
from unxt import Quantity

import galax.typing as gt
from .base import AbstractBasePhaseSpacePosition, ComponentShapeTuple
from galax.utils import ImmutableDict
from galax.utils._misc import zeroth
from galax.utils.dataclasses import dataclass_items

if TYPE_CHECKING:
from typing import Self


# Note: cannot have `strict=True` because of inheriting from ImmutableDict.
# Note: cannot have `strict=True` because of inheriting from ImmutableMap.
class AbstractCompositePhaseSpacePosition(
ImmutableDict[AbstractBasePhaseSpacePosition], # TODO: as a TypeVar
ImmutableMap[str, AbstractBasePhaseSpacePosition], # type: ignore[misc]
AbstractBasePhaseSpacePosition,
strict=False, # type: ignore[call-arg]
):
Expand All @@ -38,8 +38,8 @@ class AbstractCompositePhaseSpacePosition(
represents a component of the system.
The input signature matches that of :class:`dict` (and
:class:`~galax.utils.ImmutableDict`), so you can pass in the components as
keyword arguments or as a dictionary.
:class:`~immutable_map_jax.ImmutableMap`), so you can pass in the components
as keyword arguments or as a dictionary.
The components are stored as a dictionary and can be key accessed. However,
the composite phase-space position itself acts as a single
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
/,
**kwargs: AbstractBasePhaseSpacePosition,
) -> None:
super().__init__(psps, **kwargs) # <- ImmutableDict.__init__
super().__init__(psps, **kwargs) # <- ImmutableMap.__init__

@property
@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions src/galax/dynamics/_dynamics/integrate/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from diffrax import DenseInterpolation
from jax._src.numpy.vectorize import _parse_gufunc_signature, _parse_input_dimensions

from immutable_map_jax import ImmutableMap
from unxt import AbstractUnitSystem, Quantity, unitsystem

import galax.coordinates as gc
import galax.typing as gt
from .api import VectorField
from .base import AbstractIntegrator
from galax.utils import ImmutableDict

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -266,10 +266,10 @@ class DiffraxIntegrator(AbstractIntegrator):
diffeq_kw: Mapping[str, Any] = eqx.field(
default=(("max_steps", None), ("discrete_terminating_event", None)),
static=True,
converter=ImmutableDict,
converter=ImmutableMap,
)
solver_kw: Mapping[str, Any] = eqx.field(
default=(("scan_kind", "bounded"),), static=True, converter=ImmutableDict
default=(("scan_kind", "bounded"),), static=True, converter=ImmutableMap
)

InterpolantClass: ClassVar[type[gc.PhaseSpacePositionInterpolant]] = ( # type: ignore[misc]
Expand Down
8 changes: 4 additions & 4 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import quaxed.array_api as xp
import quaxed.numpy as qnp
import unxt
from immutable_map_jax import ImmutableMap
from unxt import AbstractUnitSystem, Quantity

import galax.typing as gt
Expand All @@ -29,7 +30,6 @@
from galax.coordinates import PhaseSpacePosition
from galax.potential._potential.params.attr import ParametersAttribute
from galax.potential._potential.params.utils import all_parameters, all_vars
from galax.utils._collections import ImmutableDict
from galax.utils._jax import vectorize_method
from galax.utils.dataclasses import ModuleMeta

Expand All @@ -41,7 +41,7 @@
HessianVec: TypeAlias = Shaped[Quantity["1/s^2"], "*#shape 3 3"] # TODO: shape -> batch


default_constants = ImmutableDict({"G": Quantity(_CONST_G.value, _CONST_G.unit)})
default_constants = ImmutableMap({"G": Quantity(_CONST_G.value, _CONST_G.unit)})


##############################################################################
Expand All @@ -56,7 +56,7 @@ class AbstractPotentialBase(eqx.Module, metaclass=ModuleMeta, strict=True): # t
units: eqx.AbstractVar[AbstractUnitSystem]
"""The unit system of the potential."""

constants: eqx.AbstractVar[ImmutableDict[Quantity]]
constants: eqx.AbstractVar[ImmutableMap[str, Quantity]]
"""The constants used by the potential."""

def __init_subclass__(cls, **kwargs: Any) -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ def _init_units(self) -> None:

# Do unit conversion for the constants
if self.units != unxt.unitsystems.dimensionless:
constants = ImmutableDict(
constants = ImmutableMap(
{k: v.decompose(self.units) for k, v in self.constants.items()}
)
object.__setattr__(self, "constants", constants)
Expand Down
6 changes: 3 additions & 3 deletions src/galax/potential/_potential/builtin/bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
import jax

import quaxed.array_api as xp
from immutable_map_jax import ImmutableMap
from unxt import AbstractUnitSystem, Quantity, unitsystem

import galax.typing as gt
from galax.potential._potential.base import default_constants
from galax.potential._potential.core import AbstractPotential
from galax.potential._potential.params.core import AbstractParameter
from galax.potential._potential.params.field import ParameterField
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method

# -------------------------------------------------------------------
Expand All @@ -44,8 +44,8 @@ class BarPotential(AbstractPotential):

_: KW_ONLY
units: AbstractUnitSystem = eqx.field(converter=unitsystem, static=True)
constants: ImmutableDict[Quantity] = eqx.field(
default=default_constants, converter=ImmutableDict
constants: ImmutableMap[str, Quantity] = eqx.field(
default=default_constants, converter=ImmutableMap
)

@partial(jax.jit)
Expand Down
Loading

0 comments on commit ffb7f56

Please sign in to comment.