Skip to content

Commit

Permalink
fix(python): Address incorrect selector & col expansion (#19742)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Nov 13, 2024
1 parent 6808bd8 commit 37ae8e7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
30 changes: 11 additions & 19 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,6 @@ def __and__(self, other: Any) -> Expr: ...
def __and__(self, other: Any) -> SelectorType | Expr:
if is_column(other):
colname = other.meta.output_name()
if self._attrs["name"] == "by_name" and (
params := self._attrs["params"]
).get("require_all", True):
return by_name(*params["*names"], colname)
other = by_name(colname)
if is_selector(other):
return _selector_proxy_(
Expand All @@ -399,6 +395,12 @@ def __and__(self, other: Any) -> SelectorType | Expr:
else:
return self.as_expr().__and__(other)

def __rand__(self, other: Any) -> Expr:
if is_column(other):
colname = other.meta.output_name()
return by_name(colname) & self
return self.as_expr().__rand__(other)

@overload
def __or__(self, other: SelectorType) -> SelectorType: ...

Expand All @@ -417,6 +419,11 @@ def __or__(self, other: Any) -> SelectorType | Expr:
else:
return self.as_expr().__or__(other)

def __ror__(self, other: Any) -> Expr:
if is_column(other):
other = by_name(other.meta.output_name())
return self.as_expr().__ror__(other)

@overload
def __xor__(self, other: SelectorType) -> SelectorType: ...

Expand All @@ -435,21 +442,6 @@ def __xor__(self, other: Any) -> SelectorType | Expr:
else:
return self.as_expr().__or__(other)

def __rand__(self, other: Any) -> Expr:
if is_column(other):
colname = other.meta.output_name()
if self._attrs["name"] == "by_name" and (
params := self._attrs["params"]
).get("require_all", True):
return by_name(colname, *params["*names"])
other = by_name(colname)
return self.as_expr().__rand__(other)

def __ror__(self, other: Any) -> Expr:
if is_column(other):
other = by_name(other.meta.output_name())
return self.as_expr().__ror__(other)

def __rxor__(self, other: Any) -> Expr:
if is_column(other):
other = by_name(other.meta.output_name())
Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,17 @@ def test_selector_by_name(df: pl.DataFrame) -> None:

# check "by_name & col"
for selector_expr, expected in (
(cs.by_name("abc", "cde") & pl.col("ghi"), ["abc", "cde", "ghi"]),
(pl.col("ghi") & cs.by_name("cde", "abc"), ["ghi", "cde", "abc"]),
(cs.by_name("abc", "cde") & pl.col("ghi"), []),
(cs.by_name("abc", "cde") & pl.col("cde"), ["cde"]),
(pl.col("cde") & cs.by_name("cde", "abc"), ["cde"]),
):
assert df.select(selector_expr).columns == expected

# check "by_name & by_name"
assert df.select(
cs.by_name("abc", "cde", "def", "eee") & cs.by_name("cde", "eee", "fgg")
).columns == ["cde", "eee"]

# expected errors
with pytest.raises(ColumnNotFoundError, match="xxx"):
df.select(cs.by_name("xxx", "fgg", "!!!"))
Expand Down

0 comments on commit 37ae8e7

Please sign in to comment.