Skip to content

Commit 28ffc54

Browse files
max-sixtyclaude
andcommitted
Fix mypy errors with newer pandas-stubs
Remove the pandas-stubs version pin (<=2.2.3.241126) and fix the mypy errors that appear with newer versions. This resolves #10110. Changes by category: 1. Added type: ignore comments with explanatory notes for cases where pandas-stubs is stricter than actual pandas behavior: - CFTimeIndex.__add__/__radd__ return Self instead of overloaded types - Index.get_indexer accepts ndarray/list, not just Index - CategoricalIndex.remove_unused_categories missing from stubs - Series.where accepts broader argument types - ExtensionArray.astype accepts ExtensionDtype - Series[datetime].__setitem__ accepts np.nan (converts to NaT) - MultiIndex.rename accepts list of names 2. Added explicit type annotations to help mypy infer correct types: - coordinates.py: codes as list[np.ndarray] - dataset.py: arrays and extension_arrays list types - indexes.py: xr_index variable with PandasIndex | PandasMultiIndex 3. Removed redundant casts and fixed variable shadowing: - Removed unnecessary cast in remove_unused_levels_categories - Converted dict_keys to list for reorder_levels - Fixed test_backends.py variable naming (tdf -> tdf_series) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 18ebe98 commit 28ffc54

File tree

11 files changed

+41
-27
lines changed

11 files changed

+41
-27
lines changed

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ mypy = "==1.18.1"
255255
pyright = "*"
256256
hypothesis = "*"
257257
lxml = "*"
258-
pandas-stubs = "<=2.2.3.241126" # https://github.com/pydata/xarray/issues/10110
258+
pandas-stubs = ">=2.3.3.251219"
259259
types-colorama = "*"
260260
types-docutils = "*"
261261
types-psutil = "*"

xarray/coding/cftimeindex.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,14 @@ def shift( # type: ignore[override,unused-ignore]
514514
f"'freq' must be of type str or datetime.timedelta, got {type(freq)}."
515515
)
516516

517-
def __add__(self, other) -> Self:
517+
# pandas-stubs defines many overloads for Index.__add__/__radd__ with specific
518+
# return types, but CFTimeIndex legitimately returns Self for all cases
519+
def __add__(self, other) -> Self: # type: ignore[override]
518520
if isinstance(other, pd.TimedeltaIndex):
519521
other = other.to_pytimedelta()
520522
return type(self)(np.array(self) + other)
521523

522-
def __radd__(self, other) -> Self:
524+
def __radd__(self, other) -> Self: # type: ignore[override]
523525
if isinstance(other, pd.TimedeltaIndex):
524526
other = other.to_pytimedelta()
525527
return type(self)(other + np.array(self))

xarray/core/coordinates.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index:
169169

170170
for i, index in enumerate(indexes):
171171
if isinstance(index, pd.MultiIndex):
172-
codes, levels = index.codes, index.levels
172+
codes: list[np.ndarray] = list(index.codes)
173+
levels = index.levels
173174
else:
174175
code, level = pd.factorize(index)
175176
codes = [code]

xarray/core/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7402,8 +7402,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
74027402
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
74037403
)
74047404

7405-
arrays = []
7406-
extension_arrays = []
7405+
arrays: list[tuple[Hashable, np.ndarray]] = []
7406+
extension_arrays: list[tuple[Hashable, pd.Series]] = []
74077407
for k, v in dataframe.items():
74087408
if not is_allowed_extension_array(v) or isinstance(
74097409
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES

xarray/core/extension_array.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def as_extension_array(
104104
[array_or_scalar], dtype=dtype
105105
)
106106
else:
107-
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr]
107+
# pandas-stubs is overly strict about astype's dtype parameter and return type;
108+
# ExtensionArray.astype accepts ExtensionDtype and returns ExtensionArray
109+
return array_or_scalar.astype(dtype, copy=copy) # type: ignore[union-attr,return-value,arg-type]
108110

109111

110112
@implements(np.result_type)
@@ -192,10 +194,11 @@ def __extension_duck_array__where(
192194
# pd.where won't broadcast 0-dim arrays across a scalar-like series; scalar y's must be preserved
193195
if hasattr(y, "shape") and len(y.shape) == 1 and y.shape[0] == 1:
194196
y = y[0] # type: ignore[index]
195-
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type]
197+
# pandas-stubs has strict overloads for Series.where that don't cover all valid arg types
198+
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type,call-overload]
196199

197200

198-
def _replace_duck(args, replacer: Callable[[PandasExtensionArray], list]) -> list:
201+
def _replace_duck(args, replacer: Callable[[PandasExtensionArray], Any]) -> list:
199202
args_as_list = list(args)
200203
for index, value in enumerate(args_as_list):
201204
if isinstance(value, PandasExtensionArray):

xarray/core/indexes.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.n
628628
flat_labels = np.ravel(labels)
629629
if flat_labels.dtype == "float16":
630630
flat_labels = flat_labels.astype("float64")
631-
flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)
631+
# pandas-stubs expects Index for get_indexer, but ndarray works at runtime
632+
flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance) # type: ignore[arg-type]
632633
indexer = flat_indexer.reshape(labels.shape)
633634
return indexer
634635

