Skip to content

Commit

Permalink
fix(core): make table normalization 9.2-compatible
Browse files Browse the repository at this point in the history
Co-authored-by: Jim Crist-Harif <[email protected]>
  • Loading branch information
deepyaman and jcrist committed Jul 24, 2024
1 parent 043272b commit d72c5de
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
11 changes: 7 additions & 4 deletions ibis_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _(X, y=None, maintain_order=False):
)

if set(y.columns).intersection(X.columns):
raise ValueError("X and y must not share column names")
raise ValueError("`X` and `y` must not share column names")

X_op = X.op()
y_op = y.op()
Expand All @@ -110,10 +110,13 @@ def _(X, y=None, maintain_order=False):
# >>> X = parent[cols]
# >>> y = parent[single_or_multiple_cols]
# >>> table = parent[cols + single_or_multiple_cols]
supported_ops = tuple(
cls for name in ["Project", "DropColumns"] if (cls := getattr(ops, name, None))
)
if (
hasattr(ops, "Project")
and isinstance(X_op, ops.Project)
and isinstance(y_op, ops.Project)
supported_ops
and isinstance(X_op, supported_ops)
and isinstance(y_op, supported_ops)
and X_op.parent is y_op.parent
):
# ibis 9.0
Expand Down
9 changes: 6 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,12 @@ def test_can_use_in_sklearn_pipeline():
@pytest.mark.parametrize(
"get_Xy",
[
pytest.param(lambda t: (t[["a", "b", "c"]], None), id="None"),
pytest.param(lambda t: (t[["a", "b", "c"]], t["y"]), id="column"),
pytest.param(lambda t: (t[["a", "b", "c"]], t[["y"]]), id="table"),
pytest.param(lambda t: (t[["a", "b", "c"]], None), id="Project-None"),
pytest.param(lambda t: (t[["a", "b", "c"]], t["y"]), id="Project-column"),
pytest.param(lambda t: (t[["a", "b", "c"]], t[["y"]]), id="Project-table"),
pytest.param(lambda t: (t.drop("y"), None), id="DropColumns-None"),
pytest.param(lambda t: (t.drop("y"), t["y"]), id="DropColumns-column"),
pytest.param(lambda t: (t.drop("y"), t[["y"]]), id="DropColumns-table"),
pytest.param(lambda t: (t, "y"), id="col-name"),
pytest.param(lambda t: (t, ["y"]), id="col-names"),
],
Expand Down

0 comments on commit d72c5de

Please sign in to comment.