Skip to content

Commit ffa38da

Browse files
committed
feat(ir): make impure ibis.random() and ibis.uuid() functions return unique node instances
1 parent f8370b1 commit ffa38da

File tree

24 files changed

+162
-24
lines changed

24 files changed

+162
-24
lines changed

ibis/backends/bigquery/compiler.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ class BigQueryCompiler(SQLGlotCompiler):
120120
ops.RPad: "rpad",
121121
ops.Levenshtein: "edit_distance",
122122
ops.Modulus: "mod",
123-
ops.RandomScalar: "rand",
124-
ops.RandomUUID: "generate_uuid",
125123
ops.RegexReplace: "regexp_replace",
126124
ops.RegexSearch: "regexp_contains",
127125
ops.Time: "time",
@@ -698,3 +696,6 @@ def visit_CountDistinct(self, op, *, arg, where):
698696
if where is not None:
699697
arg = self.if_(where, arg, NULL)
700698
return self.f.count(sge.Distinct(expressions=[arg]))
699+
700+
def visit_RandomUUID(self, op, **kwargs):
701+
return self.f.generate_uuid()

ibis/backends/clickhouse/compiler.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
9898
ops.NotNull: "isNotNull",
9999
ops.NullIf: "nullIf",
100100
ops.RStrip: "trimRight",
101-
ops.RandomScalar: "randCanonical",
102-
ops.RandomUUID: "generateUUIDv4",
103101
ops.RegexReplace: "replaceRegexpAll",
104102
ops.RowNumber: "row_number",
105103
ops.StartsWith: "startsWith",
@@ -637,6 +635,12 @@ def visit_TimestampRange(self, op, *, start, stop, step):
637635
def visit_RegexSplit(self, op, *, arg, pattern):
638636
return self.f.splitByRegexp(pattern, self.cast(arg, dt.String(nullable=False)))
639637

638+
def visit_RandomScalar(self, op, **kwargs):
639+
return self.f.randCanonical()
640+
641+
def visit_RandomUUID(self, op, **kwargs):
642+
return self.f.generateUUIDv4()
643+
640644
@staticmethod
641645
def _generate_groups(groups):
642646
return groups

ibis/backends/datafusion/compiler.py

-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ class DataFusionCompiler(SQLGlotCompiler):
7575
ops.Last: "last_value",
7676
ops.Median: "median",
7777
ops.StringLength: "character_length",
78-
ops.RandomUUID: "uuid",
7978
ops.RegexSplit: "regex_split",
8079
ops.EndsWith: "ends_with",
8180
ops.ArrayIntersect: "array_intersect",

ibis/backends/duckdb/compiler.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class DuckDBCompiler(SQLGlotCompiler):
4848
ops.MapMerge: "map_concat",
4949
ops.MapValues: "map_values",
5050
ops.Mode: "mode",
51-
ops.RandomUUID: "uuid",
5251
ops.TimeFromHMS: "make_time",
5352
ops.TypeOf: "typeof",
5453
ops.GeoPoint: "st_point",
@@ -418,3 +417,9 @@ def visit_StructField(self, op, *, arg, field):
418417
expression=sg.to_identifier(field, quoted=self.quoted),
419418
)
420419
return super().visit_StructField(op, arg=arg, field=field)
420+
421+
def visit_RandomScalar(self, op, **kwargs):
422+
return self.f.random()
423+
424+
def visit_RandomUUID(self, op, **kwargs):
425+
return self.f.uuid()

ibis/backends/flink/compiler.py

-2
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ class FlinkCompiler(SQLGlotCompiler):
7979
ops.MapKeys: "map_keys",
8080
ops.MapValues: "map_values",
8181
ops.Power: "power",
82-
ops.RandomScalar: "rand",
83-
ops.RandomUUID: "uuid",
8482
ops.RegexSearch: "regexp",
8583
ops.StrRight: "right",
8684
ops.StringLength: "char_length",

ibis/backends/impala/compiler.py

-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class ImpalaCompiler(SQLGlotCompiler):
7676
ops.Hash: "fnv_hash",
7777
ops.LStrip: "ltrim",
7878
ops.Ln: "ln",
79-
ops.RandomUUID: "uuid",
8079
ops.RStrip: "rtrim",
8180
ops.Strip: "trim",
8281
ops.TypeOf: "typeof",

