Skip to content

Commit

Permalink
feat(ir): make impure ibis.random() and ibis.uuid() functions return …
Browse files Browse the repository at this point in the history
…unique node instances
  • Loading branch information
kszucs committed Apr 15, 2024
1 parent f8370b1 commit 813dc52
Show file tree
Hide file tree
Showing 23 changed files with 157 additions and 24 deletions.
5 changes: 3 additions & 2 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.RPad: "rpad",
ops.Levenshtein: "edit_distance",
ops.Modulus: "mod",
ops.RandomScalar: "rand",
ops.RandomUUID: "generate_uuid",
ops.RegexReplace: "regexp_replace",
ops.RegexSearch: "regexp_contains",
ops.Time: "time",
Expand Down Expand Up @@ -698,3 +696,6 @@ def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_RandomUUID(self, op, **kwargs):
return self.f.generate_uuid()
8 changes: 6 additions & 2 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.NotNull: "isNotNull",
ops.NullIf: "nullIf",
ops.RStrip: "trimRight",
ops.RandomScalar: "randCanonical",
ops.RandomUUID: "generateUUIDv4",
ops.RegexReplace: "replaceRegexpAll",
ops.RowNumber: "row_number",
ops.StartsWith: "startsWith",
Expand Down Expand Up @@ -637,6 +635,12 @@ def visit_TimestampRange(self, op, *, start, stop, step):
def visit_RegexSplit(self, op, *, arg, pattern):
return self.f.splitByRegexp(pattern, self.cast(arg, dt.String(nullable=False)))

def visit_RandomScalar(self, op, **kwargs):
return self.f.randCanonical()

def visit_RandomUUID(self, op, **kwargs):
return self.f.generateUUIDv4()

@staticmethod
def _generate_groups(groups):
return groups
1 change: 0 additions & 1 deletion ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.Last: "last_value",
ops.Median: "median",
ops.StringLength: "character_length",
ops.RandomUUID: "uuid",
ops.RegexSplit: "regex_split",
ops.EndsWith: "ends_with",
ops.ArrayIntersect: "array_intersect",
Expand Down
7 changes: 6 additions & 1 deletion ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.MapMerge: "map_concat",
ops.MapValues: "map_values",
ops.Mode: "mode",
ops.RandomUUID: "uuid",
ops.TimeFromHMS: "make_time",
ops.TypeOf: "typeof",
ops.GeoPoint: "st_point",
Expand Down Expand Up @@ -418,3 +417,9 @@ def visit_StructField(self, op, *, arg, field):
expression=sg.to_identifier(field, quoted=self.quoted),
)
return super().visit_StructField(op, arg=arg, field=field)

def visit_RandomScalar(self, op, **kwargs):
return self.f.random()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()
2 changes: 0 additions & 2 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
ops.RandomScalar: "rand",
ops.RandomUUID: "uuid",
ops.RegexSearch: "regexp",
ops.StrRight: "right",
ops.StringLength: "char_length",
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.Hash: "fnv_hash",
ops.LStrip: "ltrim",
ops.Ln: "ln",
ops.RandomUUID: "uuid",
ops.RStrip: "rtrim",
ops.Strip: "trim",
ops.TypeOf: "typeof",
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.Ln: "log",
ops.Log10: "log10",
ops.Power: "power",
ops.RandomScalar: "rand",
ops.RandomUUID: "newid",
ops.Repeat: "replicate",
ops.Reverse: "reverse",
ops.StringAscii: "ascii",
Expand Down Expand Up @@ -172,6 +170,9 @@ def _minimize_spec(start, end, spec):
return None
return spec

def visit_RandomUUID(self, op, **kwargs):
return self.f.newid()

