From ee0ae0cd77917216c97cf7820f24023d5f885d5b Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Tue, 23 Apr 2024 11:36:21 -0800 Subject: [PATCH] test: add test for impure function correlation behavior Need to fix the a few broken cases. Related to https://github.com/ibis-project/ibis/issues/8921, trying to write down exactly what the expected behavior is. --- ibis/backends/tests/test_impure.py | 176 +++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 ibis/backends/tests/test_impure.py diff --git a/ibis/backends/tests/test_impure.py b/ibis/backends/tests/test_impure.py new file mode 100644 index 0000000000000..6aa7334f3406f --- /dev/null +++ b/ibis/backends/tests/test_impure.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import sys + +import pandas.testing as tm +import pytest + +import ibis +import ibis.common.exceptions as com +from ibis import _ +from ibis.backends.tests.errors import ( + PsycoPg2InternalError, + Py4JJavaError, + PyDruidProgrammingError, +) + +no_randoms = [ + pytest.mark.notimpl( + ["dask", "pandas", "polars"], raises=com.OperationNotDefinedError + ), + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function random() does not exist", + ), +] + +no_udfs = [ + pytest.mark.notyet(["datafusion"], raises=NotImplementedError), + pytest.mark.notimpl( + [ + "bigquery", + "clickhouse", + "dask", + "druid", + "exasol", + "impala", + "mssql", + "mysql", + "oracle", + "pandas", + "trino", + "risingwave", + ] + ), + pytest.mark.notimpl("pyspark", reason="only supports pandas UDFs"), + pytest.mark.broken( + ["flink"], + condition=sys.version_info >= (3, 11), + raises=Py4JJavaError, + reason="Docker image has Python 3.10, results in `cloudpickle` version mismatch", + ), +] + +no_uuids = [ + pytest.mark.notimpl( + [ + "druid", + "exasol", + "oracle", + "polars", + "pyspark", + "risingwave", + "pandas", + "dask", + ], + raises=com.OperationNotDefinedError, + ), + pytest.mark.broken("mssql", reason="Unrelated bug: Incorrect syntax near '('"), +] + + +@ibis.udf.scalar.python(side_effects=True) +def my_random(x: float) -> float: + # need to make the whole UDF self-contained for postgres to work + import random + + return random.random() + + +mark_impures = pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda _: ibis.random(), + marks=no_randoms, + id="random", + ), + pytest.param( + lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0), + marks=[ + *no_uuids, + pytest.mark.broken(["impala"], reason="instances are uncorrelated"), + ], + id="uuid", + ), + pytest.param( + lambda table: my_random(table.float_col), + marks=[ + *no_udfs, + # once this is fixed, can we unify these params with the other params? + pytest.mark.broken(["postgres"], reason="instances are uncorrelated"), + ], + id="udf", + ), + ], +) + + +@pytest.mark.broken("sqlite", reason="instances are uncorrelated") +@mark_impures +def test_impure_correlated(alltypes, impure): + df = ( + alltypes.select(common=impure(alltypes)) + .select(x=_.common, y=_.common) + .execute() + ) + tm.assert_series_equal(df.x, df.y, check_names=False) + + +@pytest.mark.broken("sqlite", reason="instances are uncorrelated") +@mark_impures +def test_chained_selections(alltypes, impure): + # https://github.com/ibis-project/ibis/issues/8921#issue-2234327722 + t = alltypes.mutate(num=impure(alltypes)) + t = t.mutate(isbig=(t.num > 0.5)) + df = t.select("num", "isbig").execute() + df["expected"] = df.num > 0.5 + tm.assert_series_equal(df.isbig, df.expected, check_names=False) + + +@pytest.mark.broken(["clickhouse"], reason="instances are correlated") +@pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda _: ibis.random(), + marks=[ + *no_randoms, + pytest.mark.broken( + ["impala", "trino"], reason="instances are correlated" + ), + ], + id="random", + ), + pytest.param( + # make this a float so we can compare to .5 + lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0), + marks=[ + *no_uuids, + pytest.mark.broken( + ["mysql", "trino"], reason="instances are correlated" + ), + ], + id="uuid", + ), + pytest.param( + lambda table: my_random(table.float_col), + # once this is fixed, can we unify these params with the other params? + marks=[ + *no_udfs, + pytest.mark.broken(["duckdb"], reason="instances are correlated"), + ], + id="udf", + ), + ], +) +def test_impure_uncorrelated(alltypes, impure): + df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute() + assert (df.x == df.y).mean() < 1 + # Even if the two expressions have the exact same ID, they should still be + # uncorrelated + common = impure(alltypes) + df = alltypes.select(x=common, y=common).execute() + assert (df.x == df.y).mean() < 1