Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 37 additions & 35 deletions xarray/computation/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import functools
from collections import Counter
from collections.abc import (
Callable,
Hashable,
)
from collections.abc import Callable, Hashable
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np
Expand All @@ -23,10 +20,7 @@
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import (
is_scalar,
parse_dims_as_set,
)
from xarray.core.utils import is_scalar, parse_dims_as_set
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand Down Expand Up @@ -912,14 +906,17 @@ def _calc_idxminmax(
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(
f"Dimension {dim!r} not found in array dimensions {array.dims!r}"
)
if dim not in array.coords:
raise KeyError(
f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)
dims = [dim] if isinstance(dim, str) else list(dim)

for _dim in dims:
if _dim not in array.dims:
raise KeyError(
f"Dimension {_dim!r} not found in array dimensions {array.dims!r}"
)
if _dim not in array.coords:
raise KeyError(
f"Dimension {_dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"
Expand All @@ -931,25 +928,30 @@ def _calc_idxminmax(

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
# Force dictionary format in case of single dim so that we can iterate over it in for loop below
if len(dims) == 1:
indx = {dims[0]: indx}

res = {}
for _dim, _da_idx in zip(dims, indx.values(), strict=False):
# Handle chunked arrays (e.g. dask).
coord = array[_dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[_dim].data, chunks=((array.sizes[_dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[_dim].data, array.data))

# Handle chunked arrays (e.g. dask).
coord = array[dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[dim].data, chunks=((array.sizes[dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[dim].data, array.data))

res = indx._replace(coord[(indx.variable,)]).rename(dim)

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs
_res = _da_idx._replace(coord[(_da_idx.variable,)]).rename(_dim)
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
_res = _res.where(~allna, fill_value)
_res.attrs = _da_idx.attrs
res[_dim] = _res

if len(dims) == 1:
res = res[dims[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should have some type stability here so

idmax(dim) -> array; idxmax((dim,)) -> tuple[array]; idxmax((dim0, dim1, ...)) -> tuple[array, ...]

Copy link
Contributor Author

@gcaria gcaria Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the code to match the behavior of DataArray.arg* which returns a dict for both idx*((dim,)) and idx*((dim0, dim1, ...))

Does that seem sensible?

Currently navigating the existing tests for arg* and multiple dims

return res
Loading