Skip to content

Commit

Permalink
feat(psp): constructors (#270)
Browse files Browse the repository at this point in the history
* feat(psp): constructors

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Apr 27, 2024
1 parent 2b2d303 commit 6b2264b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
12 changes: 11 additions & 1 deletion src/galax/coordinates/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

__all__: list[str] = []

from typing import cast

try: # TODO: less hacky way of supporting optional dependencies
import pytest
except ImportError:
except ImportError: # pragma: no cover
pass
else:
_ = pytest.importorskip("gala")
Expand All @@ -23,3 +25,11 @@
def gala_psp_to_galax_psp(obj: gd.PhaseSpacePosition, /) -> gcx.PhaseSpacePosition:
"""`gala.dynamics.PhaseSpacePosition` -> `galax.coordinates.PhaseSpacePosition`."""
return gcx.PhaseSpacePosition(q=obj.pos, p=obj.vel, t=None)


@gcx.PhaseSpacePosition.constructor._f.register # type: ignore[misc] # noqa: SLF001
def constructor(
_: type[gcx.PhaseSpacePosition], obj: gd.PhaseSpacePosition, /
) -> gcx.PhaseSpacePosition:
"""Construct a :mod:`galax` PhaseSpacePosition from a :mod:`gala` one."""
return cast(gcx.PhaseSpacePosition, gala_psp_to_galax_psp(obj))
113 changes: 113 additions & 0 deletions src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = ["AbstractPhaseSpacePosition", "ComponentShapeTuple"]

from abc import abstractmethod
from collections.abc import Mapping
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Any, NamedTuple, cast
Expand All @@ -18,6 +19,7 @@

import galax.typing as gt
from .utils import getitem_broadscalartime_index
from galax.utils.dataclasses import dataclass_items

if TYPE_CHECKING:
from typing import Self
Expand Down Expand Up @@ -65,6 +67,50 @@ class AbstractPhaseSpacePosition(eqx.Module, strict=True): # type: ignore[call-
t: eqx.AbstractVar[gt.BroadBatchFloatQScalar]
"""Time corresponding to the positions and momenta."""

# ---------------------------------------------------------------
# Constructors

@classmethod
@dispatch # type: ignore[misc]
def constructor(
cls: "type[AbstractPhaseSpacePosition]", obj: Mapping[str, Any], /
) -> "AbstractPhaseSpacePosition":
"""Construct from a mapping.
Parameters
----------
cls : type[:class:`~galax.coordinates.AbstractPhaseSpacePosition`]
The class to construct.
obj : Mapping[str, Any]
The mapping from which to construct.
Returns
-------
:class:`~galax.coordinates.AbstractPhaseSpacePosition`
The constructed phase-space position.
Examples
--------
With the following imports:
>>> from unxt import Quantity
>>> from galax.coordinates import PhaseSpacePosition
We can create a phase-space position from a mapping:
>>> obj = {"q": Quantity([1, 2, 3], "kpc"),
... "p": Quantity([4, 5, 6], "km/s"),
... "t": Quantity(0, "Gyr")}
>>> PhaseSpacePosition.constructor(obj)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[...](value=f64[], unit=Unit("Gyr"))
)
"""
return cls(**obj)

# ==========================================================================
# Array properties

Expand Down Expand Up @@ -534,7 +580,74 @@ def angular_momentum(self) -> gt.BatchQVec3:
# =============================================================================
# helper functions

# -----------------------------------------------
# Register additional constructors


@AbstractPhaseSpacePosition.constructor._f.register # type: ignore[misc] # noqa: SLF001
def constructor(
cls: type[AbstractPhaseSpacePosition], obj: AbstractPhaseSpacePosition, /
) -> AbstractPhaseSpacePosition:
"""Construct from a `AbstractPhaseSpacePosition`.
Parameters
----------
cls : type[:class:`~galax.coordinates.AbstractPhaseSpacePosition`]
The class to construct.
obj : :class:`~galax.coordinates.AbstractPhaseSpacePosition`
The phase-space position object from which to construct.
Returns
-------
:class:`~galax.coordinates.AbstractPhaseSpacePosition`
The constructed phase-space position.
Raises
------
TypeError
If the input object is not an instance of the target class.
Examples
--------
With the following imports:
>>> from unxt import Quantity
>>> import coordinax as cx
>>> from galax.coordinates import PhaseSpacePosition
We can create a phase-space position and construct a new one from it:
>>> psp = PhaseSpacePosition(q=Quantity([1, 2, 3], "kpc"),
... p=Quantity([4, 5, 6], "km/s"),
... t=Quantity(0, "Gyr"))
>>> PhaseSpacePosition.constructor(psp) is psp
True
Note that the constructed object is the same as the input object because
the types are the same. If we define a new class that inherits from
:class:`~galax.coordinates.PhaseSpacePosition`, we can construct a
new object from the input object that is an instance of the new class:
>>> class NewPhaseSpacePosition(PhaseSpacePosition): pass
>>> new_psp = NewPhaseSpacePosition.constructor(psp)
>>> new_psp is psp
False
>>> isinstance(new_psp, NewPhaseSpacePosition)
True
"""
# TODO: add isinstance checks

# Avoid copying if the types are the same. Isinstance is not strict
# enough, so we use type() instead.
if type(obj) is cls: # pylint: disable=unidiomatic-typecheck
return obj

return cls(**dict(dataclass_items(obj)))


# -----------------------------------------------
# Register AbstractPhaseSpacePosition with `coordinax.represent_as`
@dispatch # type: ignore[misc]
def represent_as(
psp: AbstractPhaseSpacePosition,
Expand Down
23 changes: 22 additions & 1 deletion src/galax/dynamics/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

__all__: list[str] = []

from typing import cast

try: # TODO: less hacky way of supporting optional dependencies
import pytest
except ImportError:
except ImportError: # pragma: no cover
pass
else:
_ = pytest.importorskip("gala")
Expand All @@ -18,14 +20,33 @@

import galax.dynamics as gdx

# =============================================================================
# Orbit


@conversion_method(type_from=gd.Orbit, type_to=gdx.Orbit) # type: ignore[misc]
def gala_orbit_to_galax_orbit(obj: gd.Orbit, /) -> gdx.Orbit:
"""`gala.dynamics.Orbit` -> `galax.dynamics.Orbit`."""
return gdx.Orbit(q=obj.pos, p=obj.vel, t=obj.t)


@gdx.Orbit.constructor._f.register # type: ignore[misc] # noqa: SLF001
def constructor(_: type[gdx.Orbit], obj: gd.Orbit, /) -> gdx.Orbit:
"""Construct a :mod:`galax` Orbit from a :mod:`gala` one."""
return cast(gdx.Orbit, gala_orbit_to_galax_orbit(obj))


# =============================================================================
# MockStream


@conversion_method(type_from=gd.MockStream, type_to=gdx.MockStream) # type: ignore[misc]
def gala_mockstream_to_galax_mockstream(obj: gd.MockStream, /) -> gdx.MockStream:
"""`gala.dynamics.MockStream` -> `galax.dynamics.MockStream`."""
return gdx.MockStream(q=obj.pos, p=obj.vel, release_time=obj.release_time)


@gdx.MockStream.constructor._f.register # type: ignore[misc] # noqa: SLF001
def constructor(_: type[gdx.MockStream], obj: gd.MockStream, /) -> gdx.MockStream:
"""Construct a :mod:`galax` MockStream from a :mod:`gala` one."""
return cast(gdx.MockStream, gala_mockstream_to_galax_mockstream(obj))
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/io/_gala.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

try: # TODO: less hacky way of supporting optional dependencies
import pytest
except ImportError:
except ImportError: # pragma: no cover
pass
else:
_ = pytest.importorskip("gala")
Expand Down

0 comments on commit 6b2264b

Please sign in to comment.