Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ mypy = "==1.18.1"
pyright = "*"
hypothesis = "*"
lxml = "*"
pandas-stubs = "<=2.2.3.241126" # https://github.com/pydata/xarray/issues/10110
pandas-stubs = "<=2.3.3.251219"
types-colorama = "*"
types-docutils = "*"
types-psutil = "*"
Expand Down
6 changes: 4 additions & 2 deletions xarray/coding/cftimeindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,12 +514,14 @@ def shift( # type: ignore[override,unused-ignore]
f"'freq' must be of type str or datetime.timedelta, got {type(freq)}."
)

def __add__(self, other) -> Self:
# pandas-stubs defines many overloads for Index.__add__/__radd__ with specific
# return types, but CFTimeIndex legitimately returns Self for all cases
def __add__(self, other) -> Self: # type: ignore[override]
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return type(self)(np.array(self) + other)

def __radd__(self, other) -> Self:
def __radd__(self, other) -> Self: # type: ignore[override]
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return type(self)(other + np.array(self))
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index:

for i, index in enumerate(indexes):
if isinstance(index, pd.MultiIndex):
codes, levels = index.codes, index.levels
codes: list[np.ndarray] = list(index.codes)
levels = index.levels
else:
code, level = pd.factorize(index)
codes = [code]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7402,8 +7402,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

arrays = []
extension_arrays = []
arrays: list[tuple[Hashable, np.ndarray]] = []
extension_arrays: list[tuple[Hashable, pd.Series]] = []
for k, v in dataframe.items():
if not is_allowed_extension_array(v) or isinstance(
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
Expand Down
9 changes: 6 additions & 3 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def as_extension_array(
[array_or_scalar], dtype=dtype
)
else:
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr]
# pandas-stubs is overly strict about astype's dtype parameter and return type;
# ExtensionArray.astype accepts ExtensionDtype and returns ExtensionArray
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr,return-value,arg-type]


@implements(np.result_type)
Expand Down Expand Up @@ -192,10 +194,11 @@ def __extension_duck_array__where(
# pd.where won't broadcast 0-dim arrays across a scalar-like series; scalar y's must be preserved
if hasattr(y, "shape") and len(y.shape) == 1 and y.shape[0] == 1:
y = y[0] # type: ignore[index]
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type]
# pandas-stubs has strict overloads for Series.where that don't cover all valid arg types
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[call-overload]


def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list:
def _replace_duck(args, replacer: Callable[[PandasExtensionArray], Any]) -> list:
args_as_list = list(args)
for index, value in enumerate(args_as_list):
if isinstance(value, PandasExtensionArray):
Expand Down
23 changes: 13 additions & 10 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.n
flat_labels = np.ravel(labels)
if flat_labels.dtype == "float16":
flat_labels = flat_labels.astype("float64")
flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)
# pandas-stubs expects Index for get_indexer, but ndarray works at runtime
flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) # type: ignore[arg-type]
indexer = flat_indexer.reshape(labels.shape)
return indexer

Expand Down Expand Up @@ -978,14 +979,15 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex:
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
"""
if isinstance(index, pd.MultiIndex):
new_index = cast(pd.MultiIndex, index.remove_unused_levels())
new_index = index.remove_unused_levels()
# if it contains CategoricalIndex, we need to remove unused categories
# manually. See https://github.com/pandas-dev/pandas/issues/30846
if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels):
levels = []
for i, level in enumerate(new_index.levels):
if isinstance(level, pd.CategoricalIndex):
level = level[new_index.codes[i]].remove_unused_categories()
# pandas-stubs is missing remove_unused_categories on CategoricalIndex
level = level[new_index.codes[i]].remove_unused_categories() # type: ignore[attr-defined]
else:
level = level[new_index.codes[i]]
levels.append(level)
Expand Down Expand Up @@ -1229,7 +1231,7 @@ def reorder_levels(
its corresponding coordinates.

"""
index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys()))
index = self.index.reorder_levels(list(level_variables.keys()))
level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names}
return self._replace(index, level_coords_dtype=level_coords_dtype)

Expand Down Expand Up @@ -1378,28 +1380,29 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult:
indexer = DataArray(indexer, coords=coords, dims=label.dims)

