From 704d5f2772639fd22c07d52a040459f411ad46c7 Mon Sep 17 00:00:00 2001 From: nstarman Date: Mon, 22 Apr 2024 11:21:49 -0400 Subject: [PATCH] feat: extend include_meta to MockStream Signed-off-by: nstarman --- .../mockstream/mockstream_generator.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py index f498779e..d43d175e 100644 --- a/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py +++ b/src/galax/dynamics/_dynamics/mockstream/mockstream_generator.py @@ -10,6 +10,7 @@ import jax import jax.numpy as jnp import quax.examples.prng as jr +from jax.lax import stop_gradient from jax.lib.xla_bridge import get_backend import quaxed.array_api as xp @@ -144,7 +145,7 @@ def one_pt_intg( lead_arm_w, trail_arm_w = jax.vmap(one_pt_intg)(pt_ids, w0_lead, w0_trail) return lead_arm_w, trail_arm_w - @partial(jax.jit, static_argnames=("vmapped",)) + @partial(jax.jit, static_argnames=("vmapped", "include_meta")) def run( self, rng: jr.PRNG, @@ -153,6 +154,7 @@ def run( prog_mass: gt.FloatQScalar | ProgenitorMassCallable, *, vmapped: bool | None = None, + include_meta: bool = False, ) -> tuple[MockStream, PhaseSpacePosition]: """Generate mock stellar stream. @@ -174,6 +176,9 @@ def run( `None` (default), then `jax.vmap` is used on GPU and `jax.lax.scan` otherwise. + include_meta : bool, optional keyword-only + Whether to include metadata in the output. + Returns ------- mockstream : :class:`galax.dynamcis.MockStream` @@ -241,16 +246,20 @@ def run( raise ValueError(msg) mockstream = MockStream( - q=Quantity(q, self.units["length"]), + q=Quantity(q, self.units["length"]), # TODO: already Q? p=Quantity(p, self.units["speed"]), t=t, release_time=release_time, - meta={ - "generator": self, - "rng": original_rng, - "mass": prog_mass, - "vmapped": use_vmap, - }, + meta=( + { + "generator": stop_gradient(self), + "rng": stop_gradient(original_rng), + "mass": stop_gradient(prog_mass), + "vmapped": use_vmap, + } + if include_meta + else {} + ), ) return mockstream, prog_o[-1]