From 7700941e63f80dc3ba8c675693d5594ed893a8f1 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 18 Apr 2024 16:27:15 -0800 Subject: [PATCH] test: add test for impure function correlation behavior Need to fix the a few broken cases. Related to https://github.com/ibis-project/ibis/issues/8921, trying to write down exactly what the expected behavior is. --- ibis/backends/tests/test_impure.py | 135 +++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 ibis/backends/tests/test_impure.py diff --git a/ibis/backends/tests/test_impure.py b/ibis/backends/tests/test_impure.py new file mode 100644 index 0000000000000..920dbd2e166ec --- /dev/null +++ b/ibis/backends/tests/test_impure.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import pandas.testing as tm +import pytest + +import ibis +import ibis.common.exceptions as com +from ibis import _ +from ibis.backends.tests.errors import ( + PsycoPg2InternalError, + PyDruidProgrammingError, +) + +no_random = [ + pytest.mark.notimpl( + ["dask", "pandas", "polars"], raises=com.OperationNotDefinedError + ), + pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), + pytest.mark.notimpl( + ["risingwave"], + raises=PsycoPg2InternalError, + reason="function random() does not exist", + ), +] + +no_udfs = pytest.mark.notimpl( + [ + "bigquery", + "clickhouse", + "dask", + "druid", + "exasol", + "impala", + "mssql", + "mysql", + "oracle", + "pandas", + "trino", + "risingwave", + ] +) + +no_uuid = pytest.mark.notimpl( + ["druid", "exasol", "oracle", "polars", "pyspark", "risingwave", "pandas", "dask"], + raises=com.OperationNotDefinedError, +) + + +@ibis.udf.scalar.python(side_effects=True) +def my_random(x: float) -> float: + # need to make the whole UDF self-contained for postgres to work + import random + + return random.random() + + +mark_impures = pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda _: ibis.random(), + marks=no_random, + id="random", + ), + pytest.param( + lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0), + marks=no_uuid, + id="uuid", + ), + pytest.param( + lambda table: my_random(table.float_col), + marks=[ + # once this is fixed, can we unify these params with the other params? + pytest.mark.broken("postgres", reason="instances are uncorrelated"), + no_udfs, + ], + id="udf", + ), + ], +) + + +@mark_impures +def test_impure_correlated(alltypes, impure): + df = ( + alltypes.select(common=impure(alltypes)) + .select(x=_.common, y=_.common) + .execute() + ) + tm.assert_series_equal(df.x, df.y, check_names=False) + + +@mark_impures +def test_chained_selections(alltypes, impure): + # https://github.com/ibis-project/ibis/issues/8921#issue-2234327722 + t = alltypes.mutate(num=impure(alltypes)) + t = t.mutate(isbig=(t.num > 0.5)) + df = t.select("num", "isbig").execute() + df["expected"] = df.num > 0.5 + tm.assert_series_equal(df.isbig, df.expected, check_names=False) + + +@pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda _: ibis.random(), + marks=no_random, + id="random", + ), + pytest.param( + # make this a float so we can compare to .5 + lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0), + marks=no_uuid, + id="uuid", + ), + pytest.param( + lambda table: my_random(table.float_col), + # once this is fixed, can we unify these params with the other params? + marks=[ + pytest.mark.broken("duckdb", reason="instances are correlated"), + no_udfs, + ], + id="udf", + ), + ], +) +def test_impure_uncorrelated(alltypes, impure): + df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute() + assert (df.x == df.y).mean() < 1 + # Even if the two expressions have the exact same ID, they should still be + # uncorrelated + common = impure(alltypes) + df = alltypes.select(x=common, y=common).execute() + assert (df.x == df.y).mean() < 1