Skip to content

Commit

Permalink
Add tests of traversal over IR nodes
Browse files Browse the repository at this point in the history
Now that we have a uniform child attribute, this is easier.
  • Loading branch information
wence- committed Oct 14, 2024
1 parent fec6632 commit 8f7d610
Showing 1 changed file with 60 additions and 1 deletion.
61 changes: 60 additions & 1 deletion python/cudf_polars/tests/dsl/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import pylibcudf as plc

from cudf_polars.dsl import expr
import polars as pl
from polars.testing import assert_frame_equal

from cudf_polars import translate_ir
from cudf_polars.dsl import expr, ir
from cudf_polars.dsl.traversal import (
CachingVisitor,
make_recursive,
Expand Down Expand Up @@ -96,3 +100,58 @@ def test_noop_visitor():

renamed = mapper(e2)
assert renamed == make_expr(dt, "c", "c")


def test_rewrite_ir_node():
df = pl.LazyFrame({"a": [1, 2, 1], "b": [1, 3, 4]})
q = df.group_by("a").agg(pl.col("b").sum()).sort("b")

orig = translate_ir(q._ldf.visit())

new_df = pl.DataFrame({"a": [1, 1, 2], "b": [-1, -2, -4]})

def replace_df(node, rec):
if isinstance(node, ir.DataFrameScan):
return ir.DataFrameScan(
node.schema, new_df._df, node.projection, node.predicate
)
return reuse_if_unchanged(node, rec)

mapper = CachingVisitor(replace_df)

new = mapper(orig)

result = new.evaluate(cache={}).to_polars()

expect = pl.DataFrame({"a": [2, 1], "b": [-4, -3]})

assert_frame_equal(result, expect)


def test_rewrite_scan_node(tmp_path):
left = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 3, 4]})
right = pl.DataFrame({"a": [1, 4, 2], "c": [1, 2, 3]})

right.write_parquet(tmp_path / "right.pq")

right_s = pl.scan_parquet(tmp_path / "right.pq")

q = left.join(right_s, on="a", how="inner")

def replace_scan(node, rec):
if isinstance(node, ir.Scan):
return ir.DataFrameScan(
node.schema, right._df, node.with_columns, node.predicate
)
return reuse_if_unchanged(node, rec)

mapper = CachingVisitor(replace_scan)

orig = translate_ir(q._ldf.visit())
new = mapper(orig)

result = new.evaluate(cache={}).to_polars()

expect = q.collect()

assert_frame_equal(result, expect, check_row_order=False)

0 comments on commit 8f7d610

Please sign in to comment.