Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from cudf_polars.dsl.expressions import rolling, unary
from cudf_polars.dsl.expressions.base import ExecutionContext
from cudf_polars.dsl.nodebase import Node
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter, validate_to_ast
from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
from cudf_polars.dsl.utils.reshape import broadcast
from cudf_polars.dsl.utils.windows import (
Expand Down Expand Up @@ -1934,22 +1934,26 @@ class Predicate:
"""Serializable wrapper for a predicate expression."""

predicate: expr.Expr
ast: plc.expressions.Expression

def __init__(self, predicate: expr.Expr):
validate_to_ast(predicate)
self.predicate = predicate
stream = get_cuda_stream()
ast_result = to_ast(predicate, stream=stream)
stream.synchronize()
if ast_result is None:
raise NotImplementedError(
f"Conditional join with predicate {predicate}"
) # pragma: no cover; polars never delivers expressions we can't handle
self.ast = ast_result

def __reduce__(self) -> tuple[Any, ...]:
"""Pickle a Predicate object."""
return (type(self), (self.predicate,))
def ast(self, stream: Stream) -> plc.expressions.Expression:
"""
Translate the predicate cudf-polars expression to a pylibcudf expression.

Parameters
----------
stream
CUDA stream used for device memory operations and kernel launches.

Returns
-------
plc.expressions.Expression
The pylibcudf expression representing the predicate.
"""
return to_ast(self.predicate, stream=stream)

__slots__ = ("ast_predicate", "options", "predicate")
_non_child = ("schema", "predicate", "options")
Expand Down Expand Up @@ -2027,7 +2031,7 @@ def do_evaluate(
lg, rg = plc.join.conditional_inner_join(
_apply_casts(left, left_casts).table,
_apply_casts(right, right_casts).table,
predicate_wrapper.ast,
predicate_wrapper.ast(stream=stream),
stream=stream,
)
left = DataFrame.from_table(
Expand Down
113 changes: 108 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ class ASTState(TypedDict):
stream: Stream


class ValidateState(TypedDict):
"""State for validation of AST transformations."""

for_parquet: bool


class ExprTransformerState(TypedDict):
"""
State used for AST transformation when inserting column references.
Expand All @@ -126,13 +132,105 @@ class ExprTransformerState(TypedDict):
table_ref: plc.expressions.TableReference


ValidateTransformer: TypeAlias = GenericTransformer[expr.Expr, None, ValidateState]
Transformer: TypeAlias = GenericTransformer[expr.Expr, plc_expr.Expression, ASTState]
ExprTransformer: TypeAlias = GenericTransformer[
expr.Expr, expr.Expr, ExprTransformerState
]
"""Protocol for transformation of Expr nodes."""


@singledispatch
def _validate_to_ast(node: expr.Expr, self: ValidateTransformer) -> None:
print(f"Unhandled expression type {type(node)}")
raise NotImplementedError(f"Unhandled expression type {type(node)}")


@_validate_to_ast.register
def _(node: expr.Literal, self: ValidateTransformer) -> None:
return None


@_validate_to_ast.register
def _(node: expr.Col, self: ValidateTransformer) -> None:
if self.state["for_parquet"]:
return None
raise TypeError("Should always be wrapped in a ColRef node before translation")


@_validate_to_ast.register
def _(node: expr.ColRef, self: ValidateTransformer) -> None:
if self.state["for_parquet"]:
raise TypeError("Not expecting ColRef node in parquet filter")


@_validate_to_ast.register
def _(node: expr.BinOp, self: ValidateTransformer) -> None:
if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS:
self(
expr.BinOp(
node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children
)
) # check for validation errors
return None
if self.state["for_parquet"]:
op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children)
if op1_col ^ op2_col:
op: plc.binaryop.BinaryOperator = node.op
if op not in SUPPORTED_STATISTICS_BINOPS:
raise NotImplementedError(
f"Parquet filter binop with column doesn't support {node.op!r}"
)
op1, op2 = node.children
if op2_col:
(op1, op2) = (op2, op1)
op = REVERSED_COMPARISON[op]
if not isinstance(op2, expr.Literal):
raise NotImplementedError(
"Parquet filter binops must have form 'col binop literal'"
)
self(op1) # check for validation errors
self(op2) # check for validation errors
return None
elif op1_col and op2_col:
raise NotImplementedError(
"Parquet filter binops must have one column reference not two"
)
for child in node.children:
self(child)


@_validate_to_ast.register
def _(node: expr.BooleanFunction, self: ValidateTransformer) -> None:
if node.name is expr.BooleanFunction.Name.IsIn:
needles, haystack = node.children
if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16:
self(needles)
return None
if self.state["for_parquet"] and isinstance(node.children[0], expr.Col):
raise NotImplementedError(
f"Parquet filters don't support {node.name} on columns"
)
if (
node.name is expr.BooleanFunction.Name.IsNull
or node.name is expr.BooleanFunction.Name.IsNotNull
or node.name is expr.BooleanFunction.Name.Not
):
self(node.children[0])
return None # check for validation errors
raise NotImplementedError(f"AST conversion does not support {node.name}")


@_validate_to_ast.register
def _(node: expr.UnaryFunction, self: ValidateTransformer) -> None:
if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]:
raise NotImplementedError(
"Parquet filters don't support {node.name} on columns"
)
self(node.children[0]) # check for validation errors
return None


@singledispatch
def _to_ast(node: expr.Expr, self: Transformer) -> plc_expr.Expression:
"""
Expand Down Expand Up @@ -296,7 +394,15 @@ def to_parquet_filter(node: expr.Expr, stream: Stream) -> plc_expr.Expression |
return None


def to_ast(node: expr.Expr, stream: Stream) -> plc_expr.Expression | None:
def validate_to_ast(node: expr.Expr) -> None:
"""Validate it."""
mapper: ValidateTransformer = CachingVisitor(
_validate_to_ast, state={"for_parquet": False}
)
return mapper(node)


def to_ast(node: expr.Expr, stream: Stream) -> plc_expr.Expression:
"""
Convert an expression to libcudf AST nodes suitable for compute_column.

Expand All @@ -320,10 +426,7 @@ def to_ast(node: expr.Expr, stream: Stream) -> plc_expr.Expression | None:
mapper: Transformer = CachingVisitor(
_to_ast, state={"for_parquet": False, "stream": stream}
)
try:
return mapper(node)
except (KeyError, NotImplementedError):
return None
return mapper(node)


def _insert_colrefs(node: expr.Expr, rec: ExprTransformer) -> expr.Expr:
Expand Down
8 changes: 5 additions & 3 deletions python/cudf_polars/tests/dsl/test_to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,16 @@ def compute_column(e):
)
with pytest.raises(NotImplementedError):
e_with_colrefs.evaluate(table)
ast = to_ast(e_with_colrefs, stream=stream)
if ast is not None:
try:
ast = to_ast(e_with_colrefs, stream=stream)
except (KeyError, NotImplementedError):
return e.evaluate(table)
else:
return NamedColumn(
plc.transform.compute_column(table.table, ast, stream=stream),
name=e.name,
dtype=e.value.dtype,
)
return e.evaluate(table)

got = DataFrame(map(compute_column, ir.exprs), stream=stream).to_polars()

Expand Down
5 changes: 4 additions & 1 deletion python/cudf_polars/tests/experimental/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from cudf_polars import Translator
from cudf_polars.experimental.parallel import lower_ir_graph
from cudf_polars.experimental.shuffle import Shuffle
from cudf_polars.testing.asserts import DEFAULT_CLUSTER, assert_gpu_result_equal
from cudf_polars.testing.asserts import (
DEFAULT_CLUSTER,
assert_gpu_result_equal,
)
from cudf_polars.utils.config import ConfigOptions


Expand Down
16 changes: 16 additions & 0 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
get_default_engine,
)
from cudf_polars.utils.versions import POLARS_VERSION_LT_130, POLARS_VERSION_LT_132
Expand Down Expand Up @@ -312,3 +313,18 @@ def test_cross_join_filter_with_decimals(request, expr, left_dtype, right_dtype)
q = left.join(right, how="cross").filter(expr)

assert_gpu_result_equal(q, check_row_order=False)


def test_ie_join_projection_pd_19005() -> None:
lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).with_row_index()
q = (
lf.join_where(
lf,
pl.col.index < pl.col.index_right,
pl.col.index.cast(pl.Int64) + pl.col.a > pl.col.a_right,
)
.group_by(pl.col.index)
.agg(pl.col.index_right)
)

assert_ir_translation_raises(q, NotImplementedError)
Loading