ibis/backends/mssql/compiler.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,6 @@ class MSSQLCompiler(SQLGlotCompiler):
129129
ops.Ln: "log",
130130
ops.Log10: "log10",
131131
ops.Power: "power",
132-
ops.RandomScalar: "rand",
133-
ops.RandomUUID: "newid",
134132
ops.Repeat: "replicate",
135133
ops.Reverse: "reverse",
136134
ops.StringAscii: "ascii",
@@ -172,6 +170,9 @@ def _minimize_spec(start, end, spec):
172170
return None
173171
return spec
174172

173+
def visit_RandomUUID(self, op, **kwargs):
174+
return self.f.newid()
175+
175176
def visit_StringLength(self, op, *, arg):
176177
"""The MSSQL LEN function doesn't count trailing spaces.
177178

ibis/backends/postgres/compiler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ class PostgresCompiler(SQLGlotCompiler):
100100
ops.MapContains: "exist",
101101
ops.MapKeys: "akeys",
102102
ops.MapValues: "avals",
103-
ops.RandomUUID: "gen_random_uuid",
104103
ops.RegexSearch: "regexp_like",
105104
ops.TimeFromHMS: "make_time",
106105
}
@@ -111,6 +110,9 @@ def _aggregate(self, funcname: str, *args, where):
111110
return sge.Filter(this=expr, expression=sge.Where(this=where))
112111
return expr
113112

113+
def visit_RandomUUID(self, op, **kwargs):
114+
return self.f.gen_random_uuid()
115+
114116
def visit_Mode(self, op, *, arg, where):
115117
expr = self.f.mode()
116118
expr = sge.WithinGroup(

ibis/backends/snowflake/compiler.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
7878
ops.Hash: "hash",
7979
ops.Median: "median",
8080
ops.Mode: "mode",
81-
ops.RandomUUID: "uuid_string",
8281
ops.StringToTimestamp: "to_timestamp_tz",
8382
ops.TimeFromHMS: "time_from_parts",
8483
ops.TimestampFromYMDHMS: "timestamp_from_parts",
@@ -241,11 +240,14 @@ def visit_MapLength(self, op, *, arg):
241240
def visit_Log(self, op, *, arg, base):
242241
return self.f.log(base, arg, dialect=self.dialect)
243242

244-
def visit_RandomScalar(self, op):
243+
def visit_RandomScalar(self, op, **kwargs):
245244
return self.f.uniform(
246245
self.f.to_double(0.0), self.f.to_double(1.0), self.f.random()
247246
)
248247

248+
def visit_RandomUUID(self, op, **kwargs):
249+
return self.f.uuid_string()
250+
249251
def visit_ApproxMedian(self, op, *, arg, where):
250252
return self.agg.approx_percentile(arg, 0.5, where=where)
251253

ibis/backends/sql/compiler.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ class SQLGlotCompiler(abc.ABC):
262262
ops.Power: "pow",
263263
ops.RPad: "rpad",
264264
ops.Radians: "radians",
265-
ops.RandomScalar: "random",
266265
ops.RegexSearch: "regexp_like",
267266
ops.RegexSplit: "regexp_split",
268267
ops.Repeat: "repeat",
@@ -687,6 +686,14 @@ def visit_Round(self, op, *, arg, digits):
687686
return sge.Round(this=arg, decimals=digits)
688687
return sge.Round(this=arg)
689688

689+
### Random Noise
690+
691+
def visit_RandomScalar(self, op, **kwargs):
692+
return self.f.rand()
693+
694+
def visit_RandomUUID(self, op, **kwargs):
695+
return self.f.uuid()
696+
690697
### Dtype Dysmorphia
691698

692699
def visit_TryCast(self, op, *, arg, to):

ibis/backends/sql/rewrites.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,13 @@ def merge_select_select(_, **kwargs):
152152
from the inner Select are inlined into the outer Select.
153153
"""
154154
# don't merge if either the outer or the inner select has window functions
155-
blocking = (ops.WindowFunction, ops.ExistsSubquery, ops.InSubquery, ops.Unnest)
155+
blocking = (
156+
ops.WindowFunction,
157+
ops.ExistsSubquery,
158+
ops.InSubquery,
159+
ops.Unnest,
160+
ops.Impure,
161+
)
156162
if _.find_below(blocking, filter=ops.Value):
157163
return _
158164
if _.parent.find_below(blocking, filter=ops.Value):

