From 2afcaee9e27d110de5a9d8fc780075ac52023fac Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Tue, 23 Apr 2024 11:36:21 -0800 Subject: [PATCH] test: add test for impure function correlation behavior Need to fix the a few broken cases. Related to https://github.com/ibis-project/ibis/issues/8921, trying to write down exactly what the expected behavior is. --- ibis/backends/tests/test_impure.py | 156 +++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 ibis/backends/tests/test_impure.py diff --git a/ibis/backends/tests/test_impure.py b/ibis/backends/tests/test_impure.py new file mode 100644 index 0000000000000..93eced0060ee4 --- /dev/null +++ b/ibis/backends/tests/test_impure.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import pandas.testing as tm +import pytest + +import ibis +import ibis.common.exceptions as com +from ibis import _ +from ibis.backends.tests.errors import ( + PsycoPg2InternalError, + PyDruidProgrammingError, +) + +no_random = [ + pytest.mark.notimpl( + ["dask", "pandas", "polars"], raises=com.OperationNotDefinedError + ), + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function random() does not exist", + ), +] + +no_udfs = [ + pytest.mark.notyet(["datafusion"], raises=NotImplementedError), + pytest.mark.notimpl( + [ + "bigquery", + "clickhouse", + "dask", + "druid", + "exasol", + "impala", + "mssql", + "mysql", + "oracle", + "pandas", + "trino", + "risingwave", + ] + ), +] + +no_uuid = pytest.mark.notimpl( + ["druid", "exasol", "oracle", "polars", "pyspark", "risingwave", "pandas", "dask"], + raises=com.OperationNotDefinedError, +) + + +@ibis.udf.scalar.python(side_effects=True) +def my_random(x: float) -> float: + # need to make the whole UDF self-contained for postgres to work + import random + + return random.random() + + +mark_impures = pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda _: ibis.random(), + marks=no_random, + id="random", + ), + pytest.param( + lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0), + marks=[ + pytest.mark.broken(["impala"], reason="instances are uncorrelated"), + *no_uuid, + ], + id="uuid", + ), + pytest.param( + lambda table: my_random(table.float_col), + marks=[ + # once this is fixed, can we unify these params with the other params? + pytest.mark.broken( + ["flink", "postgres"], reason="instances are uncorrelated" + ), + *no_udfs, + ], + id="udf", + ), + ], +) + + +@pytest.mark.broken("sqlite", reason="instances are uncorrelated") +@mark_impures +def test_impure_correlated(alltypes, impure): + df = ( + alltypes.select(common=impure(alltypes)) + .select(x=_.common, y=_.common) + .execute() + ) + tm.assert_series_equal(df.x, df.y, check_names=False) + + +@pytest.mark.broken("sqlite", reason="instances are uncorrelated") +@mark_impures +def test_chained_selections(alltypes, impure): + # https://github.com/ibis-project/ibis/issues/8921#issue-2234327722 + t = alltypes.mutate(num=impure(alltypes)) + t = t.mutate(isbig=(t.num > 0.5)) + df = t.select("num", "isbig").execute() + df["expected"] = df.num > 0.5 + tm.assert_series_equal(df.isbig, df.expected, check_names=False) + + +@pytest.mark.broken(["clickhouse"], reason="instances are correlated") +@pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda _: ibis.random(), + marks=[ + *no_random, + pytest.mark.broken( + ["impala", "trino"], reason="instances are correlated" + ), + ], + id="random", + ), + pytest.param( + # make this a float so we can compare to .5 + lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0), + marks=[ + no_uuid, + pytest.mark.broken( + ["mssql", "trino"], reason="instances are correlated" + ), + ], + id="uuid", + ), + pytest.param( + lambda table: my_random(table.float_col), + # once this is fixed, can we unify these params with the other params? + marks=[ + *no_udfs, + pytest.mark.broken(["duckdb"], reason="instances are correlated"), + ], + id="udf", + ), + ], +) +def test_impure_uncorrelated(alltypes, impure): + df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute() + assert (df.x == df.y).mean() < 1 + # Even if the two expressions have the exact same ID, they should still be + # uncorrelated + common = impure(alltypes) + df = alltypes.select(x=common, y=common).execute() + assert (df.x == df.y).mean() < 1