def visit_StringLength(self, op, *, arg):
"""The MSSQL LEN function doesn't count trailing spaces.
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class PostgresCompiler(SQLGlotCompiler):
ops.MapContains: "exist",
ops.MapKeys: "akeys",
ops.MapValues: "avals",
ops.RandomUUID: "gen_random_uuid",
ops.RegexSearch: "regexp_like",
ops.TimeFromHMS: "make_time",
}
Expand All @@ -111,6 +110,9 @@ def _aggregate(self, funcname: str, *args, where):
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

def visit_Mode(self, op, *, arg, where):
expr = self.f.mode()
expr = sge.WithinGroup(
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
ops.Hash: "hash",
ops.Median: "median",
ops.Mode: "mode",
ops.RandomUUID: "uuid_string",
ops.StringToTimestamp: "to_timestamp_tz",
ops.TimeFromHMS: "time_from_parts",
ops.TimestampFromYMDHMS: "timestamp_from_parts",
Expand Down Expand Up @@ -241,11 +240,14 @@ def visit_MapLength(self, op, *, arg):
def visit_Log(self, op, *, arg, base):
return self.f.log(base, arg, dialect=self.dialect)

def visit_RandomScalar(self, op):
def visit_RandomScalar(self, op, **kwargs):
return self.f.uniform(
self.f.to_double(0.0), self.f.to_double(1.0), self.f.random()
)

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid_string()

def visit_ApproxMedian(self, op, *, arg, where):
return self.agg.approx_percentile(arg, 0.5, where=where)

Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ class SQLGlotCompiler(abc.ABC):
ops.Power: "pow",
ops.RPad: "rpad",
ops.Radians: "radians",
ops.RandomScalar: "random",
ops.RegexSearch: "regexp_like",
ops.RegexSplit: "regexp_split",
ops.Repeat: "repeat",
Expand Down Expand Up @@ -687,6 +686,14 @@ def visit_Round(self, op, *, arg, digits):
return sge.Round(this=arg, decimals=digits)
return sge.Round(this=arg)

### Random Noise

def visit_RandomScalar(self, op, **kwargs):
return self.f.rand()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

### Dtype Dysmorphia

def visit_TryCast(self, op, *, arg, to):
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,13 @@ def merge_select_select(_, **kwargs):
from the inner Select are inlined into the outer Select.
"""
# don't merge if either the outer or the inner select has window functions
blocking = (ops.WindowFunction, ops.ExistsSubquery, ops.InSubquery, ops.Unnest)
blocking = (
ops.WindowFunction,
ops.ExistsSubquery,
ops.InSubquery,
ops.Unnest,
ops.Impure,
)
if _.find_below(blocking, filter=ops.Value):
return _
if _.parent.find_below(blocking, filter=ops.Value):
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.Mode: "_ibis_mode",
ops.Time: "time",
ops.Date: "date",
ops.RandomUUID: "uuid",
}

def _aggregate(self, funcname: str, *args, where):
Expand Down Expand Up @@ -213,7 +212,7 @@ def visit_Clip(self, op, *, arg, lower, upper):

return arg

def visit_RandomScalar(self, op):
def visit_RandomScalar(self, op, **kwargs):
return 0.5 + self.f.random() / sge.Literal.number(float(-1 << 64))

def visit_Cot(self, op, *, arg):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
randCanonical() AS "y",
randCanonical() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
generateUUIDv4() AS "y",
generateUUIDv4() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
RANDOM() AS "y",
RANDOM() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
UUID() AS "y",
UUID() AS "z"
FROM "t" AS "t0"
) AS "t1"
10 changes: 10 additions & 0 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,13 @@ def test_union_generates_predictable_aliases(con):
expr = ibis.union(sub1, sub2)
df = con.execute(expr)
assert len(df) == 2


@pytest.mark.parametrize("value", [ibis.random(), ibis.uuid()])
def test_selects_with_impure_operations_not_merged(con, snapshot, value):
t = ibis.table({"x": "int64", "y": "float64"}, name="t")
t = t.mutate(y=value, z=value)
t = t.mutate(size=(t.y == t.z).ifelse("big", "small"))

