Skip to content

Commit

Permalink
fix broadcasting shaped constant parameters
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Dec 13, 2023
1 parent ca112b8 commit 3507c84
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/galax/potential/_potential/param/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import astropy.units as u
import equinox as eqx
from jax.numpy import vectorize

from galax.typing import (
BatchableFloatOrIntScalarLike,
Expand All @@ -16,7 +17,7 @@
FloatScalar,
Unit,
)
from galax.utils import partial_jit, vectorize_method
from galax.utils import partial_jit
from galax.utils.dataclasses import converter_float_array


Expand Down Expand Up @@ -62,9 +63,6 @@ class ConstantParameter(AbstractParameter):
# TODO: link this shape to the return shape from __call__
value: FloatArrayAnyShape = eqx.field(converter=converter_float_array)

# This is a workaround since vectorized methods don't support kwargs.
@partial_jit()
@vectorize_method(signature="()->()")
def _call_helper(self, _: FloatOrIntScalar) -> FloatArrayAnyShape:
return self.value

Expand All @@ -88,7 +86,11 @@ def __call__(
Array
The constant parameter value.
"""
return self._call_helper(t)
# Vectorization to enable broadcasting over the time dimension.
# We can't vectorize a method since the output shape depends on the
# input shape, so we have to do it this way.
signature = "()->" + str(self.value.shape).replace(" ", "")
return vectorize(self._call_helper, signature=signature)(t)


#####################################################################
Expand Down

0 comments on commit 3507c84

Please sign in to comment.