ibis/backends/sqlite/compiler.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ class SQLiteCompiler(SQLGlotCompiler):
103103
ops.Mode: "_ibis_mode",
104104
ops.Time: "time",
105105
ops.Date: "date",
106-
ops.RandomUUID: "uuid",
107106
}
108107

109108
def _aggregate(self, funcname: str, *args, where):
@@ -213,7 +212,7 @@ def visit_Clip(self, op, *, arg, lower, upper):
213212

214213
return arg
215214

216-
def visit_RandomScalar(self, op):
215+
def visit_RandomScalar(self, op, **kwargs):
217216
return 0.5 + self.f.random() / sge.Literal.number(float(-1 << 64))
218217

219218
def visit_Cot(self, op, *, arg):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
"t1"."x",
3+
"t1"."y",
4+
"t1"."z",
5+
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
6+
FROM (
7+
SELECT
8+
"t0"."x",
9+
randCanonical() AS "y",
10+
randCanonical() AS "z"
11+
FROM "t" AS "t0"
12+
) AS "t1"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
"t1"."x",
3+
"t1"."y",
4+
"t1"."z",
5+
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
6+
FROM (
7+
SELECT
8+
"t0"."x",
9+
generateUUIDv4() AS "y",
10+
generateUUIDv4() AS "z"
11+
FROM "t" AS "t0"
12+
) AS "t1"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
"t1"."x",
3+
"t1"."y",
4+
"t1"."z",
5+
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
6+
FROM (
7+
SELECT
8+
"t0"."x",
9+
RANDOM() AS "y",
10+
RANDOM() AS "z"
11+
FROM "t" AS "t0"
12+
) AS "t1"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
SELECT
2+
"t1"."x",
3+
"t1"."y",
4+
"t1"."z",
5+
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
6+
FROM (
7+
SELECT
8+
"t0"."x",
9+
UUID() AS "y",
10+
UUID() AS "z"
11+
FROM "t" AS "t0"
12+
) AS "t1"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
SELECT
2+
"t0"."x",
3+
RANDOM() AS "y",
4+
CASE WHEN RANDOM() > CAST(0.5 AS DOUBLE) THEN 'big' ELSE 'small' END AS "size"
5+
FROM "t" AS "t0"

ibis/backends/tests/test_sql.py

+10
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,13 @@ def test_union_generates_predictable_aliases(con):
177177
expr = ibis.union(sub1, sub2)
178178
df = con.execute(expr)
179179
assert len(df) == 2
180+
181+
182+
@pytest.mark.parametrize("value", [ibis.random(), ibis.uuid()])
183+
def test_selects_with_impure_operations_not_merged(con, snapshot, value):
184+
t = ibis.table({"x": "int64", "y": "float64"}, name="t")
185+
t = t.mutate(y=value, z=value)
186+
t = t.mutate(size=(t.y == t.z).ifelse("big", "small"))
187+
188+
sql = con.compile(t, pretty=True)
189+
snapshot.assert_match(sql, "out.sql")

ibis/backends/trino/compiler.py

-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ class TrinoCompiler(SQLGlotCompiler):
8080
ops.ExtractPath: "url_extract_path",
8181
ops.ExtractFragment: "url_extract_fragment",
8282
ops.ArrayPosition: "array_position",
83-
ops.RandomUUID: "uuid",
8483
}
8584

8685
def _aggregate(self, funcname: str, *args, where):

ibis/expr/operations/generic.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,17 @@ class Constant(Scalar, Singleton):
183183
shape = ds.scalar
184184

185185

