From 48d4a5156857e7ae4439394168d409563f01ada1 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 18 Apr 2024 11:47:41 -0800 Subject: [PATCH] test: add test for impure function correlation behavior Need to fix the UDF test case. Related to https://github.com/ibis-project/ibis/issues/8921, trying to write down exactly what the expected behavior is. --- ibis/backends/tests/test_impure.py | 67 ++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 ibis/backends/tests/test_impure.py diff --git a/ibis/backends/tests/test_impure.py b/ibis/backends/tests/test_impure.py new file mode 100644 index 0000000000000..10a694e1cae99 --- /dev/null +++ b/ibis/backends/tests/test_impure.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import pandas.testing as tm +import pytest + +import ibis +from ibis import _ + +_counter = 0 + + +@ibis.udf.scalar.python(side_effects=True) +def get_id() -> int: + global _counter # noqa: PLW0603 + _counter += 1 + return _counter + + +@pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda: ibis.random(), + id="random", + ), + pytest.param( + lambda: ibis.uuid(), + id="uuid", + ), + pytest.param( + get_id, + id="udf", + ), + ], +) +def test_impure_uncorrelated(alltypes, impure): + df = alltypes.select(x=impure(), y=impure()).execute() + assert (df.x != df.y).mean() >= 0.999 + # Even if the two expressions have the exact same ID, they should still be + # uncorrelated + common = impure() + df = alltypes.select(x=common, y=common).execute() + assert (df.x != df.y).mean() >= 0.999 + + +@pytest.mark.parametrize( + "impure", + [ + pytest.param( + lambda: ibis.random(), + id="random", + ), + pytest.param( + lambda: ibis.uuid(), + id="uuid", + ), + pytest.param( + get_id, + id="udf", + # once this is fixed, can we unify these params with the above params? + marks=pytest.mark.xfail(reason="executed multiple times"), + ), + ], +) +def test_impure_correlated(alltypes, impure): + df = alltypes.select(common=impure()).select(x=_.common, y=_.common).execute() + tm.assert_series_equal(df.x, df.y, check_names=False)