Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ mypy = "==1.18.1"
pyright = "*"
hypothesis = "*"
lxml = "*"
pandas-stubs = "<=2.2.3.241126" # https://github.com/pydata/xarray/issues/10110
pandas-stubs = "<=2.3.3.251219"
types-colorama = "*"
types-docutils = "*"
types-psutil = "*"
Expand Down
6 changes: 4 additions & 2 deletions xarray/coding/cftimeindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,12 +514,14 @@ def shift( # type: ignore[override,unused-ignore]
f"'freq' must be of type str or datetime.timedelta, got {type(freq)}."
)

def __add__(self, other) -> Self:
# pandas-stubs defines many overloads for Index.__add__/__radd__ with specific
# return types, but CFTimeIndex legitimately returns Self for all cases
def __add__(self, other) -> Self: # type: ignore[override]
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return type(self)(np.array(self) + other)

def __radd__(self, other) -> Self:
def __radd__(self, other) -> Self: # type: ignore[override]
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return type(self)(other + np.array(self))
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index:

for i, index in enumerate(indexes):
if isinstance(index, pd.MultiIndex):
codes, levels = index.codes, index.levels
codes: list[np.ndarray] = list(index.codes)
levels = index.levels
else:
code, level = pd.factorize(index)
codes = [code]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7402,8 +7402,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

arrays = []
extension_arrays = []
arrays: list[tuple[Hashable, np.ndarray]] = []
extension_arrays: list[tuple[Hashable, pd.Series]] = []
for k, v in dataframe.items():
if not is_allowed_extension_array(v) or isinstance(
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
Expand Down
9 changes: 6 additions & 3 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def as_extension_array(
[array_or_scalar], dtype=dtype
)
else:
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr]
# pandas-stubs is overly strict about astype's dtype parameter and return type;
# ExtensionArray.astype accepts ExtensionDtype and returns ExtensionArray
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr,return-value,arg-type]


@implements(np.result_type)
Expand Down Expand Up @@ -192,10 +194,11 @@ def __extension_duck_array__where(
# pd.where won't broadcast 0-dim arrays across a scalar-like series; scalar y's must be preserved
if hasattr(y, "shape") and len(y.shape) == 1 and y.shape[0] == 1:
y = y[0] # type: ignore[index]
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type]
# pandas-stubs has strict overloads for Series.where that don't cover all valid arg types
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[call-overload]


def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list:
def _replace_duck(args, replacer: Callable[[PandasExtensionArray], Any]) -> list:
args_as_list = list(args)
for index, value in enumerate(args_as_list):
if isinstance(value, PandasExtensionArray):
Expand Down
24 changes: 14 additions & 10 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,9 @@ def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.n
flat_labels = np.ravel(labels)
if flat_labels.dtype == "float16":
flat_labels = flat_labels.astype("float64")
flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)
flat_indexer = index.get_indexer(
pd.Index(flat_labels), method=method, tolerance=tolerance
)
indexer = flat_indexer.reshape(labels.shape)
return indexer

Expand Down Expand Up @@ -978,14 +980,15 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex:
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
"""
if isinstance(index, pd.MultiIndex):
new_index = cast(pd.MultiIndex, index.remove_unused_levels())
new_index = index.remove_unused_levels()
# if it contains CategoricalIndex, we need to remove unused categories
# manually. See https://github.com/pandas-dev/pandas/issues/30846
if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels):
levels = []
for i, level in enumerate(new_index.levels):
if isinstance(level, pd.CategoricalIndex):
level = level[new_index.codes[i]].remove_unused_categories()
# pandas-stubs is missing remove_unused_categories on CategoricalIndex
level = level[new_index.codes[i]].remove_unused_categories() # type: ignore[attr-defined]
else:
level = level[new_index.codes[i]]
levels.append(level)
Expand Down Expand Up @@ -1229,7 +1232,7 @@ def reorder_levels(
its corresponding coordinates.

"""
index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys()))
index = self.index.reorder_levels(list(level_variables.keys()))
level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names}
return self._replace(index, level_coords_dtype=level_coords_dtype)

Expand Down Expand Up @@ -1378,28 +1381,29 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult:
indexer = DataArray(indexer, coords=coords, dims=label.dims)