if new_index is not None:
xr_index: PandasIndex | PandasMultiIndex
if isinstance(new_index, pd.MultiIndex):
level_coords_dtype = {
k: self.level_coords_dtype[k] for k in new_index.names
}
new_index = self._replace(
xr_index = self._replace(
new_index, level_coords_dtype=level_coords_dtype
)
dims_dict = {}
drop_coords = []
drop_coords: list[Hashable] = []
else:
new_index = PandasIndex(
xr_index = PandasIndex(
new_index,
new_index.name,
coord_dtype=self.level_coords_dtype[new_index.name],
)
dims_dict = {self.dim: new_index.index.name}
dims_dict = {self.dim: xr_index.index.name}
drop_coords = [self.dim]

# variable(s) attrs and encoding metadata are propagated
# when replacing the indexes in the resulting xarray object
new_vars = new_index.create_variables()
indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, new_index))
new_vars = xr_index.create_variables()
indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, xr_index))

# add scalar variable for each dropped level
variables = new_vars
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]:
minval = np.nanmin(new_x_loaded)
maxval = np.nanmax(new_x_loaded)
index = x.to_index()
imin, imax = index.get_indexer([minval, maxval], method="nearest")
# pandas-stubs expects Index for get_indexer, but list works at runtime
imin, imax = index.get_indexer([minval, maxval], method="nearest") # type: ignore[arg-type]
indexes[dim] = slice(max(imin - 2, 0), imax + 2)
indexes_coords[dim] = (x[indexes[dim]], new_x)
return obj.isel(indexes), indexes_coords # type: ignore[attr-defined]
Expand Down
6 changes: 3 additions & 3 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7500,10 +7500,10 @@ def test_write_file_from_np_str(str_type: type[str | np.str_], tmpdir: str) -> N
)
tdf.index.name = "scenario"
tdf.columns.name = "year"
tdf = cast(pd.DataFrame, tdf.stack())
tdf.name = "tas"
tdf_series = cast(pd.Series, tdf.stack())
tdf_series.name = "tas"

txr = tdf.to_xarray()
txr = tdf_series.to_xarray()

txr.to_netcdf(tmpdir.join("test.nc"))

Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3931,7 +3931,8 @@ def test_to_and_from_dict_with_nan_nat(self) -> None:
y = np.random.randn(10, 3)
y[2] = np.nan
t = pd.Series(pd.date_range("20130101", periods=10))
t[2] = np.nan
# pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT
t[2] = np.nan # type: ignore[call-overload]
lat = [77.7, 83.2, 76]
da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"])
roundtripped = DataArray.from_dict(da.to_dict())
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3497,7 +3497,8 @@ def test_rename_multiindex(self) -> None:
midx_coords = Coordinates.from_pandas_multiindex(midx, "x")
original = Dataset({}, midx_coords)

midx_renamed = midx.rename(["a", "c"])
# pandas-stubs expects Hashable for rename, but list of names works for MultiIndex
midx_renamed = midx.rename(["a", "c"]) # type: ignore[call-overload]
midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x")
expected = Dataset({}, midx_coords_renamed)

Expand Down Expand Up @@ -5602,7 +5603,8 @@ def test_to_and_from_dict_with_nan_nat(
y = np.random.randn(10, 3)
y[2] = np.nan
t = pd.Series(pd.date_range("20130101", periods=10))
t[2] = np.nan
# pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT
t[2] = np.nan # type: ignore[call-overload]

lat = [77.7, 83.2, 76]
ds = Dataset(
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,7 @@ def test_extension_array_attr():
assert (roundtripped == wrapped).all()

interval_array = pd.arrays.IntervalArray.from_breaks([0, 1, 2, 3], closed="right")
wrapped = PandasExtensionArray(interval_array)
# pandas-stubs types PandasExtensionArray too narrowly; IntervalArray is valid
wrapped = PandasExtensionArray(interval_array) # type: ignore[arg-type]
assert_array_equal(wrapped.left, interval_array.left, strict=True)
assert wrapped.closed == interval_array.closed
2 changes: 1 addition & 1 deletion xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def test__infer_interval_breaks(self) -> None:
[-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10])
)
assert_array_equal(
pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), # type: ignore[operator]
pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"),
_infer_interval_breaks(pd.date_range("20000101", periods=3)),
)

Expand Down
Loading