@@ -978,14 +979,15 @@ def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex:
978979
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
979980
"""
980981
if isinstance(index, pd.MultiIndex):
981-
new_index = cast(pd.MultiIndex, index.remove_unused_levels())
982+
new_index = index.remove_unused_levels()
982983
# if it contains CategoricalIndex, we need to remove unused categories
983984
# manually. See https://github.com/pandas-dev/pandas/issues/30846
984985
if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels):
985986
levels = []
986987
for i, level in enumerate(new_index.levels):
987988
if isinstance(level, pd.CategoricalIndex):
988-
level = level[new_index.codes[i]].remove_unused_categories()
989+
# pandas-stubs is missing remove_unused_categories on CategoricalIndex
990+
level = level[new_index.codes[i]].remove_unused_categories() # type: ignore[attr-defined]
989991
else:
990992
level = level[new_index.codes[i]]
991993
levels.append(level)
@@ -1229,7 +1231,7 @@ def reorder_levels(
12291231
its corresponding coordinates.
12301232
12311233
"""
1232-
index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys()))
1234+
index = self.index.reorder_levels(list(level_variables.keys()))
12331235
level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names}
12341236
return self._replace(index, level_coords_dtype=level_coords_dtype)
12351237

@@ -1378,28 +1380,29 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult:
13781380
indexer = DataArray(indexer, coords=coords, dims=label.dims)
13791381

13801382
if new_index is not None:
1383+
xr_index: PandasIndex | PandasMultiIndex
13811384
if isinstance(new_index, pd.MultiIndex):
13821385
level_coords_dtype = {
13831386
k: self.level_coords_dtype[k] for k in new_index.names
13841387
}
1385-
new_index = self._replace(
1388+
xr_index = self._replace(
13861389
new_index, level_coords_dtype=level_coords_dtype
13871390
)
13881391
dims_dict = {}
1389-
drop_coords = []
1392+
drop_coords: list[Hashable] = []
13901393
else:
1391-
new_index = PandasIndex(
1394+
xr_index = PandasIndex(
13921395
new_index,
13931396
new_index.name,
13941397
coord_dtype=self.level_coords_dtype[new_index.name],
13951398
)
1396-
dims_dict = {self.dim: new_index.index.name}
1399+
dims_dict = {self.dim: xr_index.index.name}
13971400
drop_coords = [self.dim]
13981401

13991402
# variable(s) attrs and encoding metadata are propagated
14001403
# when replacing the indexes in the resulting xarray object
1401-
new_vars = new_index.create_variables()
1402-
indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, new_index))
1404+
new_vars = xr_index.create_variables()
1405+
indexes = cast(dict[Any, Index], dict.fromkeys(new_vars, xr_index))
14031406

14041407
# add scalar variable for each dropped level
14051408
variables = new_vars

xarray/core/missing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,8 @@ def _localize(obj: T, indexes_coords: SourceDest) -> tuple[T, SourceDest]:
588588
minval = np.nanmin(new_x_loaded)
589589
maxval = np.nanmax(new_x_loaded)
590590
index = x.to_index()
591-
imin, imax = index.get_indexer([minval, maxval], method="nearest")
591+
# pandas-stubs expects Index for get_indexer, but list works at runtime
592+
imin, imax = index.get_indexer([minval, maxval], method="nearest") # type: ignore[arg-type]
592593
indexes[dim] = slice(max(imin - 2, 0), imax + 2)
593594
indexes_coords[dim] = (x[indexes[dim]], new_x)
594595
return obj.isel(indexes), indexes_coords # type: ignore[attr-defined]

xarray/tests/test_backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7500,10 +7500,10 @@ def test_write_file_from_np_str(str_type: type[str | np.str_], tmpdir: str) -> N
75007500
)
75017501
tdf.index.name = "scenario"
75027502
tdf.columns.name = "year"
7503-
tdf = cast(pd.DataFrame, tdf.stack())
7504-
tdf.name = "tas"
7503+
tdf_series = cast(pd.Series, tdf.stack())
7504+
tdf_series.name = "tas"
75057505

7506-
txr = tdf.to_xarray()
7506+
txr = tdf_series.to_xarray()
75077507

75087508
txr.to_netcdf(tmpdir.join("test.nc"))
75097509

xarray/tests/test_dataarray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3931,7 +3931,8 @@ def test_to_and_from_dict_with_nan_nat(self) -> None:
39313931
y = np.random.randn(10, 3)
39323932
y[2] = np.nan
39333933
t = pd.Series(pd.date_range("20130101", periods=10))
3934-
t[2] = np.nan
3934+
# pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT
3935+
t[2] = np.nan # type: ignore[call-overload]
39353936
lat = [77.7, 83.2, 76]
39363937
da = DataArray(y, {"t": t, "lat": lat}, dims=["t", "lat"])
39373938
roundtripped = DataArray.from_dict(da.to_dict())

xarray/tests/test_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3497,7 +3497,8 @@ def test_rename_multiindex(self) -> None:
34973497
midx_coords = Coordinates.from_pandas_multiindex(midx, "x")
34983498
original = Dataset({}, midx_coords)
34993499

3500-
midx_renamed = midx.rename(["a", "c"])
3500+
# pandas-stubs expects Hashable for rename, but list of names works for MultiIndex
3501+
midx_renamed = midx.rename(["a", "c"]) # type: ignore[call-overload]
35013502
midx_coords_renamed = Coordinates.from_pandas_multiindex(midx_renamed, "x")
35023503
expected = Dataset({}, midx_coords_renamed)
35033504

@@ -5602,7 +5603,8 @@ def test_to_and_from_dict_with_nan_nat(
56025603
y = np.random.randn(10, 3)
56035604
y[2] = np.nan
56045605
t = pd.Series(pd.date_range("20130101", periods=10))
5605-
t[2] = np.nan
5606+
# pandas-stubs doesn't allow np.nan for datetime Series, but it converts to NaT
5607+
t[2] = np.nan # type: ignore[call-overload]
56065608

56075609
lat = [77.7, 83.2, 76]
56085610
ds = Dataset(

0 commit comments

Comments
 (0)