Skip to content

Commit

Permalink
feat: extend include_meta to MockStream
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Apr 29, 2024
1 parent 5d94900 commit 7483ad2
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax
import jax.numpy as jnp
import quax.examples.prng as jr
from jax.lax import stop_gradient
from jax.lib.xla_bridge import get_backend

import quaxed.array_api as xp
Expand Down Expand Up @@ -144,7 +145,7 @@ def one_pt_intg(
lead_arm_w, trail_arm_w = jax.vmap(one_pt_intg)(pt_ids, w0_lead, w0_trail)
return lead_arm_w, trail_arm_w

@partial(jax.jit, static_argnames=("vmapped",))
@partial(jax.jit, static_argnames=("vmapped", "include_meta"))
def run(
self,
rng: jr.PRNG,
Expand All @@ -153,6 +154,7 @@ def run(
prog_mass: gt.FloatQScalar | ProgenitorMassCallable,
*,
vmapped: bool | None = None,
include_meta: bool = False,
) -> tuple[MockStream, PhaseSpacePosition]:
"""Generate mock stellar stream.
Expand All @@ -174,6 +176,9 @@ def run(
`None` (default), then `jax.vmap` is used on GPU and `jax.lax.scan`
otherwise.
include_meta : bool, optional keyword-only
Whether to include metadata in the output.
Returns
-------
mockstream : :class:`galax.dynamcis.MockStream`
Expand Down Expand Up @@ -241,16 +246,20 @@ def run(
raise ValueError(msg)

mockstream = MockStream(
q=Quantity(q, self.units["length"]),
q=Quantity(q, self.units["length"]), # TODO: already Q?
p=Quantity(p, self.units["speed"]),
t=t,
release_time=release_time,
meta={
"generator": self,
"rng": original_rng,
"mass": prog_mass,
"vmapped": use_vmap,
},
meta=(
{
"generator": stop_gradient(self),
"rng": stop_gradient(original_rng),
"mass": stop_gradient(prog_mass),
"vmapped": use_vmap,
}
if include_meta
else {}
),
)

return mockstream, prog_o[-1]

0 comments on commit 7483ad2

Please sign in to comment.