-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add Weighted Quantile and Percentile Support to jax.numpy #32737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2345,7 +2345,8 @@ def cumulative_prod( | |
@partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) | ||
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, | ||
out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
keepdims: bool = False, weights: ArrayLike | None = None, *, | ||
interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: | ||
"""Compute the quantile of the data along the specified axis. | ||
|
||
JAX implementation of :func:`numpy.quantile`. | ||
|
@@ -2449,7 +2450,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = | |
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True) | ||
|
||
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | ||
method: str, keepdims: bool, squash_nans: bool) -> Array: | ||
method: str, keepdims: bool, squash_nans: bool, weights: ArrayLike | None = None) -> Array: | ||
if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: | ||
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") | ||
a, = promote_dtypes_inexact(a) | ||
|
@@ -2488,6 +2489,66 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | |
raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") | ||
|
||
a_shape = a.shape | ||
# Handle weights | ||
if weights is not None: | ||
a, weights = promote_dtypes_inexact(a, weights) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if axis is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Axis will never be None here, because of the If we need to change the shape of |
||
a = a.ravel() | ||
weights = weights.ravel() | ||
axis = 0 | ||
else: | ||
weights = _broadcast_to(weights, a.shape) | ||
if squash_nans: | ||
nan_mask = ~lax_internal._isnan(a) | ||
if axis is None: | ||
a = a[nan_mask] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot do this, because it results in a dynamically shaped array, and will fail under JIT |
||
weights = weights[nan_mask] | ||
else: | ||
weights = _where(nan_mask, weights, 0) | ||
a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) | ||
|
||
cum_weights = lax.cumsum(weights_sorted, axis=axis) | ||
total_weight = lax.sum(weights_sorted, axis=axis, keepdims=True) | ||
if lax_internal._all(total_weight == 0): | ||
raise ValueError("Sum of weights must not be zero.") | ||
cum_weights_norm = cum_weights / total_weight | ||
quantile_pos = q | ||
mask = cum_weights_norm >= quantile_pos[..., None] | ||
idx = lax.argmin(mask.astype(int), axis=axis) | ||
idx_prev = lax.max(idx - 1, _lax_const(idx, 0)) | ||
idx_next = idx | ||
gather_shape = list(a_sorted.shape) | ||
gather_shape[axis] = 1 | ||
dnums = lax.GatherDimensionNumbers( | ||
offset_dims=tuple(range(len(a_sorted.shape))), | ||
collapsed_slice_dims=(axis,), | ||
start_index_map=(axis,)) | ||
prev_value = lax.gather(a_sorted, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) | ||
next_value = lax.gather(a_sorted, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) | ||
prev_cumw = lax.gather(cum_weights_norm, idx_prev[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) | ||
next_cumw = lax.gather(cum_weights_norm, idx_next[..., None], dimension_numbers=dnums, slice_sizes=gather_shape) | ||
|
||
if method == "linear": | ||
denom = next_cumw - prev_cumw | ||
denom = lax.select(denom == 0, _lax_const(denom, 1), denom) | ||
weight = (quantile_pos - prev_cumw) / denom | ||
result = prev_value * (1 - weight) + next_value * weight | ||
elif method == "lower": | ||
result = prev_value | ||
elif method == "higher": | ||
result = next_value | ||
elif method == "nearest": | ||
use_prev = (quantile_pos - prev_cumw) < (next_cumw - quantile_pos) | ||
result = lax.select(use_prev, prev_value, next_value) | ||
elif method == "midpoint": | ||
result = (prev_value + next_value) / 2 | ||
else: | ||
raise ValueError(f"{method=!r} not recognized") | ||
|
||
if not keepdims: | ||
result = lax.squeeze(result, axis) | ||
return lax.convert_element_type(result, a.dtype) | ||
|
||
|
||
if squash_nans: | ||
a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. | ||
|
@@ -2578,7 +2639,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, | |
def percentile(a: ArrayLike, q: ArrayLike, | ||
axis: int | tuple[int, ...] | None = None, | ||
out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: | ||
"""Compute the percentile of the data along the specified axis. | ||
|
||
JAX implementation of :func:`numpy.percentile`. | ||
|
@@ -2623,7 +2684,7 @@ def percentile(a: ArrayLike, q: ArrayLike, | |
raise TypeError("percentile() argument interpolation was removed in JAX" | ||
" v0.8.0. Use method instead.") | ||
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, | ||
method=method, keepdims=keepdims) | ||
method=method, keepdims=keepdims, weights=weights) | ||
|
||
|
||
# TODO(jakevdp): interpolation argument deprecated 2024-05-16 | ||
|
@@ -2632,7 +2693,7 @@ def percentile(a: ArrayLike, q: ArrayLike, | |
def nanpercentile(a: ArrayLike, q: ArrayLike, | ||
axis: int | tuple[int, ...] | None = None, | ||
out: None = None, overwrite_input: bool = False, method: str = "linear", | ||
keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: | ||
keepdims: bool = False, weights: ArrayLike | None = None, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: | ||
"""Compute the percentile of the data along the specified axis, ignoring NaN values. | ||
|
||
JAX implementation of :func:`numpy.nanpercentile`. | ||
|
@@ -2680,7 +2741,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, | |
raise TypeError("nanpercentile() argument interpolation was removed in JAX" | ||
" v0.8.0. Use method instead.") | ||
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, | ||
method=method, keepdims=keepdims) | ||
method=method, keepdims=keepdims, weights=weights) | ||
|
||
|
||
@export | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make
weights
keyword-only, as NumPy does.