Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: orbit metadata #259

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/galax/dynamics/_dynamics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__all__ = ["AbstractOrbit"]

from dataclasses import replace
from dataclasses import KW_ONLY, replace
from functools import partial
from typing import TYPE_CHECKING, Any, overload

Expand All @@ -11,6 +11,7 @@
import jax.numpy as jnp

from coordinax import AbstractPosition3D, AbstractVelocity3D
from immutable_map_jax import ImmutableMap
from unxt import Quantity

import galax.coordinates as gc
Expand Down Expand Up @@ -50,6 +51,17 @@ class AbstractOrbit(gc.AbstractPhaseSpacePosition):
potential: AbstractPotentialBase
"""Potential in which the orbit was integrated."""

_: KW_ONLY

meta: ImmutableMap[Any] = eqx.field(
default_factory=dict,
converter=ImmutableMap,
static=True,
repr=False,
compare=False,
)
"""Metadata about the orbit."""

def __post_init__(self) -> None:
"""Post-initialization."""
# Need to ensure t shape is correct. Can be Vec0.
Expand Down
23 changes: 18 additions & 5 deletions src/galax/dynamics/_dynamics/integrate/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Any, Literal

import jax
import jax.numpy as jnp
from jax.lax import stop_gradient
from jax.numpy import vectorize as jax_vectorize

import quaxed.array_api as xp
Expand All @@ -29,14 +31,15 @@
_select_w0 = jax_vectorize(jax.lax.select, signature="(),(6),(6)->(6)")


# @partial(jax.jit, static_argnames=("integrator", "interpolated"))
# @partial(jax.jit, static_argnames=("integrator", "interpolated", "include_meta"))
def evaluate_orbit(
pot: gp.AbstractPotentialBase,
w0: gc.PhaseSpacePosition | gt.BatchVec6,
t: Any,
*,
integrator: Integrator | None = None,
interpolated: Literal[True, False] = False,
include_meta: Literal[True, False] = False,
) -> Orbit | InterpolatedOrbit:
"""Compute an orbit in a potential.

Expand Down Expand Up @@ -83,10 +86,16 @@ def evaluate_orbit(
is used twice: once to integrate from `w0.t` to `t[0]` and then from
`t[0]` to `t[1]`.

interpolated: bool, optional keyword-only
interpolated : bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the orbit
at the requested times. Default is `False`.

include_meta : bool, optional keyword-only
Metadata is attached as an :class:`~immutable_map_jax.ImmutableMap`.
If `True`, the metadata is populated with:

- `'has_t0'`: Whether `w0` has time information.

Returns
-------
orbit : :class:`~galax.dynamics.Orbit`
Expand Down Expand Up @@ -189,18 +198,21 @@ def evaluate_orbit(
t = jnp.atleast_1d(Quantity.constructor(t, units["time"]))

# Parse w0
has_t0: bool
psp0t: Quantity
if isinstance(w0, gc.PhaseSpacePosition):
# TODO: warn if w0.t is None?
psp0 = w0
psp0t = t[0] if w0.t is None else w0.t
has_t0 = w0.t is not None
else:
psp0 = gc.PhaseSpacePosition(
q=Quantity(w0[..., 0:3], units["length"]),
p=Quantity(w0[..., 3:6], units["speed"]),
t=t[0],
)
psp0t = t[0]
has_t0 = False

# -------------
# Initial integration
Expand Down Expand Up @@ -236,12 +248,13 @@ def evaluate_orbit(
wt = t

# Construct the orbit object
# TODO: easier construction from the (Interpolated)PhaseSpacePosition
# "integrator": stop_gradient(integrator)
meta = {"has_t0": stop_gradient(has_t0)} if include_meta else {}
if interpolated:
out = InterpolatedOrbit(
q=ws.q, p=ws.p, t=wt, interpolant=ws.interpolant, potential=pot
q=ws.q, p=ws.p, t=wt, interpolant=ws.interpolant, potential=pot, meta=meta
)
else:
out = Orbit(q=ws.q, p=ws.p, t=wt, potential=pot)
out = Orbit(q=ws.q, p=ws.p, t=wt, potential=pot, meta=meta)

return out
19 changes: 19 additions & 0 deletions src/galax/dynamics/_dynamics/mockstream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import coordinax as cx
import quaxed.array_api as xp
from coordinax import AbstractPosition3D, AbstractVelocity3D
from immutable_map_jax import ImmutableMap
from unxt import Quantity

import galax.typing as gt
Expand Down Expand Up @@ -60,6 +61,15 @@ class MockStreamArm(AbstractPhaseSpacePosition):
release_time: gt.QVecTime = eqx.field(converter=Quantity["time"].constructor)
"""Release time of the stream particles [Myr]."""

meta: ImmutableMap[Any] = eqx.field(
default_factory=dict,
converter=ImmutableMap,
static=True,
repr=False,
compare=False,
)
"""Metadata about the mock-stream arm."""

# ==========================================================================
# Array properties

Expand Down Expand Up @@ -93,6 +103,15 @@ def __getitem__(self, index: Any) -> "Self":
class MockStream(AbstractCompositePhaseSpacePosition):
_time_sorter: Shaped[Array, "alltimes"]

meta: ImmutableMap[Any] = eqx.field(
default_factory=dict,
converter=ImmutableMap,
static=True,
repr=False,
compare=False,
)
"""Metadata about the mock stream."""

def __init__(
self,
psps: dict[str, MockStreamArm] | tuple[tuple[str, MockStreamArm], ...] = (),
Expand Down
26 changes: 22 additions & 4 deletions src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def one_pt_intg(
def integ_ics(ics: gt.Vec6) -> gt.VecN:
# TODO: only return the final state
return evaluate_orbit(
self.potential, ics, tstep, integrator=self.stream_integrator
self.potential,
ics,
tstep,
integrator=self.stream_integrator,
include_meta=False,
).w(units=self.units)[-1]

# vmap integration over leading and trailing arm
Expand Down Expand Up @@ -138,10 +142,18 @@ def one_pt_intg(
) -> tuple[gt.Vec6, gt.Vec6]:
tstep = xp.asarray([ts[i], t_f])
w_lead = evaluate_orbit(
self.potential, w0_l_i, tstep, integrator=self.stream_integrator
self.potential,
w0_l_i,
tstep,
integrator=self.stream_integrator,
include_meta=False,
).w(units=self.potential.units)[-1]
w_trail = evaluate_orbit(
self.potential, w0_t_i, tstep, integrator=self.stream_integrator
self.potential,
w0_t_i,
tstep,
integrator=self.stream_integrator,
include_meta=False,
).w(units=self.potential.units)[-1]
return w_lead, w_trail

Expand All @@ -151,7 +163,7 @@ def one_pt_intg(
lead_arm_w, trail_arm_w = jax.vmap(one_pt_intg)(pt_ids, w0_lead, w0_trail)
return lead_arm_w, trail_arm_w

@partial(jax.jit, static_argnames=("vmapped",))
@partial(jax.jit, static_argnames=("vmapped", "include_meta"))
def run(
self,
rng: PRNGKeyArray,
Expand All @@ -160,6 +172,7 @@ def run(
prog_mass: gt.FloatQScalar | ProgenitorMassCallable,
*,
vmapped: bool | None = None,
include_meta: bool = False,
) -> tuple[MockStream, gc.PhaseSpacePosition]:
"""Generate mock stellar stream.

