Skip to content

Commit

Permalink
test: add test for impure function correlation behavior
Browse files Browse the repository at this point in the history
Need to fix the a few broken cases.

Related to ibis-project#8921,
trying to write down exactly what the expected behavior is.
  • Loading branch information
NickCrews committed Apr 23, 2024
1 parent 9355281 commit 37ae5dd
Showing 1 changed file with 176 additions and 0 deletions.
176 changes: 176 additions & 0 deletions ibis/backends/tests/test_impure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from __future__ import annotations

import sys

import pandas.testing as tm
import pytest

import ibis
import ibis.common.exceptions as com
from ibis import _
from ibis.backends.tests.errors import (
PsycoPg2InternalError,
Py4JJavaError,
PyDruidProgrammingError,
)

no_randoms = [
pytest.mark.notimpl(
["dask", "pandas", "polars"], raises=com.OperationNotDefinedError
),
pytest.mark.notimpl("druid", raises=PyDruidProgrammingError),
pytest.mark.notimpl(
"risingwave",
raises=PsycoPg2InternalError,
reason="function random() does not exist",
),
]

no_udfs = [
pytest.mark.notyet("datafusion", raises=NotImplementedError),
pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"dask",
"druid",
"exasol",
"impala",
"mssql",
"mysql",
"oracle",
"pandas",
"trino",
"risingwave",
]
),
pytest.mark.notimpl("pyspark", reason="only supports pandas UDFs"),
pytest.mark.broken(
"flink",
condition=sys.version_info >= (3, 11),
raises=Py4JJavaError,
reason="Docker image has Python 3.10, results in `cloudpickle` version mismatch",
),
]

no_uuids = [
pytest.mark.notimpl(
[
"druid",
"exasol",
"oracle",
"polars",
"pyspark",
"risingwave",
"pandas",
"dask",
],
raises=com.OperationNotDefinedError,
),
pytest.mark.broken("mssql", reason="Unrelated bug: Incorrect syntax near '('"),
]


@ibis.udf.scalar.python(side_effects=True)
def my_random(x: float) -> float:
# need to make the whole UDF self-contained for postgres to work
import random

return random.random()


mark_impures = pytest.mark.parametrize(
"impure",
[
pytest.param(
lambda _: ibis.random(),
marks=no_randoms,
id="random",
),
pytest.param(
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
marks=[
*no_uuids,
pytest.mark.broken("impala", reason="instances are uncorrelated"),
],
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
marks=[
*no_udfs,
pytest.mark.broken(
["flink", "postgres"], reason="instances are uncorrelated"
),
],
id="udf",
),
],
)


@pytest.mark.broken("sqlite", reason="instances are uncorrelated")
@mark_impures
def test_impure_correlated(alltypes, impure):
df = (
alltypes.select(common=impure(alltypes))
.select(x=_.common, y=_.common)
.execute()
)
tm.assert_series_equal(df.x, df.y, check_names=False)


@pytest.mark.broken("sqlite", reason="instances are uncorrelated")
@mark_impures
def test_chained_selections(alltypes, impure):
# https://github.com/ibis-project/ibis/issues/8921#issue-2234327722
t = alltypes.mutate(num=impure(alltypes))
t = t.mutate(isbig=(t.num > 0.5))
df = t.select("num", "isbig").execute()
df["expected"] = df.num > 0.5
tm.assert_series_equal(df.isbig, df.expected, check_names=False)


@pytest.mark.broken(["clickhouse"], reason="instances are correlated")
@pytest.mark.parametrize(
"impure",
[
pytest.param(
lambda _: ibis.random(),
marks=[
*no_randoms,
pytest.mark.broken(
["impala", "trino"], reason="instances are correlated"
),
],
id="random",
),
pytest.param(
# make this a float so we can compare to .5
lambda _: ibis.uuid().cast(str).contains("a").ifelse(1, 0),
marks=[
*no_uuids,
pytest.mark.broken(
["mysql", "trino"], reason="instances are correlated"
),
],
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
marks=[
*no_udfs,
pytest.mark.broken("duckdb", reason="instances are correlated"),
],
id="udf",
),
],
)
def test_impure_uncorrelated(alltypes, impure):
df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute()
assert (df.x == df.y).mean() < 1
# Even if the two expressions have the exact same ID, they should still be
# uncorrelated
common = impure(alltypes)
df = alltypes.select(x=common, y=common).execute()
assert (df.x == df.y).mean() < 1

0 comments on commit 37ae5dd

Please sign in to comment.