if new_index is not None:
xr_index: PandasIndex | PandasMultiIndex
if isinstance(new_index, pd.MultiIndex):
level_coords_dtype = {
k: self.level_coords_dtype[k] for k in new_index.names
}
new_index = self._replace(
xr_index = self._replace(
new_index, level_coords_dtype=level_coords_dtype
)
dims_dict = {}
drop_coords = []
drop_coords: list[Hashable] = []
else:
new_index = PandasIndex(
xr_index = PandasIndex(
new_index,
new_index.name,
coord_dtype=self.level_coords_dtype[new_index.name],
)
dims_dict = {self.dim: new_index.index.name}
dims_dict = {self.dim: xr_index.index.name}
drop_coords = [self.dim]

# variable(s) attrs and encoding metadata are propagated
# when replacing the indexes in the resulting xarray object
new_vars = new_index.create_variables()
indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, new_index))
new_vars = xr_index.create_variables()
indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, xr_index))

# add scalar variable for each dropped level
variables = new_vars
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]:
minval = np.nanmin(new_x_loaded)
maxval = np.nanmax(new_x_loaded)
index = x.to_index()
imin, imax = index.get_indexer([minval, maxval], method="nearest")
imin, imax = index.get_indexer(pd.Index([minval, maxval]), method="nearest")
indexes[dim] = slice(max(imin - 2, 0), imax + 2)
indexes_coords[dim] = (x[indexes[dim]], new_x)
return obj.isel(indexes), indexes_coords # type: ignore[attr-defined]
Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7500,10 +7500,10 @@ def test_write_file_from_np_str(str_type: type[str | np.str_], tmpdir: str) -> N
)
tdf.index.name = "scenario"
tdf.columns.name = "year"
tdf = cast(pd.DataFrame, tdf.stack())
tdf.name = "tas"
tdf_series = tdf.stack()
tdf_series.name = "tas" # type: ignore[union-attr]

txr = tdf.to_xarray()
txr = tdf_series.to_xarray()

txr.to_netcdf(tmpdir.join("test.nc"))

Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3931,7 +3931,8 @@ def test_to_and_from_dict_with_nan_nat(self) -> None:
y = np.random.randn(10, 3)
y[2] = np.nan
t = pd.Series(pd.date_range("20130101", periods=10))
t[2] = np.nan
# pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT
t[2] = np.nan # type: ignore[call-overload]
lat = [77.7, 83.2, 76]
da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"])
roundtripped = DataArray.from_dict(da.to_dict())
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3497,7 +3497,8 @@ def test_rename_multiindex(self) -> None:
midx_coords = Coordinates.from_pandas_multiindex(midx, "x")
original = Dataset({}, midx_coords)

midx_renamed = midx.rename(["a", "c"])
# pandas-stubs expects Hashable for rename, but list of names works for MultiIndex
midx_renamed = midx.rename(["a", "c"]) # type: ignore[call-overload]
midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x")
expected = Dataset({}, midx_coords_renamed)

Expand Down Expand Up @@ -5602,7 +5603,8 @@ def test_to_and_from_dict_with_nan_nat(
y = np.random.randn(10, 3)
y[2] = np.nan
t = pd.Series(pd.date_range("20130101", periods=10))
t[2] = np.nan
# pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT
t[2] = np.nan # type: ignore[call-overload]

lat = [77.7, 83.2, 76]
ds = Dataset(
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,7 @@ def test_extension_array_attr():
assert (roundtripped == wrapped).all()

interval_array = pd.arrays.IntervalArray.from_breaks([0, 1, 2, 3], closed="right")
wrapped = PandasExtensionArray(interval_array)
# pandas-stubs types PandasExtensionArray too narrowly; IntervalArray is valid
wrapped = PandasExtensionArray(interval_array) # type: ignore[arg-type]
assert_array_equal(wrapped.left, interval_array.left, strict=True)
assert wrapped.closed == interval_array.closed
2 changes: 1 addition & 1 deletion xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def test__infer_interval_breaks(self) -> None:
[-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10])
)
assert_array_equal(
pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), # type: ignore[operator]
pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"),
_infer_interval_breaks(pd.date_range("20000101", periods=3)),
)

Expand Down
Loading