Expand Down Expand Up @@ -197,6 +210,9 @@ def run(
`None` (default), then `jax.vmap` is used on GPU and `jax.lax.scan`
otherwise.

include_meta : bool, optional keyword-only
Whether to include metadata in the output.

Returns
-------
mockstream : :class:`galax.dynamcis.MockStreamArm`
Expand Down Expand Up @@ -246,12 +262,14 @@ def run(
p=Quantity(lead_arm_w[:, 3:6], self.units["speed"]),
t=t,
release_time=mock0["lead"].release_time,
meta={"use_vmap": use_vmap} if include_meta else {},
)
comps["trail"] = MockStreamArm(
q=Quantity(trail_arm_w[:, 0:3], self.units["length"]),
p=Quantity(trail_arm_w[:, 3:6], self.units["speed"]),
t=t,
release_time=mock0["trail"].release_time,
meta={"use_vmap": use_vmap} if include_meta else {},
)

return MockStream(comps), prog_o[-1]
26 changes: 25 additions & 1 deletion src/galax/dynamics/_dynamics/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

__all__ = ["Orbit", "InterpolatedOrbit"]

from typing import final
from dataclasses import KW_ONLY
from typing import Any, final

import equinox as eqx

from coordinax import AbstractPosition3D, AbstractVelocity3D
from immutable_map_jax import ImmutableMap
from unxt import Quantity

import galax.potential as gp
Expand Down Expand Up @@ -38,6 +40,17 @@ class Orbit(AbstractOrbit):
potential: gp.AbstractPotentialBase
"""Potential in which the orbit was integrated."""

_: KW_ONLY

meta: ImmutableMap[Any] = eqx.field(
default_factory=dict,
converter=ImmutableMap,
static=True,
repr=False,
compare=False,
)
"""Metadata about the orbit."""


# ==========================================================================

Expand All @@ -62,6 +75,17 @@ class InterpolatedOrbit(AbstractOrbit):
interpolant: PhaseSpacePositionInterpolant
"""The interpolation function."""

_: KW_ONLY

meta: ImmutableMap[Any] = eqx.field(
default_factory=dict,
converter=ImmutableMap,
static=True,
repr=False,
compare=False,
)
"""Metadata about the orbit."""

def __call__(self, t: BatchFloatQScalar) -> Orbit:
"""Call the interpolation."""
# TODO: more efficilent conversion to Orbit
Expand Down
22 changes: 17 additions & 5 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def evaluate_orbit(
*,
integrator: "Integrator | None" = None,
interpolated: Literal[True, False] = False,
include_meta: Literal[True, False] = False,
) -> "Orbit":
"""Compute an orbit in a potential.

Expand Down Expand Up @@ -357,7 +358,7 @@ def evaluate_orbit(
A :class:`~galax.coordinates.PhaseSpacePosition` will be
constructed, interpreting the array as the 'q', 'p' (each
Array[float, (*batch, 3)]) arguments, with 't' set to ``t[0]``.
t: Quantity[float, (time,)]
t : Quantity[float, (time,)]
Array of times at which to compute the orbit. The first element
should be the initial time and the last element should be the final
time and the array should be monotonically moving from the first to
Expand All @@ -375,10 +376,16 @@ def evaluate_orbit(
Integrator to use. If `None`, the default integrator
:class:`~galax.integrator.DiffraxIntegrator` is used.

interpolated: bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the orbit
at the requested times. Default is `False`.
interpolated : bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the
orbit at the requested times. Default is `False`.

include_meta : bool, optional keyword-only
Metadata is attached as an :class:`~immutable_map_jax.ImmutableMap`.
If `True`, the metadata is populated with:

- `'integrator'`: The integrator used.
- `'has_t0'`: Whether `w0` has time information.

Returns
-------
Expand All @@ -396,7 +403,12 @@ def evaluate_orbit(
return cast(
"Orbit",
evaluate_orbit(
self, w0, t, integrator=integrator, interpolated=interpolated
self,
w0,
t,
integrator=integrator,
interpolated=interpolated,
include_meta=include_meta,
),
)

Expand Down
Loading