sql = con.compile(t, pretty=True)
snapshot.assert_match(sql, "out.sql")
1 change: 0 additions & 1 deletion ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class TrinoCompiler(SQLGlotCompiler):
ops.ExtractPath: "url_extract_path",
ops.ExtractFragment: "url_extract_fragment",
ops.ArrayPosition: "array_position",
ops.RandomUUID: "uuid",
}

def _aggregate(self, funcname: str, *args, where):
Expand Down
17 changes: 15 additions & 2 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ class Constant(Scalar, Singleton):
shape = ds.scalar


@public
class Impure(Value):
_counter = itertools.count()
uid: Optional[int] = None

def __init__(self, uid, **kwargs):
if uid is None:
uid = next(self._counter)
super().__init__(uid=uid, **kwargs)


@public
class TimestampNow(Constant):
dtype = dt.timestamp
Expand All @@ -194,13 +205,15 @@ class DateNow(Constant):


@public
class RandomScalar(Constant):
class RandomScalar(Impure):
dtype = dt.float64
shape = ds.scalar


@public
class RandomUUID(Constant):
class RandomUUID(Impure):
dtype = dt.uuid
shape = ds.scalar


@public
Expand Down
10 changes: 10 additions & 0 deletions ibis/expr/operations/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,13 @@ def test_NULL():
assert ops.NULL.dtype is dt.null
assert ops.NULL == ops.Literal(None, dt.null)
assert ops.NULL is not ops.Literal(None, dt.int8)


@pytest.mark.parametrize("op", [ops.RandomScalar, ops.RandomUUID])
def test_unique_impure_values(op):
assert op() != op()
assert hash(op()) != hash(op())

node = op()
other = node.copy()
assert node == other
25 changes: 25 additions & 0 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,3 +1606,28 @@ def test_subsequent_order_by_calls():
first = ops.Sort(t, [t.int_col.desc()]).to_expr()
second = ops.Sort(first, [first.int_col.asc()]).to_expr()
assert ts.equals(second)


@pytest.mark.parametrize("func", [ibis.random, ibis.uuid])
def test_impure_operation_dereferencing(func):
t = ibis.table({"x": "int64"}, name="t")

impure = func()
t1 = t.mutate(y=impure)
t2 = t1.mutate(z=impure.cast("string"))

expected = ops.Project(
parent=t1,
values={"x": t1.x, "y": t1.y, "z": t1.y.cast("string")},
)
assert t2.op() == expected

v1 = func()
v2 = func()

t1 = t.mutate(y=v1)
t2 = t1.mutate(z=v2.cast("string"))
expected = ops.Project(
parent=t1, values={"x": t1.x, "y": t1.y, "z": v2.cast("string")}
)
assert t2.op() == expected
6 changes: 5 additions & 1 deletion ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,15 @@ def test_order_by_asc_deferred_sort_key(table):
assert_equal(result, expected2)


# different instantiations create unique objects
rand = ibis.random()


@pytest.mark.parametrize(
("key", "expected"),
[
param(ibis.NA, ibis.NA.op(), id="na"),
param(ibis.random(), ibis.random().op(), id="random"),
param(rand, rand.op(), id="random"),
param(1.0, ibis.literal(1.0).op(), id="float"),
param(ibis.literal("a"), ibis.literal("a").op(), id="string"),
param(ibis.literal([1, 2, 3]), ibis.literal([1, 2, 3]).op(), id="array"),
Expand Down
5 changes: 3 additions & 2 deletions ibis/tests/expr/test_window_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,16 @@ def test_window_api_supports_scalar_order_by(t):
)
assert expr == expected

window = ibis.window(order_by=ibis.random())
rand = ibis.random()
window = ibis.window(order_by=rand)
expr = t.a.sum().over(window).op()
expected = ops.WindowFunction(
t.a.sum(),
how="rows",
start=None,
end=None,
group_by=(),
order_by=[ibis.random()],
order_by=[rand],
)
assert expr == expected

Expand Down

0 comments on commit 813dc52

Please sign in to comment.