From d72c5de26107b54036988fcb721ba48d7036b09d Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 23 Jul 2024 13:41:07 -0600 Subject: [PATCH] fix(core): make table normalization 9.2-compatible Co-authored-by: Jim Crist-Harif --- ibis_ml/core.py | 11 +++++++---- tests/test_core.py | 9 ++++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ibis_ml/core.py b/ibis_ml/core.py index e69c8e1..72f314a 100644 --- a/ibis_ml/core.py +++ b/ibis_ml/core.py @@ -101,7 +101,7 @@ def _(X, y=None, maintain_order=False): ) if set(y.columns).intersection(X.columns): - raise ValueError("X and y must not share column names") + raise ValueError("`X` and `y` must not share column names") X_op = X.op() y_op = y.op() @@ -110,10 +110,13 @@ def _(X, y=None, maintain_order=False): # >>> X = parent[cols] # >>> y = parent[single_or_multiple_cols] # >>> table = parent[cols + single_or_multiple_cols] + supported_ops = tuple( + cls for name in ["Project", "DropColumns"] if (cls := getattr(ops, name, None)) + ) if ( - hasattr(ops, "Project") - and isinstance(X_op, ops.Project) - and isinstance(y_op, ops.Project) + supported_ops + and isinstance(X_op, supported_ops) + and isinstance(y_op, supported_ops) and X_op.parent is y_op.parent ): # ibis 9.0 diff --git a/tests/test_core.py b/tests/test_core.py index 85412db..b16cf59 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -223,9 +223,12 @@ def test_can_use_in_sklearn_pipeline(): @pytest.mark.parametrize( "get_Xy", [ - pytest.param(lambda t: (t[["a", "b", "c"]], None), id="None"), - pytest.param(lambda t: (t[["a", "b", "c"]], t["y"]), id="column"), - pytest.param(lambda t: (t[["a", "b", "c"]], t[["y"]]), id="table"), + pytest.param(lambda t: (t[["a", "b", "c"]], None), id="Project-None"), + pytest.param(lambda t: (t[["a", "b", "c"]], t["y"]), id="Project-column"), + pytest.param(lambda t: (t[["a", "b", "c"]], t[["y"]]), id="Project-table"), + pytest.param(lambda t: (t.drop("y"), None), id="DropColumns-None"), + pytest.param(lambda t: (t.drop("y"), t["y"]), id="DropColumns-column"), + pytest.param(lambda t: (t.drop("y"), t[["y"]]), id="DropColumns-table"), pytest.param(lambda t: (t, "y"), id="col-name"), pytest.param(lambda t: (t, ["y"]), id="col-names"), ],