186+
@public
187+
class Impure(Value):
188+
_counter = itertools.count()
189+
uid: Optional[int] = None
190+
191+
def __init__(self, uid, **kwargs):
192+
if uid is None:
193+
uid = next(self._counter)
194+
super().__init__(uid=uid, **kwargs)
195+
196+
186197
@public
187198
class TimestampNow(Constant):
188199
dtype = dt.timestamp
@@ -194,13 +205,15 @@ class DateNow(Constant):
194205

195206

196207
@public
197-
class RandomScalar(Constant):
208+
class RandomScalar(Impure):
198209
dtype = dt.float64
210+
shape = ds.scalar
199211

200212

201213
@public
202-
class RandomUUID(Constant):
214+
class RandomUUID(Impure):
203215
dtype = dt.uuid
216+
shape = ds.scalar
204217

205218

206219
@public

ibis/expr/operations/tests/test_generic.py

+10
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,13 @@ def test_NULL():
143143
assert ops.NULL.dtype is dt.null
144144
assert ops.NULL == ops.Literal(None, dt.null)
145145
assert ops.NULL is not ops.Literal(None, dt.int8)
146+
147+
148+
@pytest.mark.parametrize("op", [ops.RandomScalar, ops.RandomUUID])
149+
def test_unique_impure_values(op):
150+
assert op() != op()
151+
assert hash(op()) != hash(op())
152+
153+
node = op()
154+
other = node.copy()
155+
assert node == other

ibis/expr/tests/test_newrels.py

+25
Original file line numberDiff line numberDiff line change
@@ -1606,3 +1606,28 @@ def test_subsequent_order_by_calls():
16061606
first = ops.Sort(t, [t.int_col.desc()]).to_expr()
16071607
second = ops.Sort(first, [first.int_col.asc()]).to_expr()
16081608
assert ts.equals(second)
1609+
1610+
1611+
@pytest.mark.parametrize("func", [ibis.random, ibis.uuid])
1612+
def test_impure_operation_dereferencing(func):
1613+
t = ibis.table({"x": "int64"}, name="t")
1614+
1615+
impure = func()
1616+
t1 = t.mutate(y=impure)
1617+
t2 = t1.mutate(z=impure.cast("string"))
1618+
1619+
expected = ops.Project(
1620+
parent=t1,
1621+
values={"x": t1.x, "y": t1.y, "z": t1.y.cast("string")},
1622+
)
1623+
assert t2.op() == expected
1624+
1625+
v1 = func()
1626+
v2 = func()
1627+
1628+
t1 = t.mutate(y=v1)
1629+
t2 = t1.mutate(z=v2.cast("string"))
1630+
expected = ops.Project(
1631+
parent=t1, values={"x": t1.x, "y": t1.y, "z": v2.cast("string")}
1632+
)
1633+
assert t2.op() == expected

ibis/tests/expr/test_table.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -562,11 +562,15 @@ def test_order_by_asc_deferred_sort_key(table):
562562
assert_equal(result, expected2)
563563

564564

565+
# different instantiations create unique objects
566+
rand = ibis.random()
567+
568+
565569
@pytest.mark.parametrize(
566570
("key", "expected"),
567571
[
568572
param(ibis.NA, ibis.NA.op(), id="na"),
569-
param(ibis.random(), ibis.random().op(), id="random"),
573+
param(rand, rand.op(), id="random"),
570574
param(1.0, ibis.literal(1.0).op(), id="float"),
571575
param(ibis.literal("a"), ibis.literal("a").op(), id="string"),
572576
param(ibis.literal([1, 2, 3]), ibis.literal([1, 2, 3]).op(), id="array"),

ibis/tests/expr/test_window_frames.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,16 @@ def test_window_api_supports_scalar_order_by(t):
246246
)
247247
assert expr == expected
248248

249-
window = ibis.window(order_by=ibis.random())
249+
rand = ibis.random()
250+
window = ibis.window(order_by=rand)
250251
expr = t.a.sum().over(window).op()
251252
expected = ops.WindowFunction(
252253
t.a.sum(),
253254
how="rows",
254255
start=None,
255256
end=None,
256257
group_by=(),
257-
order_by=[ibis.random()],
258+
order_by=[rand],
258259
)
259260
assert expr == expected
260261

0 commit comments

Comments
 (0)