Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 67 additions & 6 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,8 @@ def cumulative_prod(
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, weights: ArrayLike | None = None, *,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make weights keyword-only, as NumPy does.

interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array:
"""Compute the quantile of the data along the specified axis.

JAX implementation of :func:`numpy.quantile`.
Expand Down Expand Up @@ -2449,7 +2450,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True)

def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
method: str, keepdims: bool, squash_nans: bool) -> Array:
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array:
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]:
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'")
a, = promote_dtypes_inexact(a)
Expand Down Expand Up @@ -2488,6 +2489,66 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}")

a_shape = a.shape
# Handle weights
if weights is not None:
a, weights = promote_dtypes_inexact(a, weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a has already been promoted; we should promote only once. That requires making the previous promote_dtypes_inexact conditional on whether weights is supplied.

if axis is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Axis will never be None here, because of the if axis is None check above.

If we need to change the shape of weights based on the value of axis, that needs to be done before this block.

a = a.ravel()
weights = weights.ravel()
axis = 0
else:
weights = _broadcast_to(weights, a.shape)
if squash_nans:
nan_mask = ~lax_internal._isnan(a)
if axis is None:
a = a[nan_mask]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot do this, because it results in a dynamically shaped array, and will fail under JIT

weights = weights[nan_mask]
else:
weights = _where(nan_mask, weights, 0)
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis)

cum_weights = lax.cumsum(weights_sorted, axis=axis)
total_weight = lax.sum(weights_sorted, axis=axis, keepdims=True)
if lax_internal._all(total_weight == 0):
raise ValueError("Sum of weights must not be zero.")
cum_weights_norm = cum_weights / total_weight
quantile_pos = q
mask = cum_weights_norm >= quantile_pos[..., None]
idx = lax.argmin(mask.astype(int), axis=axis)
idx_prev = lax.max(idx - 1, _lax_const(idx, 0))
idx_next = idx
gather_shape = list(a_sorted.shape)
gather_shape[axis] = 1
dnums = lax.GatherDimensionNumbers(
offset_dims=tuple(range(len(a_sorted.shape))),
collapsed_slice_dims=(axis,),
start_index_map=(axis,))
prev_value = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
next_value = lax.gather(a_sorted, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
prev_cumw = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)
next_cumw = lax.gather(cum_weights_norm, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape)

if method == "linear":
denom = next_cumw - prev_cumw
denom = lax.select(denom == 0, _lax_const(denom, 1), denom)
weight = (quantile_pos - prev_cumw) / denom
result = prev_value * (1 - weight) + next_value * weight
elif method == "lower":
result = prev_value
elif method == "higher":
result = next_value
elif method == "nearest":
use_prev = (quantile_pos - prev_cumw) < (next_cumw - quantile_pos)
result = lax.select(use_prev, prev_value, next_value)
elif method == "midpoint":
result = (prev_value + next_value) / 2
else:
raise ValueError(f"{method=!r} not recognized")

if not keepdims:
result = lax.squeeze(result, axis)
return lax.convert_element_type(result, a.dtype)


if squash_nans:
a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end.
Expand Down Expand Up @@ -2578,7 +2639,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis.

JAX implementation of :func:`numpy.percentile`.
Expand Down Expand Up @@ -2623,7 +2684,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
raise TypeError("percentile() argument interpolation was removed in JAX"
" v0.8.0. Use method instead.")
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
method=method, keepdims=keepdims, weights=weights)


# TODO(jakevdp): interpolation argument deprecated 2024-05-16
Expand All @@ -2632,7 +2693,7 @@ def percentile(a: ArrayLike, q: ArrayLike,
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array:
"""Compute the percentile of the data along the specified axis, ignoring NaN values.

JAX implementation of :func:`numpy.nanpercentile`.
Expand Down Expand Up @@ -2680,7 +2741,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
raise TypeError("nanpercentile() argument interpolation was removed in JAX"
" v0.8.0. Use method instead.")
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
method=method, keepdims=keepdims, weights=weights)


@export
Expand Down
9 changes: 9 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from jax._src import config
from jax._src import dtypes
from jax._src.numpy.reductions import quantile
from jax._src import test_util as jtu

config.parse_flags_with_absl()
Expand Down Expand Up @@ -824,6 +825,14 @@ def testPercentilePrecision(self):
x = jnp.float64([1, 2, 3, 4, 7, 10])
self.assertEqual(jnp.percentile(x, 50), 3.5)

def test_weighted_quantile_linear(self):
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
q = jnp.array([0.5])
expected = np.quantile(a, q, weights=weights)
result = quantile(a, q, weights=weights, method="linear")
np.testing.assert_allclose(result, expected, rtol=1e-6)

@jtu.sample_product(
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
Expand Down