Skip to content

Commit 143bab4

Browse files
Fix: proper return types for MultiIndex.swaplevel & MultiIndex.union (#1437)
* Fix: proper return types for MultiIndex.swaplevel & MultiIndex.union * Index of and List of tuples in MultiIndex overriding * MultiIndex union tests * MultiIndex-Index union overload simplified * MultiIndex test amended and swaplevel test added * Simplified MultiIndex-Index union overload * Reverted S1 changes * amended MultiIndex union test and added swaplevel test * Removed runtime-failing test * Pre-commit refactoring * Removed a MultiIndex-Index test * Removed MultiIndex union overloading * Separated swaplevel and union tests Co-authored-by: Yi-Fan Wang <[email protected]> * test for MultiIndex-Index union * MultiIndex.union overload replaced with change of base return type * removed MultiIndex.union overload * Removed extra newline Co-authored-by: Yi-Fan Wang <[email protected]> * Added py.typed * existing tests fixed * tests fixed for runtime errors --------- Co-authored-by: Yi-Fan Wang <[email protected]>
1 parent 95ce46c commit 143bab4

File tree

4 files changed

+26
-5
lines changed

4 files changed

+26
-5
lines changed

pandas-stubs/core/indexes/base.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
405405
__bool__ = ...
406406
def union(
407407
self, other: list[HashableT] | Self, sort: bool | None = None
408-
) -> Index: ...
408+
) -> Self: ...
409409
def intersection(
410410
self, other: list[S1] | Self, sort: bool | None = False
411411
) -> Self: ...

pandas-stubs/core/indexes/multi.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class MultiIndex(Index):
135135
def append(self, other): ...
136136
def repeat(self, repeats, axis=...): ...
137137
def drop(self, codes, level: Level | None = None, errors: str = "raise") -> Self: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
138-
def swaplevel(self, i: int = -2, j: int = -1): ...
138+
def swaplevel(self, i: int = -2, j: int = -1) -> Self: ...
139139
def reorder_levels(self, order): ...
140140
def sortlevel(
141141
self,

pandas-stubs/core/indexes/range.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class RangeIndex(_IndexSubclassBase[int, np.int64]):
8080
def all(self, *args: Any, **kwargs: Any) -> bool: ...
8181
def any(self, *args: Any, **kwargs: Any) -> bool: ...
8282
@final
83-
def union(
83+
def union( # type: ignore[override]
8484
self, other: list[HashableT] | Index, sort: bool | None = None
8585
) -> Index | Index[int] | RangeIndex: ...
8686
@overload # type: ignore[override]

tests/indexes/test_indexes.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,18 @@ def test_range_index_union() -> None:
302302
def test_index_union_sort() -> None:
303303
"""Test sort argument in pd.Index.union GH1264."""
304304
check(
305-
assert_type(pd.Index(["e", "f"]).union(["a", "b", "c"], sort=True), pd.Index),
305+
assert_type(
306+
pd.Index(["e", "f"]).union(["a", "b", "c"], sort=True), "pd.Index[str]"
307+
),
306308
pd.Index,
309+
str,
307310
)
308311
check(
309-
assert_type(pd.Index(["e", "f"]).union(["a", "b", "c"], sort=False), pd.Index),
312+
assert_type(
313+
pd.Index(["e", "f"]).union(["a", "b", "c"], sort=False), "pd.Index[str]"
314+
),
310315
pd.Index,
316+
str,
311317
)
312318

313319

@@ -1520,3 +1526,18 @@ def test_to_series() -> None:
15201526
np.complexfloating,
15211527
)
15221528
check(assert_type(Index(["1"]).to_series(), "pd.Series[str]"), pd.Series, str)
1529+
1530+
1531+
def test_multiindex_union() -> None:
1532+
"""Test that MultiIndex.union returns MultiIndex"""
1533+
mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"])
1534+
mi2 = pd.MultiIndex.from_product([["a", "b"], [3, 4]], names=["let", "num"])
1535+
1536+
check(assert_type(mi.union(mi2), "pd.MultiIndex"), pd.MultiIndex)
1537+
check(assert_type(mi.union([("c", 3), ("d", 4)]), "pd.MultiIndex"), pd.MultiIndex)
1538+
1539+
1540+
def test_multiindex_swaplevel() -> None:
1541+
"""Test that MultiIndex.swaplevel returns MultiIndex"""
1542+
mi = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["let", "num"])
1543+
check(assert_type(mi.swaplevel(0, 1), "pd.MultiIndex"), pd.MultiIndex)

0 commit comments

Comments
 (0)