diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 4b4ca61c8bb5..392c107b71f1 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -23,12 +23,15 @@ import numpy as np +import jax +from jax import lax +import jax._src.numpy as jnp from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes from jax._src.numpy.util import ( - _broadcast_to, ensure_arraylike, + _broadcast_to, check_arraylike, ensure_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where) from jax._src.lax import control_flow from jax._src.lax import lax as lax @@ -2345,7 +2348,8 @@ def cumulative_prod( @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, + interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis. JAX implementation of :func:`numpy.quantile`. @@ -2383,7 +2387,10 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No >>> jnp.quantile(x, q, method='nearest') Array([2., 4., 7.], dtype=float32) """ - a, q = ensure_arraylike("quantile", a, q) + if weights is None: + a, q = ensure_arraylike("quantile", a, q) + else: + a, q, weights = ensure_arraylike("quantile", a, q, weights) if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") @@ -2391,14 +2398,15 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No if not isinstance(interpolation, DeprecatedArg): raise TypeError("quantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False) + return _quantile(a, q, axis, method, keepdims, False, weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, + interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: """Compute the quantile of the data along the specified axis, ignoring NaNs. JAX implementation of :func:`numpy.nanquantile`. @@ -2437,7 +2445,10 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = >>> jnp.nanquantile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - a, q = ensure_arraylike("nanquantile", a, q) + if weights is None: + a, q = ensure_arraylike("nanquantile", a, q) + else: + a, q, weights = ensure_arraylike("nanquantile", a, q, weights) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") @@ -2446,13 +2457,12 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = if not isinstance(interpolation, DeprecatedArg): raise TypeError("nanquantile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") - return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True) + return _quantile(a, q, axis, method, keepdims, True, weights) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, - method: str, keepdims: bool, squash_nans: bool) -> Array: - if method not in ["linear", "lower", "higher", "midpoint", "nearest"]: - raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', or 'nearest'") - a, = promote_dtypes_inexact(a) + method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array: + if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]: + raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'") keepdim = [] if dtypes.issubdtype(a.dtype, np.complexfloating): raise ValueError("quantile does not support complex input, as the operation is poorly defined.") @@ -2482,12 +2492,83 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, else: axis = canonicalize_axis(axis, a.ndim) + q, = promote_dtypes_inexact(q) + q = jnp.atleast_1d(q) q_shape = q.shape q_ndim = q.ndim if q_ndim > 1: raise ValueError(f"q must be have rank <= 1, got shape {q.shape}") a_shape = a.shape + # Handle weights + if weights is None: + a, = promote_dtypes_inexact(a) + else: + a, weights = promote_dtypes_inexact(a, weights) + a_shape = a.shape + w_shape = np.shape(weights) + if w_shape != a_shape: + if axis is None: + raise TypeError("Axis must be specified when shapes of a and weights differ.") + if isinstance(axis, tuple): + if w_shape != tuple(a_shape[i] for i in axis): + raise ValueError("Shape of weights must match the shape of the axes being reduced.") + weights = lax.broadcast_in_dim( + weights, + shape=a_shape, + broadcast_dimensions=axis + ) + else: + if len(w_shape) != 1 or w_shape[0] != a_shape[axis]: + raise ValueError("Length of weights not compatible with specified axis.") + weights = lax.expand_dims(weights, axis) + weights = _broadcast_to(weights, a.shape) + + + if squash_nans: + nan_mask = ~lax_internal._isnan(a) + weights = _where(nan_mask, weights, 0) + else: + with jax.debug_nans(False): + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) + + total_weight = sum(weights, axis=axis, keepdims=True) + a_sorted, weights_sorted = lax.sort_key_val(a, weights, dimension=axis) + cum_weights = lax.cumsum(weights_sorted, axis=axis) + cum_weights_norm = lax.div(cum_weights, total_weight) + + def _weighted_quantile(qi): + index_dtype = dtypes.default_int_dtype() + idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims) + idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1) + val = jnp.take_along_axis(a_sorted, idx, axis) + + idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1) + val_prev = jnp.take_along_axis(a_sorted, idx_prev, axis) + cw_prev = jnp.take_along_axis(cum_weights_norm, idx_prev, axis) + cw_next = jnp.take_along_axis(cum_weights_norm, idx, axis) + + if method == "linear": + denom = cw_next - cw_prev + denom = _where(denom == 0, 1, denom) + weight = (qi - cw_prev) / denom + out = val_prev * (1 - weight) + val * weight + elif method == "lower": + out = val_prev + elif method == "higher": + out = val + elif method == "nearest": + out = _where(lax.abs(qi - cw_prev) < lax.abs(qi - cw_next), val_prev, val) + elif method == "midpoint": + out = (val_prev + val) / 2 + elif method == "inverted_cdf": + out = val + else: + raise ValueError(f"{method=!r} not recognized") + return out + + result = jax.vmap(_weighted_quantile)(q) + return result if squash_nans: a = _where(lax._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. @@ -2563,6 +2644,8 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = lax.select(pred, low_value, high_value) elif method == "midpoint": result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) + elif method == "inverted_cdf": + result = high_value else: raise ValueError(f"{method=!r} not recognized") if keepdims and keepdim: @@ -2578,7 +2661,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis. JAX implementation of :func:`numpy.percentile`. @@ -2616,14 +2699,17 @@ def percentile(a: ArrayLike, q: ArrayLike, >>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32) """ - a, q = ensure_arraylike("percentile", a, q) + if weights is None: + a, q = ensure_arraylike("percentile", a, q) + else: + a, q, weights = ensure_arraylike("percentile", a, q, weights) q, = promote_dtypes_inexact(q) # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 if not isinstance(interpolation, DeprecatedArg): raise TypeError("percentile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, - method=method, keepdims=keepdims) + method=method, keepdims=keepdims, weights=weights) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 @@ -2632,7 +2718,7 @@ def percentile(a: ArrayLike, q: ArrayLike, def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False, *, weights: ArrayLike | None = None, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: """Compute the percentile of the data along the specified axis, ignoring NaN values. JAX implementation of :func:`numpy.nanpercentile`. @@ -2672,7 +2758,10 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, >>> jnp.nanpercentile(x, q) Array([1.5, 3. , 4.5], dtype=float32) """ - a, q = ensure_arraylike("nanpercentile", a, q) + if weights is None: + a, q = ensure_arraylike("nanpercentile", a, q) + else: + a, q, weights = ensure_arraylike("nanpercentile", a, q, weights) q, = promote_dtypes_inexact(q) q = q / 100 # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 @@ -2680,7 +2769,7 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, raise TypeError("nanpercentile() argument interpolation was removed in JAX" " v0.8.0. Use method instead.") return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, - method=method, keepdims=keepdims) + method=method, keepdims=keepdims, weights=weights) @export diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 9f79cec9ca18..9752f58db22e 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -824,6 +824,78 @@ def testPercentilePrecision(self): x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) + @jtu.sample_product( + [dict(a_shape=a_shape, axis=axis) + for a_shape, axis in ( + ((7,), None), + ((6, 7,), None), + ((47, 7), 0), + ((47, 7), ()), + ((4, 101), 1), + ((4, 47, 7), (1, 2)), + ((4, 47, 7), (0, 2)), + ((4, 47, 7), (1, 0, 2)), + ) + ], + a_dtype=default_dtypes, + q_dtype=[np.float32], + q_shape=scalar_shapes + [(1,), (4,)], + keepdims=[False, True], + method=['linear', 'lower', 'higher', 'nearest', 'midpoint', 'inverted_cdf'], +) + def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdims, method): + rng = jtu.rand_default(self.rng()) + a = rng(a_shape, a_dtype) + q = rng(q_shape, q_dtype) + if axis is None: + weights_shape = a_shape + elif isinstance(axis, tuple): + weights_shape = tuple(a_shape[i] for i in axis) + else: + weights_shape = (a_shape[axis],) + weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3 + + def np_fun(a, q, weights): + return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims) + def jnp_fun(a, q, weights): + return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims) + args_maker = lambda: [ + rng(a_shape, a_dtype), + rng(q_shape, q_dtype), + np.abs(rng(weights_shape, a_dtype)) + 1e-3 + ] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6) + self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6) + + def test_weighted_quantile_negative_weights(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, -1, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + with self.assertRaisesRegex(ValueError, "Weights must be non-negative"): + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) + + def test_weighted_quantile_all_weights_zero(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.zeros_like(a) + q = jnp.array([0.5]) + with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"): + jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) + + def test_weighted_quantile_weights_with_nan(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float) + q = jnp.array([0.5]) + result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights) + assert np.isnan(np.array(result)).all() + + def test_weighted_quantile_scalar_q(self): + a = jnp.array([1, 2, 3, 4, 5], dtype=float) + weights = jnp.array([1, 2, 1, 1, 1], dtype=float) + q = 0.5 + result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights) + assert jnp.issubdtype(result.dtype, jnp.floating) + assert result.shape == () + @jtu.sample_product( [dict(a_shape=a_shape, axis=axis) for a_shape, axis in (