Skip to content

Commit

Permalink
feat: add cumulative_sum (#63)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Jul 25, 2024
1 parent 79c657e commit 0a88ecf
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
16 changes: 15 additions & 1 deletion src/quaxed/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["max", "mean", "min", "prod", "std", "sum", "var"]
__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]


from jax.experimental import array_api
Expand All @@ -9,6 +9,20 @@
from quaxed._utils import quaxify


@quaxify
def cumulative_sum(
x: ArrayLike,
/,
*,
axis: int | None = None,
dtype: DType | None = None,
include_initial: bool = False,
) -> Value:
return array_api.cumulative_sum(
x, axis=axis, dtype=dtype, include_initial=include_initial
)


@quaxify
def max( # pylint: disable=redefined-builtin
x: ArrayLike,
Expand Down
36 changes: 36 additions & 0 deletions tests/array_api/test_myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,42 @@ def test_sort():
# Statistical functions


def test_cumulative_sum():
"""Test `cumulative_sum`."""
x = MyArray(xp.asarray([1, 2, 3], dtype=float))

# No arguments
got = xp.cumulative_sum(x)
expected = MyArray(xp.asarray([1, 3, 6], dtype=float))

assert isinstance(got, MyArray)
assert jnp.array_equal(got.array, expected.array)

# axis
got = xp.cumulative_sum(x, axis=0)
expected = MyArray(xp.asarray([1, 3, 6], dtype=float))

assert isinstance(got, MyArray)
assert jnp.array_equal(got.array, expected.array)

with pytest.raises(ValueError, match="axis 1"):
_ = xp.cumulative_sum(x, axis=1)

# dtype
got = xp.cumulative_sum(x, dtype=int)
expected = MyArray(xp.asarray([1, 3, 6], dtype=int))

assert isinstance(got, MyArray)
assert jnp.array_equal(got.array, expected.array)

# initial
got = xp.cumulative_sum(x, include_initial=True)
expected = MyArray(xp.asarray([0, 1, 3, 6]))

assert isinstance(got, MyArray)
assert jnp.array_equal(got.array, expected.array)


@pytest.mark.skip("TODO")
def test_max():
"""Test `max`."""
Expand Down
12 changes: 7 additions & 5 deletions tests/myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,18 @@ def _complex_p(x: MyArray, y: MyArray) -> MyArray:
def _concatenate_p(
operand0: MyArray,
*operands: MyArray,
dimension: Any,
**kwargs: Any,
) -> MyArray:
return MyArray(
lax.concatenate(
[operand0.array] + [op.array for op in operands],
dimension=dimension,
),
lax.concatenate([operand0.array] + [op.array for op in operands], **kwargs)
)


@register(lax.concatenate_p)
def _concatenate_p(operand0: ArrayLike, operand1: MyArray, /, **kwargs: Any) -> MyArray:
return MyArray(lax.concatenate_p.bind(operand0, operand1.array, **kwargs))


# ==============================================================================


Expand Down

0 comments on commit 0a88ecf

Please sign in to comment.