From 6248ec372e63a674be10bda205548a1698a85faa Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 11:54:29 +0000 Subject: [PATCH 01/23] Renaming in typing for clarity --- python/cudf_polars/cudf_polars/typing/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cudf_polars/cudf_polars/typing/__init__.py b/python/cudf_polars/cudf_polars/typing/__init__.py index 240b11bdf59..75c76d87242 100644 --- a/python/cudf_polars/cudf_polars/typing/__init__.py +++ b/python/cudf_polars/cudf_polars/typing/__init__.py @@ -18,7 +18,7 @@ import polars as pl -IR: TypeAlias = Union[ +PolarsIR: TypeAlias = Union[ pl_ir.PythonScan, pl_ir.Scan, pl_ir.Cache, @@ -38,7 +38,7 @@ pl_ir.ExtContext, ] -Expr: TypeAlias = Union[ +PolarsExpr: TypeAlias = Union[ pl_expr.Function, pl_expr.Window, pl_expr.Literal, @@ -68,7 +68,7 @@ def set_node(self, n: int) -> None: """Set the current plan node to n.""" ... - def view_current_node(self) -> IR: + def view_current_node(self) -> PolarsIR: """Convert current plan node to python rep.""" ... @@ -80,7 +80,7 @@ def get_dtype(self, n: int) -> pl.DataType: """Get the datatype of the given expression id.""" ... - def view_expression(self, n: int) -> Expr: + def view_expression(self, n: int) -> PolarsExpr: """Convert the given expression to python rep.""" ... From 8b5aaedd1bc5d4500a8dc04081875fff0405e3cc Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 8 Oct 2024 11:27:41 +0000 Subject: [PATCH 02/23] Extract abstract base for nodes into new file We will use this to provide infrastructure for making IR nodes easier to traverse. Expr nodes already use this facility, but we want to share it. --- .../cudf_polars/cudf_polars/dsl/nodebase.py | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 python/cudf_polars/cudf_polars/dsl/nodebase.py diff --git a/python/cudf_polars/cudf_polars/dsl/nodebase.py b/python/cudf_polars/cudf_polars/dsl/nodebase.py new file mode 100644 index 00000000000..6b3f05c7bf8 --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/nodebase.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Base class for IR nodes, and utilities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +if TYPE_CHECKING: + from collections.abc import Hashable, Sequence + + from typing_extensions import Self + + +__all__: list[str] = ["Node"] + + +class Node: + """ + An abstract node type. + + Nodes are immutable! + + This contains a (potentially empty) tuple of child nodes, + along with non-child data. For uniform reconstruction and + implementation of hashing and equality schemes, child classes need + to provide a certain amount of metadata when they are defined. + Specifically, the ``_non_child`` attribute must list, in-order, + the names of the slots that are passed to the constructor. The + constructor must take arguments in the order ``(*_non_child, + *children).`` + """ + + __slots__ = ("_hash_value", "_repr_value", "children") + _hash_value: int + _repr_value: str + children: tuple[Node, ...] + _non_child: ClassVar[tuple[str, ...]] = () + + def _ctor_arguments(self, children: Sequence[Node]) -> Sequence: + return (*(getattr(self, attr) for attr in self._non_child), *children) + + def reconstruct( + self, children: Sequence[Node] + ) -> Self: # pragma: no cover; not yet used + """ + Rebuild this node with new children. + + Parameters + ---------- + children + New children + + Returns + ------- + New node with new children. Non-child data is shared with the input. + """ + return type(self)(*self._ctor_arguments(children)) + + def get_hashable(self) -> Hashable: + """ + Return a hashable object for the node. + + Returns + ------- + Hashable object. + + Notes + ----- + This method is used by the :meth:`__hash__` implementation + (which does caching). If your node type needs special-case + handling for some of its attributes, override this method, not + :meth:`__hash__`. + """ + return (type(self), self._ctor_arguments(self.children)) + + def __hash__(self) -> int: + """ + Hash of an expression with caching. + + See Also + -------- + get_hashable + """ + try: + return self._hash_value + except AttributeError: + self._hash_value = hash(self.get_hashable()) + return self._hash_value + + def is_equal(self, other: Any) -> bool: + """ + Equality of two expressions. + + Override this in subclasses, rather than :meth:`__eq__`. + + Parameter + --------- + other + object to compare to. + + Notes + ----- + Since nodes are immutable, this does common subexpression + elimination when two nodes are determined to be equal. + + Returns + ------- + True if the two expressions are equal, false otherwise. + """ + if self is other: + return True + if type(self) is not type(other): + return False # pragma: no cover; __eq__ trips first + result = self._ctor_arguments(self.children) == other._ctor_arguments( + other.children + ) + # Eager CSE for nodes that match. + if result and len(self.children) > 0: + self.children = other.children + return result + + def __eq__(self, other: Any) -> bool: + """ + Equality of expressions. + + See Also + -------- + is_equal + """ + if type(self) is not type(other) or hash(self) != hash(other): + return False + else: + return self.is_equal(other) + + def __ne__(self, other: Any) -> bool: + """Inequality of expressions.""" + return not self.__eq__(other) + + def __repr__(self) -> str: + """String representation of an expression with caching.""" + try: + return self._repr_value + except AttributeError: + args = ", ".join(f"{arg!r}" for arg in self._ctor_arguments(self.children)) + self._repr_value = f"{type(self).__name__}({args})" + return self._repr_value From 26b9d7dd552aeccbc75f8d75c5ba577edce635a4 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 8 Oct 2024 11:31:49 +0000 Subject: [PATCH 03/23] Use new Node base class for expressions --- .../dsl/expressions/aggregation.py | 2 +- .../cudf_polars/dsl/expressions/base.py | 92 +------------------ .../cudf_polars/dsl/expressions/binaryop.py | 2 +- .../cudf_polars/dsl/expressions/boolean.py | 2 +- .../cudf_polars/dsl/expressions/datetime.py | 2 +- .../cudf_polars/dsl/expressions/literal.py | 10 +- .../cudf_polars/dsl/expressions/rolling.py | 4 +- .../cudf_polars/dsl/expressions/selection.py | 4 +- .../cudf_polars/dsl/expressions/sorting.py | 4 +- .../cudf_polars/dsl/expressions/string.py | 2 +- .../cudf_polars/dsl/expressions/ternary.py | 2 +- .../cudf_polars/dsl/expressions/unary.py | 7 +- python/cudf_polars/tests/dsl/test_expr.py | 23 +++++ 13 files changed, 50 insertions(+), 106 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py index b8b18ec5039..3b78bb3dd54 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/aggregation.py @@ -37,7 +37,7 @@ class Agg(Expr): def __init__( self, dtype: plc.DataType, name: str, options: Any, *children: Expr ) -> None: - super().__init__(dtype) + self.dtype = dtype self.name = name self.options = options self.children = children diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/base.py b/python/cudf_polars/cudf_polars/dsl/expressions/base.py index 8d021b0231d..46eaeb7e4d6 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/base.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/base.py @@ -13,9 +13,10 @@ import pylibcudf as plc from cudf_polars.containers import Column +from cudf_polars.dsl.nodebase import Node if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Mapping from cudf_polars.containers import Column, DataFrame @@ -32,100 +33,17 @@ class ExecutionContext(IntEnum): ROLLING = enum.auto() -class Expr: - """ - An abstract expression object. +class Expr(Node): + """An abstract expression object.""" - This contains a (potentially empty) tuple of child expressions, - along with non-child data. For uniform reconstruction and - implementation of hashing and equality schemes, child classes need - to provide a certain amount of metadata when they are defined. - Specifically, the ``_non_child`` attribute must list, in-order, - the names of the slots that are passed to the constructor. The - constructor must take arguments in the order ``(*_non_child, - *children).`` - """ - - __slots__ = ("dtype", "_hash_value", "_repr_value") + __slots__ = ("dtype",) dtype: plc.DataType """Data type of the expression.""" - _hash_value: int - """Caching slot for the hash of the expression.""" - _repr_value: str - """Caching slot for repr of the expression.""" children: tuple[Expr, ...] = () """Children of the expression.""" _non_child: ClassVar[tuple[str, ...]] = ("dtype",) """Names of non-child data (not Exprs) for reconstruction.""" - # Constructor must take arguments in order (*_non_child, *children) - def __init__(self, dtype: plc.DataType) -> None: - self.dtype = dtype - - def _ctor_arguments(self, children: Sequence[Expr]) -> Sequence: - return (*(getattr(self, attr) for attr in self._non_child), *children) - - def get_hash(self) -> int: - """ - Return the hash of this expr. - - Override this in subclasses, rather than __hash__. - - Returns - ------- - The integer hash value. - """ - return hash((type(self), self._ctor_arguments(self.children))) - - def __hash__(self) -> int: - """Hash of an expression with caching.""" - try: - return self._hash_value - except AttributeError: - self._hash_value = self.get_hash() - return self._hash_value - - def is_equal(self, other: Any) -> bool: - """ - Equality of two expressions. - - Override this in subclasses, rather than __eq__. - - Parameter - --------- - other - object to compare to - - Returns - ------- - True if the two expressions are equal, false otherwise. - """ - if type(self) is not type(other): - return False # pragma: no cover; __eq__ trips first - return self._ctor_arguments(self.children) == other._ctor_arguments( - other.children - ) - - def __eq__(self, other: Any) -> bool: - """Equality of expressions.""" - if type(self) is not type(other) or hash(self) != hash(other): - return False - else: - return self.is_equal(other) - - def __ne__(self, other: Any) -> bool: - """Inequality of expressions.""" - return not self.__eq__(other) - - def __repr__(self) -> str: - """String representation of an expression with caching.""" - try: - return self._repr_value - except AttributeError: - args = ", ".join(f"{arg!r}" for arg in self._ctor_arguments(self.children)) - self._repr_value = f"{type(self).__name__}({args})" - return self._repr_value - def do_evaluate( self, df: DataFrame, diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py b/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py index 19baae3611d..5ff72f7a9ba 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py @@ -35,7 +35,7 @@ def __init__( left: Expr, right: Expr, ) -> None: - super().__init__(dtype) + self.dtype = dtype if plc.traits.is_boolean(self.dtype): # For boolean output types, bitand and bitor implement # boolean logic, so translate. bitxor also does, but the diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index ff9973a47d5..cd03e182076 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -42,7 +42,7 @@ def __init__( options: tuple[Any, ...], *children: Expr, ) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.name = name self.children = children diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index f752a23b628..8811c85b7f5 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -48,7 +48,7 @@ def __init__( options: tuple[Any, ...], *children: Expr, ) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.name = name self.children = children diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/literal.py b/python/cudf_polars/cudf_polars/dsl/expressions/literal.py index 562a2255033..55b9bb9fd10 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/literal.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/literal.py @@ -16,7 +16,7 @@ from cudf_polars.utils import dtypes if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Hashable, Mapping import pyarrow as pa @@ -34,7 +34,7 @@ class Literal(Expr): children: tuple[()] def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None: - super().__init__(dtype) + self.dtype = dtype assert value.type == plc.interop.to_arrow(dtype) self.value = value @@ -61,16 +61,16 @@ class LiteralColumn(Expr): children: tuple[()] def __init__(self, dtype: plc.DataType, value: pl.Series) -> None: - super().__init__(dtype) + self.dtype = dtype data = value.to_arrow() self.value = data.cast(dtypes.downcast_arrow_lists(data.type)) - def get_hash(self) -> int: + def get_hashable(self) -> Hashable: """Compute a hash of the column.""" # This is stricter than necessary, but we only need this hash # for identity in groupby replacements so it's OK. And this # way we avoid doing potentially expensive compute. - return hash((type(self), self.dtype, id(self.value))) + return (type(self), self.dtype, id(self.value)) def do_evaluate( self, diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py index f7dcc3c542c..bef95779745 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py @@ -22,7 +22,7 @@ class RollingWindow(Expr): children: tuple[Expr] def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.children = (agg,) raise NotImplementedError("Rolling window not implemented") @@ -34,7 +34,7 @@ class GroupedRollingWindow(Expr): children: tuple[Expr, ...] def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.children = (agg, *by) raise NotImplementedError("Grouped rolling window not implemented") diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/selection.py b/python/cudf_polars/cudf_polars/dsl/expressions/selection.py index a7a3e68a28c..9aada61bce6 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/selection.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/selection.py @@ -28,7 +28,7 @@ class Gather(Expr): children: tuple[Expr, Expr] def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None: - super().__init__(dtype) + self.dtype = dtype self.children = (values, indices) def do_evaluate( @@ -70,7 +70,7 @@ class Filter(Expr): children: tuple[Expr, Expr] def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr): - super().__init__(dtype) + self.dtype = dtype self.children = (values, indices) def do_evaluate( diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py index 861b73ce6a0..29d81f6d948 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/sorting.py @@ -30,7 +30,7 @@ class Sort(Expr): def __init__( self, dtype: plc.DataType, options: tuple[bool, bool, bool], column: Expr ) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.children = (column,) @@ -70,7 +70,7 @@ def __init__( column: Expr, *by: Expr, ) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.children = (column, *by) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 6669669aadc..f424720f5e4 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -39,7 +39,7 @@ def __init__( options: tuple[Any, ...], *children: Expr, ) -> None: - super().__init__(dtype) + self.dtype = dtype self.options = options self.name = name self.children = children diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py b/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py index c7d7a802ded..f6bd304a41f 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/ternary.py @@ -33,7 +33,7 @@ class Ternary(Expr): def __init__( self, dtype: plc.DataType, when: Expr, then: Expr, otherwise: Expr ) -> None: - super().__init__(dtype) + self.dtype = dtype self.children = (when, then, otherwise) def do_evaluate( diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py index 3d4d15be1ce..6bd9afb2ce2 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py @@ -31,7 +31,7 @@ class Cast(Expr): children: tuple[Expr] def __init__(self, dtype: plc.DataType, value: Expr) -> None: - super().__init__(dtype) + self.dtype = dtype self.children = (value,) if not dtypes.can_cast(value.dtype, self.dtype): raise NotImplementedError( @@ -62,6 +62,9 @@ class Len(Expr): children: tuple[()] + def __init__(self, dtype: plc.DataType) -> None: + self.dtype = dtype + def do_evaluate( self, df: DataFrame, @@ -142,7 +145,7 @@ class UnaryFunction(Expr): def __init__( self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr ) -> None: - super().__init__(dtype) + self.dtype = dtype self.name = name self.options = options self.children = children diff --git a/python/cudf_polars/tests/dsl/test_expr.py b/python/cudf_polars/tests/dsl/test_expr.py index b7d4672daca..48546f4aaa6 100644 --- a/python/cudf_polars/tests/dsl/test_expr.py +++ b/python/cudf_polars/tests/dsl/test_expr.py @@ -73,3 +73,26 @@ def test_namedexpr_repr_stable(): b2 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a")) assert repr(b1) == repr(b2) + + +def test_equality_cse(): + dt = plc.DataType(plc.TypeId.INT8) + + def make_expr(n1, n2): + a = expr.Col(plc.DataType(plc.TypeId.INT8), n1) + b = expr.Col(plc.DataType(plc.TypeId.INT8), n2) + + return expr.BinOp(dt, plc.binaryop.BinaryOperator.ADD, a, b) + + e1 = make_expr("a", "b") + e2 = make_expr("a", "b") + e3 = make_expr("a", "c") + + assert e1.children is not e2.children + assert e1 == e2 + assert e1.children is e2.children + assert e1 == e1 + assert e2 == e2 + assert e1 != e3 + assert e2 != e3 + assert e3 == e3 From ffe460c35f49b40690d8844752ee69e7cd2c65c4 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 09:39:58 +0000 Subject: [PATCH 04/23] Infrastructure for traversal and visitors And tests of basic functionality. --- .../cudf_polars/cudf_polars/dsl/traversal.py | 191 ++++++++++++++++++ .../cudf_polars/typing/__init__.py | 41 +++- python/cudf_polars/pyproject.toml | 2 +- .../cudf_polars/tests/dsl/test_traversal.py | 98 +++++++++ 4 files changed, 329 insertions(+), 3 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/dsl/traversal.py create mode 100644 python/cudf_polars/tests/dsl/test_traversal.py diff --git a/python/cudf_polars/cudf_polars/dsl/traversal.py b/python/cudf_polars/cudf_polars/dsl/traversal.py new file mode 100644 index 00000000000..2469b167a83 --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/traversal.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Traversal and visitor utilities for nodes.""" + +from __future__ import annotations + +from collections.abc import Hashable +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Mapping, MutableMapping + + from cudf_polars.dsl import expr, ir + from cudf_polars.dsl.nodebase import Node + from cudf_polars.typing import ExprTransformer, GenericTransformer, IRTransformer + + +__all__: list[str] = [ + "traversal", + "reuse_if_unchanged", + "make_recursive", + "CachingVisitor", +] + + +def traversal(node: Node) -> Generator[Node, None, None]: + """ + Pre-order traversal of nodes in an expression. + + Parameters + ---------- + node + Root of expression to traverse. + + Yields + ------ + Unique nodes in the expression, parent before child, children + in-order from left to right. + """ + seen = {node} + lifo = [node] + + while lifo: + node = lifo.pop() + yield node + for child in reversed(node.children): + if child not in seen: + seen.add(child) + lifo.append(child) + + +# reuse_if_unchanged can either be applied to expressions, in which +# case we need an ExprTransformer... +@overload +def reuse_if_unchanged(node: expr.Expr, fn: ExprTransformer) -> expr.Expr: ... + + +# .. or to plan nodes (IR), in which case we need an IRTransformer +@overload +def reuse_if_unchanged(node: ir.IR, fn: IRTransformer) -> ir.IR: ... + + +def reuse_if_unchanged(node, fn): + """ + Recipe for transforming nodes that returns the old object if unchanged. + + Parameters + ---------- + node + Node to recurse on + fn + Function to transform children + + Notes + ----- + This can be used as a generic "base case" handler when + writing transforms that take nodes and produce new nodes. + + Returns + ------- + Existing node `e` if transformed children are unchanged, otherwise + reconstructed node with new children. + """ + new_children = [fn(c) for c in node.children] + if all(new == old for new, old in zip(new_children, node.children, strict=True)): + return node + return node.reconstruct(new_children) + + +U_contra = TypeVar("U_contra", bound=Hashable, contravariant=True) +V_co = TypeVar("V_co", covariant=True) + + +def make_recursive( + fn: Callable[[U_contra, GenericTransformer[U_contra, V_co]], V_co], + *, + state: Mapping[str, Any] | None = None, +) -> GenericTransformer[U_contra, V_co]: + """ + No-op wrapper for recursive visitors. + + Facilitates using visitors that don't need caching but are written + in the same style. + + Parameters + ---------- + fn + Function to transform inputs to outputs. Should take as its + second argument a callable from input to output. + state + Arbitrary *immutable* state that should be accessible to the + visitor through the `state` property. + + Notes + ----- + All transformation functions *must* be pure. + + Usually, prefer a :class:`CachingVisitor`, but if we know that we + don't need caching in a transformation and then this no-op + approach is slightly cheaper. + + Returns + ------- + Recursive function without caching. + + See Also + -------- + CachingVisitor + """ + + def rec(node: U_contra) -> V_co: + return fn(node, rec) # type: ignore[arg-type] + + rec.state = state if state is not None else {} # type: ignore[attr-defined] + return rec # type: ignore[return-value] + + +class CachingVisitor(Generic[U_contra, V_co]): + """ + Caching wrapper for recursive visitors. + + Facilitates writing visitors where already computed results should + be cached and reused. The cache is managed automatically, and is + tied to the lifetime of the wrapper. + + Parameters + ---------- + fn + Function to transform inputs to outputs. Should take as its + second argument the recursive cache manager. + state + Arbitrary *immutable* state that should be accessible to the + visitor through the `state` property. + + Notes + ----- + All transformation functions *must* be pure. + + Returns + ------- + Recursive function with caching. + """ + + def __init__( + self, + fn: Callable[[U_contra, GenericTransformer[U_contra, V_co]], V_co], + *, + state: Mapping[str, Any] | None = None, + ) -> None: + self.fn = fn + self.cache: MutableMapping[U_contra, V_co] = {} + self.state = state if state is not None else {} + + def __call__(self, value: U_contra) -> V_co: + """ + Apply the function to a value. + + Parameters + ---------- + value + The value to transform. + + Returns + ------- + A transformed value. + """ + try: + return self.cache[value] + except KeyError: + return self.cache.setdefault(value, self.fn(value, self)) diff --git a/python/cudf_polars/cudf_polars/typing/__init__.py b/python/cudf_polars/cudf_polars/typing/__init__.py index 75c76d87242..45e70c808de 100644 --- a/python/cudf_polars/cudf_polars/typing/__init__.py +++ b/python/cudf_polars/cudf_polars/typing/__init__.py @@ -5,8 +5,8 @@ from __future__ import annotations -from collections.abc import Mapping -from typing import TYPE_CHECKING, Literal, Protocol, Union +from collections.abc import Hashable, Mapping +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union import pylibcudf as plc @@ -18,6 +18,18 @@ import polars as pl + from cudf_polars.dsl import expr, ir + +__all__: list[str] = [ + "PolarsIR", + "PolarsExpr", + "NodeTraverser", + "OptimizationArgs", + "GenericTransformer", + "ExprTransformer", + "IRTransformer", +] + PolarsIR: TypeAlias = Union[ pl_ir.PythonScan, pl_ir.Scan, @@ -107,3 +119,28 @@ def set_udf( "cluster_with_columns", "no_optimization", ] + + +U_contra = TypeVar("U_contra", bound=Hashable, contravariant=True) +V_co = TypeVar("V_co", covariant=True) + + +class GenericTransformer(Protocol[U_contra, V_co]): + """Abstract protocol for recursive visitors.""" + + def __call__(self, __value: U_contra) -> V_co: + """Apply the visitor to the node.""" + ... + + @property + def state(self) -> Mapping[str, Any]: + """Arbitrary immutable state.""" + ... + + +# Quotes to avoid circular import +ExprTransformer: TypeAlias = GenericTransformer["expr.Expr", "expr.Expr"] +"""Protocol for transformation of Expr nodes.""" + +IRTransformer: TypeAlias = GenericTransformer["ir.IR", "ir.IR"] +"""Protocol for transformation of IR nodes.""" diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 5345fad41a2..a8bb634732f 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -60,7 +60,7 @@ xfail_strict = true [tool.coverage.report] exclude_also = [ "if TYPE_CHECKING:", - "class .*\\bProtocol\\):", + "class .*\\bProtocol(?:\\[[^]]+\\])?\\):", "assert_never\\(" ] # The cudf_polars test suite doesn't exercise the plugin, so we omit diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py new file mode 100644 index 00000000000..bfafe2dce92 --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pylibcudf as plc + +from cudf_polars.dsl import expr +from cudf_polars.dsl.traversal import ( + CachingVisitor, + make_recursive, + reuse_if_unchanged, + traversal, +) + + +def make_expr(dt, n1, n2): + a1 = expr.Col(dt, n1) + a2 = expr.Col(dt, n2) + + return expr.BinOp(dt, plc.binaryop.BinaryOperator.MUL, a1, a2) + + +def test_traversal_unique(): + dt = plc.DataType(plc.TypeId.INT8) + + e1 = make_expr(dt, "a", "a") + unique_exprs = list(traversal(e1)) + + assert len(unique_exprs) == 2 + assert set(unique_exprs) == {expr.Col(dt, "a"), e1} + assert unique_exprs == [e1, expr.Col(dt, "a")] + + e2 = make_expr(dt, "a", "b") + unique_exprs = list(traversal(e2)) + + assert len(unique_exprs) == 3 + assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e2} + assert unique_exprs == [e2, expr.Col(dt, "a"), expr.Col(dt, "b")] + + e3 = make_expr(dt, "b", "a") + unique_exprs = list(traversal(e3)) + + assert len(unique_exprs) == 3 + assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e3} + assert unique_exprs == [e3, expr.Col(dt, "b"), expr.Col(dt, "a")] + + +def rename(e, rec): + mapping = rec.state["mapping"] + if isinstance(e, expr.Col) and e.name in mapping: + return type(e)(e.dtype, mapping[e.name]) + return reuse_if_unchanged(e, rec) + + +def test_caching_visitor(): + dt = plc.DataType(plc.TypeId.INT8) + + e1 = make_expr(dt, "a", "b") + + mapper = CachingVisitor(rename, state={"mapping": {"b": "c"}}) + + renamed = mapper(e1) + assert renamed == make_expr(dt, "a", "c") + assert len(mapper.cache) == 3 + + e2 = make_expr(dt, "a", "a") + mapper = CachingVisitor(rename, state={"mapping": {"b": "c"}}) + + renamed = mapper(e2) + assert renamed == make_expr(dt, "a", "a") + assert len(mapper.cache) == 2 + mapper = CachingVisitor(rename, state={"mapping": {"a": "c"}}) + + renamed = mapper(e2) + assert renamed == make_expr(dt, "c", "c") + assert len(mapper.cache) == 2 + + +def test_noop_visitor(): + dt = plc.DataType(plc.TypeId.INT8) + + e1 = make_expr(dt, "a", "b") + + mapper = make_recursive(rename, state={"mapping": {"b": "c"}}) + + renamed = mapper(e1) + assert renamed == make_expr(dt, "a", "c") + + e2 = make_expr(dt, "a", "a") + mapper = make_recursive(rename, state={"mapping": {"b": "c"}}) + + renamed = mapper(e2) + assert renamed == make_expr(dt, "a", "a") + mapper = make_recursive(rename, state={"mapping": {"a": "c"}}) + + renamed = mapper(e2) + assert renamed == make_expr(dt, "c", "c") From 83a60f08fd962b80bd18ec039fb925335f5a7369 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 8 Oct 2024 11:35:29 +0000 Subject: [PATCH 05/23] Use abstract Node infrastructure to define IR nodes This way we will be able to write generic traversals more easily. --- python/cudf_polars/cudf_polars/dsl/ir.py | 514 ++++++++++++------ .../cudf_polars/cudf_polars/dsl/translate.py | 40 +- python/cudf_polars/tests/test_config.py | 6 +- 3 files changed, 370 insertions(+), 190 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index e319c363a23..f503ea3f1d1 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -13,8 +13,8 @@ from __future__ import annotations -import dataclasses import itertools +import json from functools import cache from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar @@ -27,10 +27,11 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import Column, DataFrame -from cudf_polars.utils import dtypes, sorting +from cudf_polars.dsl.nodebase import Node +from cudf_polars.utils import dtypes if TYPE_CHECKING: - from collections.abc import Callable, MutableMapping + from collections.abc import Callable, Hashable, MutableMapping, Sequence from typing import Literal from cudf_polars.typing import Schema @@ -121,16 +122,27 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column ] -@dataclasses.dataclass -class IR: +class IR(Node): """Abstract plan node, representing an unevaluated dataframe.""" + __slots__ = ("schema",) + _non_child: ClassVar[tuple[str, ...]] = ("schema",) schema: Schema """Mapping from column names to their data types.""" + children: tuple[IR, ...] = () - def __post_init__(self): - """Validate preconditions.""" - pass # noqa: PIE790 + def get_hashable(self) -> Hashable: + """ + Hashable representation of node, treating schema dictionary. + + Since the schema is a dictionary, even though it is morally + immutable, it is not hashable. We therefore convert it to + tuples for hashing purposes. + """ + # Schema is the first constructor argument + args = self._ctor_arguments(self.children)[1:] + schema_hash = tuple(self.schema.items()) + return (type(self), schema_hash, args) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """ @@ -159,24 +171,49 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: ) # pragma: no cover -@dataclasses.dataclass class PythonScan(IR): """Representation of input from a python function.""" + __slots__ = ("options", "predicate") + _non_child = ("schema", "options", "predicate") options: Any """Arbitrary options.""" predicate: expr.NamedExpr | None """Filter to apply to the constructed dataframe before returning it.""" - def __post_init__(self): - """Validate preconditions.""" + def __init__(self, schema: Schema, options: Any, predicate: expr.NamedExpr | None): + self.schema = schema + self.options = options + self.predicate = predicate raise NotImplementedError("PythonScan not implemented") -@dataclasses.dataclass class Scan(IR): """Input from files.""" + __slots__ = ( + "typ", + "reader_options", + "cloud_options", + "paths", + "with_columns", + "skip_rows", + "n_rows", + "row_index", + "predicate", + ) + _non_child = ( + "schema", + "typ", + "reader_options", + "cloud_options", + "paths", + "with_columns", + "skip_rows", + "n_rows", + "row_index", + "predicate", + ) typ: str """What type of file are we reading? Parquet, CSV, etc...""" reader_options: dict[str, Any] @@ -185,7 +222,7 @@ class Scan(IR): """Cloud-related authentication options, currently ignored.""" paths: list[str] """List of paths to read from.""" - with_columns: list[str] + with_columns: list[str] | None """Projected columns to return.""" skip_rows: int """Rows to skip at the start when reading.""" @@ -196,9 +233,29 @@ class Scan(IR): predicate: expr.NamedExpr | None """Mask to apply to the read dataframe.""" - def __post_init__(self) -> None: - """Validate preconditions.""" - super().__post_init__() + def __init__( + self, + schema: Schema, + typ: str, + reader_options: dict[str, Any], + cloud_options: dict[str, Any] | None, + paths: list[str], + with_columns: list[str] | None, + skip_rows: int, + n_rows: int, + row_index: tuple[str, int] | None, + predicate: expr.NamedExpr | None, + ): + self.schema = schema + self.typ = typ + self.reader_options = reader_options + self.cloud_options = cloud_options + self.paths = paths + self.with_columns = with_columns + self.skip_rows = skip_rows + self.n_rows = n_rows + self.row_index = row_index + self.predicate = predicate if self.typ not in ("csv", "parquet", "ndjson"): # pragma: no cover # This line is unhittable ATM since IPC/Anonymous scan raise # on the polars side @@ -258,6 +315,28 @@ def __post_init__(self) -> None: "Reading only parquet metadata to produce row index." ) + def get_hashable(self) -> Hashable: + """ + Hashable representation of the node. + + The options dictionaries are serialised for hashing purposes + as json strings. + """ + schema_hash = tuple(self.schema.items()) + return ( + type(self), + schema_hash, + self.typ, + json.dumps(self.reader_options), + json.dumps(self.cloud_options), + tuple(self.paths), + tuple(self.with_columns) if self.with_columns is not None else None, + self.skip_rows, + self.n_rows, + self.row_index, + self.predicate, + ) + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" with_columns = self.with_columns @@ -401,7 +480,6 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return df.filter(mask) -@dataclasses.dataclass class Cache(IR): """ Return a cached plan node. @@ -409,20 +487,28 @@ class Cache(IR): Used for CSE at the plan level. """ + __slots__ = ("key", "children") + _non_child = ("schema", "key") + children: tuple[IR] key: int """The cache key.""" value: IR """The unevaluated node to cache.""" + def __init__(self, schema: Schema, key: int, value: IR): + self.schema = schema + self.key = key + self.children = (value,) + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" try: return cache[self.key] except KeyError: - return cache.setdefault(self.key, self.value.evaluate(cache=cache)) + (value,) = self.children + return cache.setdefault(self.key, value.evaluate(cache=cache)) -@dataclasses.dataclass class DataFrameScan(IR): """ Input from an existing polars DataFrame. @@ -430,13 +516,37 @@ class DataFrameScan(IR): This typically arises from ``q.collect().lazy()`` """ + __slots__ = ("df", "projection", "predicate") + _non_child = ("schema", "df", "projection", "predicate") df: Any """Polars LazyFrame object.""" - projection: list[str] + projection: tuple[str, ...] | None """List of columns to project out.""" predicate: expr.NamedExpr | None """Mask to apply.""" + def __init__( + self, + schema: Schema, + df: Any, + projection: Sequence[str] | None, + predicate: expr.NamedExpr | None, + ): + self.schema = schema + self.df = df + self.projection = tuple(projection) if projection is not None else None + self.predicate = predicate + + def get_hashable(self) -> Hashable: + """ + Hashable representation of the node. + + The (heavy) dataframe object is hashed as its id, so this is + not stable across runs, or repeat instances of the same equal dataframes. + """ + schema_hash = tuple(self.schema.items()) + return (type(self), schema_hash, id(self.df), self.projection, self.predicate) + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" pdf = pl.DataFrame._from_pydf(self.df) @@ -454,28 +564,42 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return df -@dataclasses.dataclass class Select(IR): """Produce a new dataframe selecting given expressions from an input.""" + __slots__ = ("exprs", "children", "should_broadcast") + _non_child = ("schema", "exprs", "should_broadcast") + children: tuple[IR] df: IR """Input dataframe.""" - expr: list[expr.NamedExpr] + exprs: tuple[expr.NamedExpr, ...] """List of expressions to evaluate to form the new dataframe.""" should_broadcast: bool """Should columns be broadcast?""" + def __init__( + self, + schema: Schema, + exprs: Sequence[expr.NamedExpr], + should_broadcast: bool, # noqa: FBT001 + df: IR, + ): + self.schema = schema + self.exprs = tuple(exprs) + self.should_broadcast = should_broadcast + self.children = (df,) + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) # Handle any broadcasting - columns = [e.evaluate(df) for e in self.expr] + columns = [e.evaluate(df) for e in self.exprs] if self.should_broadcast: columns = broadcast(*columns) return DataFrame(columns) -@dataclasses.dataclass class Reduce(IR): """ Produce a new dataframe selecting given expressions from an input. @@ -483,36 +607,70 @@ class Reduce(IR): This is a special case of :class:`Select` where all outputs are a single row. """ + __slots__ = ("exprs", "children") + _non_child = ("schema", "exprs") + df: IR """Input dataframe.""" - expr: list[expr.NamedExpr] + exprs: tuple[expr.NamedExpr, ...] """List of expressions to evaluate to form the new dataframe.""" + def __init__( + self, schema: Schema, exprs: Sequence[expr.NamedExpr], df: IR + ): # pragma: no cover; polars doesn't emit this node yet + self.schema = schema + self.exprs = tuple(exprs) + self.children = (df,) + def evaluate( self, *, cache: MutableMapping[int, DataFrame] ) -> DataFrame: # pragma: no cover; polars doesn't emit this node yet """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) - columns = broadcast(*(e.evaluate(df) for e in self.expr)) + (child,) = self.children + df = child.evaluate(cache=cache) + columns = broadcast(*(e.evaluate(df) for e in self.exprs)) assert all(column.obj.size() == 1 for column in columns) return DataFrame(columns) -@dataclasses.dataclass class GroupBy(IR): """Perform a groupby.""" - df: IR - """Input dataframe.""" - agg_requests: list[expr.NamedExpr] - """List of expressions to evaluate groupwise.""" - keys: list[expr.NamedExpr] - """List of expressions forming the keys.""" - maintain_order: bool - """Should the order of the input dataframe be maintained?""" - options: Any - """Options controlling style of groupby.""" - agg_infos: list[expr.AggInfo] = dataclasses.field(init=False) + __slots__ = ( + "agg_requests", + "keys", + "maintain_order", + "options", + "agg_infos", + "children", + ) + _non_child = ("schema", "keys", "agg_requests", "maintain_order", "options") + children: tuple[IR] + + def __init__( + self, + schema: Schema, + keys: Sequence[expr.NamedExpr], + agg_requests: Sequence[expr.NamedExpr], + maintain_order: bool, # noqa: FBT001 + options: Any, + df: IR, + ): + self.schema = schema + self.keys = tuple(keys) + self.agg_requests = tuple(agg_requests) + self.maintain_order = maintain_order + self.options = options + self.children = (df,) + if self.options.rolling: + raise NotImplementedError( + "rolling window/groupby" + ) # pragma: no cover; rollingwindow constructor has already raised + if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests): + raise NotImplementedError("Nested aggregations in groupby") + self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] + if len(self.keys) == 0: + raise NotImplementedError("dynamic groupby") @staticmethod def check_agg(agg: expr.Expr) -> int: @@ -542,22 +700,10 @@ def check_agg(agg: expr.Expr) -> int: else: raise NotImplementedError(f"No handler for {agg=}") - def __post_init__(self) -> None: - """Check whether all the aggregations are implemented.""" - super().__post_init__() - if self.options.rolling: - raise NotImplementedError( - "rolling window/groupby" - ) # pragma: no cover; rollingwindow constructor has already raised - if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests): - raise NotImplementedError("Nested aggregations in groupby") - self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] - if len(self.keys) == 0: - raise NotImplementedError("dynamic groupby") - def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) keys = broadcast( *(k.evaluate(df) for k in self.keys), target_length=df.num_rows ) @@ -646,17 +792,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return DataFrame(broadcasted).slice(self.options.slice) -@dataclasses.dataclass class Join(IR): """A join of two dataframes.""" - left: IR - """Left frame.""" - right: IR - """Right frame.""" - left_on: list[expr.NamedExpr] + __slots__ = ("left_on", "right_on", "options", "children") + _non_child = ("schema", "left_on", "right_on", "options") + left_on: tuple[expr.NamedExpr, ...] """List of expressions used as keys in the left frame.""" - right_on: list[expr.NamedExpr] + right_on: tuple[expr.NamedExpr, ...] """List of expressions used as keys in the right frame.""" options: tuple[ Literal["inner", "left", "right", "full", "leftsemi", "leftanti", "cross"], @@ -674,9 +817,20 @@ class Join(IR): - coalesce: should key columns be coalesced (only makes sense for outer joins) """ - def __post_init__(self) -> None: - """Validate preconditions.""" - super().__post_init__() + def __init__( + self, + schema: Schema, + left_on: Sequence[expr.NamedExpr], + right_on: Sequence[expr.NamedExpr], + options: Any, + left: IR, + right: IR, + ): + self.schema = schema + self.left_on = tuple(left_on) + self.right_on = tuple(right_on) + self.options = options + self.children = (left, right) if any( isinstance(e.value, expr.Literal) for e in itertools.chain(self.left_on, self.right_on) @@ -777,8 +931,7 @@ def _reorder_maps( def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - left = self.left.evaluate(cache=cache) - right = self.right.evaluate(cache=cache) + left, right = (c.evaluate(cache=cache) for c in self.children) how, join_nulls, zlice, suffix, coalesce = self.options suffix = "_right" if suffix is None else suffix if how == "cross": @@ -866,20 +1019,29 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return result.slice(zlice) -@dataclasses.dataclass class HStack(IR): """Add new columns to a dataframe.""" - df: IR - """Input dataframe.""" - columns: list[expr.NamedExpr] - """List of expressions to produce new columns.""" - should_broadcast: bool - """Should columns be broadcast?""" + __slots__ = ("columns", "should_broadcast", "children") + _non_child = ("schema", "columns", "should_broadcast") + children: tuple[IR] + + def __init__( + self, + schema: Schema, + columns: Sequence[expr.NamedExpr], + should_broadcast: bool, # noqa: FBT001 + df: IR, + ): + self.schema = schema + self.columns = tuple(columns) + self.should_broadcast = should_broadcast + self.children = (df,) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) columns = [c.evaluate(df) for c in self.columns] if self.should_broadcast: columns = broadcast(*columns, target_length=df.num_rows) @@ -895,20 +1057,28 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return df.with_columns(columns) -@dataclasses.dataclass class Distinct(IR): """Produce a new dataframe with distinct rows.""" - df: IR - """Input dataframe.""" - keep: plc.stream_compaction.DuplicateKeepOption - """Which rows to keep.""" - subset: set[str] | None - """Which columns to inspect when computing distinct rows.""" - zlice: tuple[int, int] | None - """Optional slice to perform after compaction.""" - stable: bool - """Should order be preserved?""" + __slots__ = ("keep", "subset", "zlice", "stable", "children") + _non_child = ("schema", "keep", "subset", "zlice", "stable") + children: tuple[IR] + + def __init__( + self, + schema: Schema, + keep: plc.stream_compaction.DuplicateKeepOption, + subset: frozenset[str] | None, + zlice: tuple[int, int] | None, + stable: bool, # noqa: FBT001 + df: IR, + ): + self.schema = schema + self.keep = keep + self.subset = subset + self.zlice = zlice + self.stable = stable + self.children = (df,) _KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = { "first": plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST, @@ -917,18 +1087,10 @@ class Distinct(IR): "any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY, } - def __init__(self, schema: Schema, df: IR, options: Any) -> None: - self.schema = schema - self.df = df - (keep, subset, maintain_order, zlice) = options - self.keep = Distinct._KEEP_MAP[keep] - self.subset = set(subset) if subset is not None else None - self.stable = maintain_order - self.zlice = zlice - def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) if self.subset is None: indices = list(range(df.num_columns)) keys_sorted = all(c.is_sorted for c in df.column_map.values()) @@ -967,46 +1129,35 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return result.slice(self.zlice) -@dataclasses.dataclass class Sort(IR): """Sort a dataframe.""" - df: IR - """Input.""" - by: list[expr.NamedExpr] - """List of expressions to produce sort keys.""" - do_sort: Callable[..., plc.Table] - """pylibcudf sorting function.""" - zlice: tuple[int, int] | None - """Optional slice to apply after sorting.""" - order: list[plc.types.Order] - """Order keys should be sorted in.""" - null_order: list[plc.types.NullOrder] - """Where nulls sort to.""" + __slots__ = ("by", "order", "null_order", "stable", "zlice", "children") + _non_child = ("schema", "by", "order", "null_order", "stable", "zlice") + children: tuple[IR] def __init__( self, schema: Schema, - df: IR, - by: list[expr.NamedExpr], - options: Any, + by: Sequence[expr.NamedExpr], + order: Sequence[plc.types.Order], + null_order: Sequence[plc.types.NullOrder], + stable: bool, # noqa: FBT001 zlice: tuple[int, int] | None, - ) -> None: + df: IR, + ): self.schema = schema - self.df = df - self.by = by + self.by = tuple(by) + self.order = tuple(order) + self.null_order = tuple(null_order) + self.stable = stable self.zlice = zlice - stable, nulls_last, descending = options - self.order, self.null_order = sorting.sort_order( - descending, nulls_last=nulls_last, num_keys=len(by) - ) - self.do_sort = ( - plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key - ) + self.children = (df,) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) sort_keys = broadcast( *(k.evaluate(df) for k in self.by), target_length=df.num_rows ) @@ -1016,11 +1167,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: for i, k in enumerate(sort_keys) if k.name in df.column_map and k.obj is df.column_map[k.name].obj } - table = self.do_sort( + do_sort = ( + plc.sorting.stable_sort_by_key if self.stable else plc.sorting.sort_by_key + ) + table = do_sort( df.table, plc.Table([k.obj for k in sort_keys]), - self.order, - self.null_order, + list(self.order), + list(self.null_order), ) columns: list[Column] = [] for name, c in zip(df.column_map, table.columns(), strict=True): @@ -1037,49 +1191,65 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return DataFrame(columns).slice(self.zlice) -@dataclasses.dataclass class Slice(IR): """Slice a dataframe.""" - df: IR - """Input.""" + __slots__ = ("offset", "length", "children") + _non_child = ("schema", "offset", "length") + children: tuple[IR] offset: int """Start of the slice.""" length: int """Length of the slice.""" + def __init__(self, schema: Schema, offset: int, length: int, df: IR): + self.schema = schema + self.offset = offset + self.length = length + self.children = (df,) + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) return df.slice((self.offset, self.length)) -@dataclasses.dataclass class Filter(IR): """Filter a dataframe with a boolean mask.""" - df: IR - """Input.""" - mask: expr.NamedExpr - """Expression evaluating to a mask.""" + __slots__ = ("mask", "children") + _non_child = ("schema", "mask") + children: tuple[IR] + + def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR): + self.schema = schema + self.mask = mask + self.children = (df,) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) (mask,) = broadcast(self.mask.evaluate(df), target_length=df.num_rows) return df.filter(mask) -@dataclasses.dataclass class Projection(IR): """Select a subset of columns from a dataframe.""" - df: IR - """Input.""" + __slots__ = ("children",) + _non_child = ("schema",) + children: tuple[IR] + + def __init__(self, schema: Schema, df: IR): + self.schema = schema + self.children = (df,) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - df = self.df.evaluate(cache=cache) + (child,) = self.children + df = child.evaluate(cache=cache) # This can reorder things. columns = broadcast( *(df.column_map[name] for name in self.schema), target_length=df.num_rows @@ -1087,16 +1257,13 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return DataFrame(columns) -@dataclasses.dataclass class MapFunction(IR): """Apply some function to a dataframe.""" - df: IR - """Input.""" - name: str - """Function name.""" + __slots__ = ("name", "options", "children") + _non_child = ("schema", "name", "options") + children: tuple[IR] options: Any - """Arbitrary options, interpreted per function.""" _NAMES: ClassVar[frozenset[str]] = frozenset( [ @@ -1111,9 +1278,11 @@ class MapFunction(IR): ] ) - def __post_init__(self) -> None: - """Validate preconditions.""" - super().__post_init__() + def __init__(self, schema: Schema, name: str, options: Any, df: IR): + self.schema = schema + self.name = name + self.options = options + self.children = (df,) if self.name not in MapFunction._NAMES: raise NotImplementedError(f"Unhandled map function {self.name}") if self.name == "explode": @@ -1127,7 +1296,7 @@ def __post_init__(self) -> None: old, new, _ = self.options # TODO: perhaps polars should validate renaming in the IR? if len(new) != len(set(new)) or ( - set(new) & (set(self.df.schema.keys()) - set(old)) + set(new) & (set(df.schema.keys()) - set(old)) ): raise NotImplementedError("Duplicate new names in rename.") elif self.name == "unpivot": @@ -1136,31 +1305,31 @@ def __post_init__(self) -> None: variable_name = "variable" if variable_name is None else variable_name if len(pivotees) == 0: index = frozenset(indices) - pivotees = [name for name in self.df.schema if name not in index] + pivotees = [name for name in df.schema if name not in index] if not all( - dtypes.can_cast(self.df.schema[p], self.schema[value_name]) - for p in pivotees + dtypes.can_cast(df.schema[p], self.schema[value_name]) for p in pivotees ): raise NotImplementedError( "Unpivot cannot cast all input columns to " f"{self.schema[value_name].id()}" ) - self.options = (indices, pivotees, variable_name, value_name) + self.options = (tuple(indices), tuple(pivotees), variable_name, value_name) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" + (child,) = self.children if self.name == "rechunk": # No-op in our data model # Don't think this appears in a plan tree from python - return self.df.evaluate(cache=cache) # pragma: no cover + return child.evaluate(cache=cache) # pragma: no cover elif self.name == "rename": - df = self.df.evaluate(cache=cache) + df = child.evaluate(cache=cache) # final tag is "swapping" which is useful for the # optimiser (it blocks some pushdown operations) old, new, _ = self.options return df.rename_columns(dict(zip(old, new, strict=True))) elif self.name == "explode": - df = self.df.evaluate(cache=cache) + df = child.evaluate(cache=cache) ((to_explode,),) = self.options index = df.column_names.index(to_explode) subset = df.column_names_set - {to_explode} @@ -1170,7 +1339,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: elif self.name == "unpivot": indices, pivotees, variable_name, value_name = self.options npiv = len(pivotees) - df = self.df.evaluate(cache=cache) + df = child.evaluate(cache=cache) index_columns = [ Column(col, name=name) for col, name in zip( @@ -1209,37 +1378,38 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: raise AssertionError("Should never be reached") # pragma: no cover -@dataclasses.dataclass class Union(IR): """Concatenate dataframes vertically.""" - dfs: list[IR] - """List of inputs.""" - zlice: tuple[int, int] | None - """Optional slice to apply after concatenation.""" + __slots__ = ("zlice", "children") + _non_child = ("schema", "zlice") - def __post_init__(self) -> None: - """Validate preconditions.""" - super().__post_init__() - schema = self.dfs[0].schema - if not all(s.schema == schema for s in self.dfs[1:]): + def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR): + self.schema = schema + self.zlice = zlice + self.children = children + schema = self.children[0].schema + if not all(s.schema == schema for s in self.children[1:]): raise NotImplementedError("Schema mismatch") def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" # TODO: only evaluate what we need if we have a slice - dfs = [df.evaluate(cache=cache) for df in self.dfs] + dfs = [df.evaluate(cache=cache) for df in self.children] return DataFrame.from_table( plc.concatenate.concatenate([df.table for df in dfs]), dfs[0].column_names ).slice(self.zlice) -@dataclasses.dataclass class HConcat(IR): """Concatenate dataframes horizontally.""" - dfs: list[IR] - """List of inputs.""" + __slots__ = ("children",) + _non_child = ("schema",) + + def __init__(self, schema: Schema, *children: IR): + self.schema = schema + self.children = children @staticmethod def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table: @@ -1271,7 +1441,7 @@ def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table: def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - dfs = [df.evaluate(cache=cache) for df in self.dfs] + dfs = [df.evaluate(cache=cache) for df in self.children] max_rows = max(df.num_rows for df in dfs) # Horizontal concatenation extends shorter tables with nulls dfs = [ diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index a0291037f01..522c4a6729c 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -20,7 +20,7 @@ from cudf_polars.dsl import expr, ir from cudf_polars.typing import NodeTraverser -from cudf_polars.utils import dtypes +from cudf_polars.utils import dtypes, sorting __all__ = ["translate_ir", "translate_named_expr"] @@ -148,7 +148,7 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.expr] - return ir.Select(schema, inp, exprs, node.should_broadcast) + return ir.Select(schema, exprs, node.should_broadcast, inp) @_translate_ir.register @@ -161,11 +161,11 @@ def _( keys = [translate_named_expr(visitor, n=e) for e in node.keys] return ir.GroupBy( schema, - inp, - aggs, keys, + aggs, node.maintain_order, node.options, + inp, ) @@ -182,7 +182,7 @@ def _( with set_node(visitor, node.input_right): inp_right = translate_ir(visitor, n=None) right_on = [translate_named_expr(visitor, n=e) for e in node.right_on] - return ir.Join(schema, inp_left, inp_right, left_on, right_on, node.options) + return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right) @_translate_ir.register @@ -192,7 +192,7 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] - return ir.HStack(schema, inp, exprs, node.should_broadcast) + return ir.HStack(schema, exprs, node.should_broadcast, inp) @_translate_ir.register @@ -202,17 +202,23 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) exprs = [translate_named_expr(visitor, n=e) for e in node.expr] - return ir.Reduce(schema, inp, exprs) + return ir.Reduce(schema, exprs, inp) @_translate_ir.register def _( node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: + (keep, subset, maintain_order, zlice) = node.options + keep = ir.Distinct._KEEP_MAP[keep] + subset = frozenset(subset) if subset is not None else None return ir.Distinct( schema, + keep, + subset, + zlice, + maintain_order, translate_ir(visitor, n=node.input), - node.options, ) @@ -223,14 +229,18 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) by = [translate_named_expr(visitor, n=e) for e in node.by_column] - return ir.Sort(schema, inp, by, node.sort_options, node.slice) + stable, nulls_last, descending = node.sort_options + order, null_order = sorting.sort_order( + descending, nulls_last=nulls_last, num_keys=len(by) + ) + return ir.Sort(schema, by, order, null_order, stable, node.slice, inp) @_translate_ir.register def _( node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len) + return ir.Slice(schema, node.offset, node.len, translate_ir(visitor, n=node.input)) @_translate_ir.register @@ -240,7 +250,7 @@ def _( with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) mask = translate_named_expr(visitor, n=node.predicate) - return ir.Filter(schema, inp, mask) + return ir.Filter(schema, mask, inp) @_translate_ir.register @@ -259,10 +269,10 @@ def _( name, *options = node.function return ir.MapFunction( schema, - # TODO: merge_sorted breaks this pattern - translate_ir(visitor, n=node.input), name, options, + # TODO: merge_sorted breaks this pattern + translate_ir(visitor, n=node.input), ) @@ -271,7 +281,7 @@ def _( node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: return ir.Union( - schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options + schema, node.options, *(translate_ir(visitor, n=n) for n in node.inputs) ) @@ -279,7 +289,7 @@ def _( def _( node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs]) + return ir.HConcat(schema, *(translate_ir(visitor, n=n) for n in node.inputs)) def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: diff --git a/python/cudf_polars/tests/test_config.py b/python/cudf_polars/tests/test_config.py index 3c3986be19b..9900f598e5f 100644 --- a/python/cudf_polars/tests/test_config.py +++ b/python/cudf_polars/tests/test_config.py @@ -10,7 +10,7 @@ import rmm -from cudf_polars.dsl.ir import IR +from cudf_polars.dsl.ir import DataFrameScan from cudf_polars.testing.asserts import ( assert_gpu_result_equal, assert_ir_translation_raises, @@ -18,10 +18,10 @@ def test_polars_verbose_warns(monkeypatch): - def raise_unimplemented(self): + def raise_unimplemented(self, *args): raise NotImplementedError("We don't support this") - monkeypatch.setattr(IR, "__post_init__", raise_unimplemented) + monkeypatch.setattr(DataFrameScan, "__init__", raise_unimplemented) q = pl.LazyFrame({}) # Ensure that things raise assert_ir_translation_raises(q, NotImplementedError) From a234e37310666e26ca8d30367ed66ecc21965bd9 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 09:42:58 +0000 Subject: [PATCH 06/23] Add tests of traversal over IR nodes Now that we have a uniform child attribute, this is easier. --- .../cudf_polars/tests/dsl/test_traversal.py | 61 ++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py index bfafe2dce92..f0b03e70be9 100644 --- a/python/cudf_polars/tests/dsl/test_traversal.py +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -5,7 +5,11 @@ import pylibcudf as plc -from cudf_polars.dsl import expr +import polars as pl +from polars.testing import assert_frame_equal + +from cudf_polars import translate_ir +from cudf_polars.dsl import expr, ir from cudf_polars.dsl.traversal import ( CachingVisitor, make_recursive, @@ -96,3 +100,58 @@ def test_noop_visitor(): renamed = mapper(e2) assert renamed == make_expr(dt, "c", "c") + + +def test_rewrite_ir_node(): + df = pl.LazyFrame({"a": [1, 2, 1], "b": [1, 3, 4]}) + q = df.group_by("a").agg(pl.col("b").sum()).sort("b") + + orig = translate_ir(q._ldf.visit()) + + new_df = pl.DataFrame({"a": [1, 1, 2], "b": [-1, -2, -4]}) + + def replace_df(node, rec): + if isinstance(node, ir.DataFrameScan): + return ir.DataFrameScan( + node.schema, new_df._df, node.projection, node.predicate + ) + return reuse_if_unchanged(node, rec) + + mapper = CachingVisitor(replace_df) + + new = mapper(orig) + + result = new.evaluate(cache={}).to_polars() + + expect = pl.DataFrame({"a": [2, 1], "b": [-4, -3]}) + + assert_frame_equal(result, expect) + + +def test_rewrite_scan_node(tmp_path): + left = pl.LazyFrame({"a": [1, 2, 3], "b": [1, 3, 4]}) + right = pl.DataFrame({"a": [1, 4, 2], "c": [1, 2, 3]}) + + right.write_parquet(tmp_path / "right.pq") + + right_s = pl.scan_parquet(tmp_path / "right.pq") + + q = left.join(right_s, on="a", how="inner") + + def replace_scan(node, rec): + if isinstance(node, ir.Scan): + return ir.DataFrameScan( + node.schema, right._df, node.with_columns, node.predicate + ) + return reuse_if_unchanged(node, rec) + + mapper = CachingVisitor(replace_scan) + + orig = translate_ir(q._ldf.visit()) + new = mapper(orig) + + result = new.evaluate(cache={}).to_polars() + + expect = q.collect() + + assert_frame_equal(result, expect, check_row_order=False) From 73019c83e5c001d5b2a9bb6b677a058c39b4b268 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 11:02:49 +0000 Subject: [PATCH 07/23] Overview documentation for visitor pattern/utilities --- python/cudf_polars/docs/overview.md | 214 ++++++++++++++++++++++++---- 1 file changed, 188 insertions(+), 26 deletions(-) diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index 7837a275f20..967fbf95ea0 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -15,10 +15,11 @@ You will need: ## Installing polars -`cudf-polars` works with polars >= 1.3, as long as the internal IR -version doesn't get a major version bump. So `pip install polars>=1.3` -should work. For development, if we're adding things to the polars -side of things, we will need to build polars from source: +The `cudf-polars` `pyproject.toml` advertises which polars versions it +works with. So for pure `cudf-polars` development, installing as +normal and satisfying the dependencies in the repository is +sufficient. For development, if we're adding things to the polars side +of things, we will need to build polars from source: ```sh git clone https://github.com/pola-rs/polars @@ -126,7 +127,6 @@ arguments, at the moment, `raise_on_fail` is also supported, which raises, rather than falling back, during translation: ```python - result = q.collect(engine=pl.GPUEngine(raise_on_fail=True)) ``` @@ -144,11 +144,69 @@ changes. We can therefore attempt to detect the IR version appropriately. This should be done during IR translation in `translate.py`. -## Adding a handler for a new plan node +# IR design + +As noted, we translate the polars DSL into our own IR. This is both so +that we can smooth out minor version differences (advertised by +`NodeTraverser` version changes) within `cudf-polars`, and so that we +have the freedom to introduce new IR nodes and rewrite rules as might +be appropriate for GPU execution. + +To that end, we provide facilities for definition of nodes as well as +writing traversals and rewrite rules. The abstract base class `Node` +in `dsl/nodebase.py` defines the interface for implementing new nodes, +and provides many useful default methods. See also the docstrings of +the `Node` class. + +> ![NOTE] This generic implementation relies on nodes being treated as +> *immutable*. Do not implement in-place modification of nodes, bad +> things will happen. + +## Defining nodes + +A concrete node type (`cudf-polars` has ones for expressions `Expr` +and ones for plan nodes `IR`), should inherit from `Node`. Nodes have +two types of data: + +1. `children`: a tuple (possibly empty) of concrete nodes +2. non-child: arbitrary data attached to the node that is _not_ a + concrete node. + +The base `Node` class requires that one advertise the _names_ of the +non-child attributes in the `_non_child` class variable. The +constructor of the concrete node should take its arguments in the +order `*_non_child` (ordered as the class variable does) and then +`*children`. For example, the `Sort` node, which sorts a column +generated by an expression, has this definition: + +```python +class Expr(Node): + children: tuple[Expr, ...] + +class Sort(Expr): + _non_child = ("dtype", "options") + children: tuple[Expr] + def __init__(self, dtype, options, column: Expr): + self.dtype = dtype + self.options = options + self.children = (column,) +``` + +By following this pattern, we get an automatic (caching) +implementation of `__hash__` and `__eq__`, as well as a useful +`reconstruct` method that will rebuild the node with new children. + +If you want to control the behaviour of `__hash__` and `__eq__` for a +single node, override (respectively) the `get_hashable` and `is_equal` +methods. + +## Adding new translation rules from the polars IR + +### Plan nodes -Plan node definitions live in `cudf_polars/dsl/ir.py`, these are -`dataclasses` that inherit from the base `IR` node. The evaluation of -a plan node is done by implementing the `evaluate` method. +Plan node definitions live in `cudf_polars/dsl/ir.py`, these all +inherit from the base `IR` node. The evaluation of a plan node is done +by implementing the `evaluate` method. To translate the plan node, add a case handler in `translate_ir` which lives in `cudf_polars/dsl/translate.py`. @@ -163,25 +221,12 @@ translating a `Join` node, the left keys (expressions) should be translated with the left input active (and right keys with right input). To facilitate this, use the `set_node` context manager. -## Adding a handler for a new expression node +### Expression nodes Adding a handle for an expression node is very similar to a plan node. -Expressions are all defined in `cudf_polars/dsl/expr.py` and inherit -from `Expr`. Unlike plan nodes, these are not `dataclasses`, since it -is simpler for us to implement efficient hashing, repr, and equality if we -can write that ourselves. - -Every expression consists of two types of data: -1. child data (other `Expr`s) -2. non-child data (anything other than an `Expr`) -The generic implementations of special methods in the base `Expr` base -class require that the subclasses advertise which arguments to the -constructor are non-child in a `_non_child` class slot. The -constructor should then take arguments: -```python -def __init__(self, *non_child_data: Any, *children: Expr): -``` -Read the docstrings in the `Expr` class for more details. +Expressions are defined in `cudf_polars/dsl/expressions/` and exported +into the `dsl` namespace via `expr.py`. They inherit +from `Expr`. Expressions are evaluated by implementing a `do_evaluate` method that takes a `DataFrame` as context (this provides columns) along with an @@ -198,6 +243,123 @@ To simplify state tracking, all columns should be considered immutable on construction. This matches the "functional" description coming from the logical plan in any case, so is reasonably natural. +## Traversing and transforming nodes + +As well as just representing and evaluating nodes. We also provide +facilities for traversing a tree of nodes and defining transformation +rules in `dsl/traversal.py`. The simplest is `traversal`, this yields +all _unique_ nodes in an expression parent before child, children +in-order left to right (i.e. a pre-order traversal). Use this if you +want to know some specific thing about an expression. For example, to +determine if an expression contains a `Literal` node: + +```python +def has_literal(node: Expr) -> bool: + return any(isinstance(e, Literal) for e in traversal(node)) +``` + +For transformations and rewrites, we use the following generic +pattern. Rather than defining methods on each node in turn for a +particular rewrite rule, we prefer free functions and use +`functools.singledispatch` to provide dispatching. + +It is often convenient to provide (immutable) state to a visitor, as +well as some facility to perform DAG-aware rewrites (reusing a +transformation for an expression if we have already seen it). We +therefore adopt the following pattern of writing DAG-aware visitors. +Suppose we want a rewrite rule (`rewrite`) between expressions +(`Expr`) and some new type `T`. We define our general transformation +function `rewrite` with type `Expr -> (Expr -> T) -> T`: + +```python +from cudf_polars.typing import GenericTransformer + +@singledispatch +def rewrite(e: Expr, rec: GenericTransformer[Expr, T]) -> T: + ... +``` + +Note in particular that the function to perform the recursion is +passed as the second argument. We now, in the usual fashion, register +handlers for different expression types. To use this function, we need +to be able to provide both the expression to convert and the recursive +function itself. To do this we must convert our `rewrite` function +into something that only takes a single argument (the expression to +rewrite), but carries around information about how to perform the +recursion. To this end, we have two utilities in `traversal.py`: + +- `make_recursive` and +- `CachingVisitor`. + +These both implement the `GenericTransformer` protocol, and can be +wrapped around a transformation function like `rewrite` to provide a +function `Expr -> T`. They also allow us to attach arbitrary +*immutable* state to our visitor by passing a `state` dictionary. This +dictionary can then be inspected by the concrete transformation +function. `make_recursive` is very simple, and provides no caching of +intermediate results (so any DAGs that are visited will be viewed as +trees). `CachingVisitor` provides the same interface, but maintains a +cache of intermediate results, and reuses them if the same expression +is seen again. + +Finally, for writing transformations that take nodes and deliver new +nodes (e.g. rewrite rules), we have a final utility +`reuse_if_unchanged` which can be used as a base case transformation +for node to node rewrites. It is a depth-first visit that transforms +children but only returns a new node with new children if the rewrite +on children changed things. + +To see how these pieces fit together, let us consider writing a +`rename` function that takes an expression (potentially with +references to columns) along with a mapping defining a renaming +between (some subset of) column names. The goal is to deliver a new +expression with appropriate columns renamed. + +To start, we define the dispatch function +```python +from collections.abc import Mapping +from functools import singledispatch +from cudf_polars.dsl.traversal import ( + CachingVisitor, make_recursive, reuse_if_unchanged +) +from cudf_polars.dsl.expr import Col, Expr +from cudf_polars.typing import ExprTransformer + + +@singledispatch +def _rename(e: Expr, rec: ExprTransformer) -> Expr: + raise NotImplementedError(f"No handler for {type(e)}") +``` +then we register specific handlers, first for columns: +```python +@_rename.register +def _(e: Col, rec: ExprTransformer) -> Expr: + mapping = rec.state["mapping"] # state set on rec + if e.name in mapping: + # If we have a rename, return a new Col reference + # with a new name + return type(e)(e.dtype, mapping[e.name]) + return e +``` +and then for the remaining expressions +```python +_rename.register(Expr)(reuse_if_unchanged) +``` +> ![NOTE] In this case, we could have put the generic handler in +> the `_rename` function, however, then we would not get a nice error +> message if we accidentally sent in an object of the incorrect type. + +Finally we tie everything together with a public function: + +```python +def rename(e: Expr, mapping: Mapping[str, str]) -> Expr: + """Rename column references in an expression.""" + mapper = CachingVisitor(_rename, state={"mapping": mapping}) + # or + # mapper = make_recursive(_rename, state={"mapping": mapping}) + return mapper(e) +``` + # Containers Containers should be constructed as relatively lightweight objects From b14b150aec62f6e161babe18ab3df4ee9a7ba20f Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 16:41:13 +0000 Subject: [PATCH 08/23] Some grammar fixes --- python/cudf_polars/docs/overview.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index 967fbf95ea0..ccebb519760 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -164,11 +164,11 @@ the `Node` class. ## Defining nodes -A concrete node type (`cudf-polars` has ones for expressions `Expr` -and ones for plan nodes `IR`), should inherit from `Node`. Nodes have +A concrete node type (`cudf-polars` has expression nodes, `Expr`; +and plan nodes, `IR`), should inherit from `Node`. Nodes have two types of data: -1. `children`: a tuple (possibly empty) of concrete nodes +1. `children`: a tuple (possibly empty) of concrete nodes; 2. non-child: arbitrary data attached to the node that is _not_ a concrete node. @@ -247,11 +247,11 @@ the logical plan in any case, so is reasonably natural. As well as just representing and evaluating nodes. We also provide facilities for traversing a tree of nodes and defining transformation -rules in `dsl/traversal.py`. The simplest is `traversal`, this yields -all _unique_ nodes in an expression parent before child, children -in-order left to right (i.e. a pre-order traversal). Use this if you -want to know some specific thing about an expression. For example, to -determine if an expression contains a `Literal` node: +rules in `dsl/traversal.py`. The simplest is `traversal`, a +[pre-order](https://en.wikipedia.org/wiki/Tree_traversal) visit of all +unique nodes in an expression. Use this if you want to know some +specific thing about an expression. For example, to determine if an +expression contains a `Literal` node: ```python def has_literal(node: Expr) -> bool: From a49846f6000f043862bec46a153dfe9dd0ea09fd Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 16:49:17 +0000 Subject: [PATCH 09/23] Reinstate docstrings for properties --- python/cudf_polars/cudf_polars/dsl/ir.py | 41 ++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index f503ea3f1d1..d24a4c7fcf5 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -130,6 +130,7 @@ class IR(Node): schema: Schema """Mapping from column names to their data types.""" children: tuple[IR, ...] = () + """Child IR nodes that are inputs to this one.""" def get_hashable(self) -> Hashable: """ @@ -645,6 +646,14 @@ class GroupBy(IR): "children", ) _non_child = ("schema", "keys", "agg_requests", "maintain_order", "options") + keys: tuple[expr.NamedExpr, ...] + """Grouping keys.""" + agg_requests: tuple[expr.NamedExpr, ...] + """Aggregation expressions.""" + maintain_order: bool + """Preserve order in groupby.""" + options: Any + """Arbitrary options.""" children: tuple[IR] def __init__( @@ -1024,6 +1033,8 @@ class HStack(IR): __slots__ = ("columns", "should_broadcast", "children") _non_child = ("schema", "columns", "should_broadcast") + should_broadcast: bool + """Should the resulting evaluated columns be broadcast to the same length.""" children: tuple[IR] def __init__( @@ -1062,6 +1073,15 @@ class Distinct(IR): __slots__ = ("keep", "subset", "zlice", "stable", "children") _non_child = ("schema", "keep", "subset", "zlice", "stable") + keep: plc.stream_compaction.DuplicateKeepOption + """Which distinct value to keep.""" + subset: frozenset[str] | None + """Which columns should be used to define distinctness. If None, + then all columns are used.""" + zlice: tuple[int, int] | None + """Optional slice to apply to the result.""" + stable: bool + """Should the result maintain ordering.""" children: tuple[IR] def __init__( @@ -1134,6 +1154,16 @@ class Sort(IR): __slots__ = ("by", "order", "null_order", "stable", "zlice", "children") _non_child = ("schema", "by", "order", "null_order", "stable", "zlice") + by: tuple[expr.NamedExpr, ...] + """Sort keys.""" + order: tuple[plc.types.Order, ...] + """Sort order for each sort key.""" + null_order: tuple[plc.types.NullOrder, ...] + """Null sorting location for each sort key.""" + stable: bool + """Should the sort be stable?""" + zlice: tuple[int, int] | None + """Optional slice to apply to the result.""" children: tuple[IR] def __init__( @@ -1196,11 +1226,11 @@ class Slice(IR): __slots__ = ("offset", "length", "children") _non_child = ("schema", "offset", "length") - children: tuple[IR] offset: int """Start of the slice.""" length: int """Length of the slice.""" + children: tuple[IR] def __init__(self, schema: Schema, offset: int, length: int, df: IR): self.schema = schema @@ -1220,6 +1250,8 @@ class Filter(IR): __slots__ = ("mask", "children") _non_child = ("schema", "mask") + mask: expr.NamedExpr + """Expression to produce the filter mask.""" children: tuple[IR] def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR): @@ -1262,8 +1294,11 @@ class MapFunction(IR): __slots__ = ("name", "options", "children") _non_child = ("schema", "name", "options") - children: tuple[IR] + name: str + """Name of the function to apply""" options: Any + """Arbitrary name-specific options""" + children: tuple[IR] _NAMES: ClassVar[frozenset[str]] = frozenset( [ @@ -1383,6 +1418,8 @@ class Union(IR): __slots__ = ("zlice", "children") _non_child = ("schema", "zlice") + zlice: tuple[int, int] | None + """Optional slice to apply to the result.""" def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR): self.schema = schema From 9449b44ebde4cae76303f5312b748a224799feb6 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 14 Oct 2024 16:51:41 +0000 Subject: [PATCH 10/23] Use side-effect free rather than pure --- python/cudf_polars/cudf_polars/dsl/traversal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/traversal.py b/python/cudf_polars/cudf_polars/dsl/traversal.py index 2469b167a83..0f5f38c623b 100644 --- a/python/cudf_polars/cudf_polars/dsl/traversal.py +++ b/python/cudf_polars/cudf_polars/dsl/traversal.py @@ -114,7 +114,7 @@ def make_recursive( Notes ----- - All transformation functions *must* be pure. + All transformation functions *must* be free of side-effects. Usually, prefer a :class:`CachingVisitor`, but if we know that we don't need caching in a transformation and then this no-op @@ -155,7 +155,7 @@ class CachingVisitor(Generic[U_contra, V_co]): Notes ----- - All transformation functions *must* be pure. + All transformation functions *must* be free of side-effects. Returns ------- From e36ead1879183eeaddae4c81c47f6abe3e3945bd Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 11 Oct 2024 13:56:35 +0000 Subject: [PATCH 11/23] A few type annotations --- python/cudf_polars/cudf_polars/testing/asserts.py | 2 +- python/cudf_polars/cudf_polars/testing/plugin.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index 7b6f3848fc4..7b45c1eaa06 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -151,7 +151,7 @@ def assert_collect_raises( collect_kwargs: dict[OptimizationArgs, bool] | None = None, polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None, cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None, -): +) -> None: """ Assert that collecting the result of a query raises the expected exceptions. diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py index 05b76d76808..a3e0896806b 100644 --- a/python/cudf_polars/cudf_polars/testing/plugin.py +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -16,7 +16,7 @@ from collections.abc import Mapping -def pytest_addoption(parser: pytest.Parser): +def pytest_addoption(parser: pytest.Parser) -> None: """Add plugin-specific options.""" group = parser.getgroup( "cudf-polars", "Plugin to set GPU as default engine for polars tests" @@ -28,7 +28,7 @@ def pytest_addoption(parser: pytest.Parser): ) -def pytest_configure(config: pytest.Config): +def pytest_configure(config: pytest.Config) -> None: """Enable use of this module as a pytest plugin to enable GPU collection.""" no_fallback = config.getoption("--cudf-polars-no-fallback") collect = polars.LazyFrame.collect @@ -148,7 +148,7 @@ def pytest_configure(config: pytest.Config): def pytest_collection_modifyitems( session: pytest.Session, config: pytest.Config, items: list[pytest.Item] -): +) -> None: """Mark known failing tests.""" if config.getoption("--cudf-polars-no-fallback"): # Don't xfail tests if running without fallback From 6fced33daea90cbf3e34fd422496651c9cd8f995 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:30:24 +0000 Subject: [PATCH 12/23] Expose all type ids and match order with libcudf --- python/pylibcudf/pylibcudf/libcudf/types.pxd | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pylibcudf/pylibcudf/libcudf/types.pxd b/python/pylibcudf/pylibcudf/libcudf/types.pxd index eabae68bc90..60e293e5cdb 100644 --- a/python/pylibcudf/pylibcudf/libcudf/types.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/types.pxd @@ -70,18 +70,19 @@ cdef extern from "cudf/types.hpp" namespace "cudf" nogil: TIMESTAMP_MILLISECONDS TIMESTAMP_MICROSECONDS TIMESTAMP_NANOSECONDS - DICTIONARY32 - STRING - LIST - STRUCT - NUM_TYPE_IDS + DURATION_DAYS DURATION_SECONDS DURATION_MILLISECONDS DURATION_MICROSECONDS DURATION_NANOSECONDS + DICTIONARY32 + STRING + LIST DECIMAL32 DECIMAL64 DECIMAL128 + STRUCT + NUM_TYPE_IDS cdef cppclass data_type: data_type() except + From 7de74d4a0a49d111948ca9b93bfe3cba87b21b12 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 10 Oct 2024 15:40:29 +0000 Subject: [PATCH 13/23] Support all types for scalars in pylibcudf Expressions --- python/pylibcudf/pylibcudf/expressions.pyx | 50 ++++++++++++++++++- .../pylibcudf/libcudf/wrappers/durations.pxd | 5 +- .../pylibcudf/libcudf/wrappers/timestamps.pxd | 5 +- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/python/pylibcudf/pylibcudf/expressions.pyx b/python/pylibcudf/pylibcudf/expressions.pyx index a44c9e25987..1535f68366b 100644 --- a/python/pylibcudf/pylibcudf/expressions.pyx +++ b/python/pylibcudf/pylibcudf/expressions.pyx @@ -5,7 +5,17 @@ from pylibcudf.libcudf.expressions import \ table_reference as TableReference # no-cython-lint from cython.operator cimport dereference -from libc.stdint cimport int32_t, int64_t +from libc.stdint cimport ( + int8_t, + int16_t, + int32_t, + int64_t, + uint8_t, + uint16_t, + uint32_t, + uint64_t, +) +from libcpp cimport bool from libcpp.memory cimport make_unique, unique_ptr from libcpp.string cimport string from libcpp.utility cimport move @@ -18,12 +28,14 @@ from pylibcudf.libcudf.scalar.scalar cimport ( ) from pylibcudf.libcudf.types cimport size_type, type_id from pylibcudf.libcudf.wrappers.durations cimport ( + duration_D, duration_ms, duration_ns, duration_s, duration_us, ) from pylibcudf.libcudf.wrappers.timestamps cimport ( + timestamp_D, timestamp_ms, timestamp_ns, timestamp_s, @@ -78,6 +90,34 @@ cdef class Literal(Expression): self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) )) + elif tid == type_id.INT16: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.INT8: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT64: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT32: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT16: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.UINT8: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) + elif tid == type_id.BOOL8: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) elif tid == type_id.FLOAT64: self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) @@ -110,6 +150,10 @@ cdef class Literal(Expression): self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) )) + elif tid == type_id.TIMESTAMP_DAYS: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) elif tid == type_id.DURATION_NANOSECONDS: self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) @@ -130,6 +174,10 @@ cdef class Literal(Expression): self.c_obj = move(make_unique[libcudf_exp.literal]( dereference(self.scalar.c_obj) )) + elif tid == type_id.DURATION_DAYS: + self.c_obj = move(make_unique[libcudf_exp.literal]( + dereference(self.scalar.c_obj) + )) else: raise NotImplementedError( f"Don't know how to make literal with type id {tid}" diff --git a/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd b/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd index 7c648425eb5..c9c960d0a79 100644 --- a/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/wrappers/durations.pxd @@ -1,9 +1,10 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. -from libc.stdint cimport int64_t +from libc.stdint cimport int32_t, int64_t cdef extern from "cudf/wrappers/durations.hpp" namespace "cudf" nogil: + ctypedef int32_t duration_D ctypedef int64_t duration_s ctypedef int64_t duration_ms ctypedef int64_t duration_us diff --git a/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd b/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd index 50d37fd0a68..5dcd144529d 100644 --- a/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/wrappers/timestamps.pxd @@ -1,9 +1,10 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. -from libc.stdint cimport int64_t +from libc.stdint cimport int32_t, int64_t cdef extern from "cudf/wrappers/timestamps.hpp" namespace "cudf" nogil: + ctypedef int32_t timestamp_D ctypedef int64_t timestamp_s ctypedef int64_t timestamp_ms ctypedef int64_t timestamp_us From 2caabfc7fffb61e0232ed2dfa411c136763dca7d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:34:56 +0000 Subject: [PATCH 14/23] Expose compute_column --- .../pylibcudf/pylibcudf/libcudf/transform.pxd | 5 ++++ python/pylibcudf/pylibcudf/transform.pxd | 3 ++ python/pylibcudf/pylibcudf/transform.pyx | 29 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/python/pylibcudf/pylibcudf/libcudf/transform.pxd b/python/pylibcudf/pylibcudf/libcudf/transform.pxd index d21510bd731..47d79083b66 100644 --- a/python/pylibcudf/pylibcudf/libcudf/transform.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/transform.pxd @@ -27,6 +27,11 @@ cdef extern from "cudf/transform.hpp" namespace "cudf" nogil: column_view input ) except + + cdef unique_ptr[column] compute_column( + table_view table, + expression expr + ) except + + cdef unique_ptr[column] transform( column_view input, string unary_udf, diff --git a/python/pylibcudf/pylibcudf/transform.pxd b/python/pylibcudf/pylibcudf/transform.pxd index b530f433c97..4fb623158f0 100644 --- a/python/pylibcudf/pylibcudf/transform.pxd +++ b/python/pylibcudf/pylibcudf/transform.pxd @@ -3,6 +3,7 @@ from libcpp cimport bool from pylibcudf.libcudf.types cimport bitmask_type, data_type from .column cimport Column +from .expressions cimport Expression from .gpumemoryview cimport gpumemoryview from .table cimport Table from .types cimport DataType @@ -10,6 +11,8 @@ from .types cimport DataType cpdef tuple[gpumemoryview, int] nans_to_nulls(Column input) +cpdef Column compute_column(Table input, Expression expr) + cpdef tuple[gpumemoryview, int] bools_to_mask(Column input) cpdef Column mask_to_bools(Py_ssize_t bitmask, int begin_bit, int end_bit) diff --git a/python/pylibcudf/pylibcudf/transform.pyx b/python/pylibcudf/pylibcudf/transform.pyx index 74134caeb78..6b7bc6ddb37 100644 --- a/python/pylibcudf/pylibcudf/transform.pyx +++ b/python/pylibcudf/pylibcudf/transform.pyx @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. +from cython.operator cimport dereference from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move, pair @@ -43,6 +44,34 @@ cpdef tuple[gpumemoryview, int] nans_to_nulls(Column input): ) +cpdef Column compute_column(Table input, Expression expr): + """Create a column by evaluating an expression on a table. + + For details see :cpp:func:`compute_column`. + + Parameters + ---------- + input : Table + Table used for expression evaluation + expr : Expression + Expression to evaluate + + Returns + ------- + Column of the evaluated expression + """ + cdef unique_ptr[column] c_result + + with nogil: + c_result = move( + cpp_transform.compute_column( + input.view(), dereference(expr.c_obj.get()) + ) + ) + + return Column.from_libcudf(move(c_result)) + + cpdef tuple[gpumemoryview, int] bools_to_mask(Column input): """Create a bitmask from a column of boolean elements From 0f5a67018f6b2e37085dade221eb6e4aaba9a803 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 10 Oct 2024 11:35:35 +0000 Subject: [PATCH 15/23] Implement conversion from Expr nodes to pylibcudf Expressions We will use this for inequality joins and filter pushdown in the parquet reader. The handling is a bit complicated, since the subset of expressions that the parquet filter accepts is smaller than all possible expressions. Since much of the logic is similar, however, we just dispatch on a transformer state variable to determine which case we're handling. --- python/cudf_polars/cudf_polars/dsl/to_ast.py | 263 +++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 python/cudf_polars/cudf_polars/dsl/to_ast.py diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py new file mode 100644 index 00000000000..b90ab1d3869 --- /dev/null +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Conversion of expression nodes to libcudf AST nodes.""" + +from __future__ import annotations + +from functools import partial, reduce, singledispatch +from typing import TYPE_CHECKING, TypeAlias + +import pylibcudf as plc +from pylibcudf import expressions as pexpr + +from polars.polars import _expr_nodes as pl_expr + +from cudf_polars.dsl import expr +from cudf_polars.dsl.traversal import make_recursive +from cudf_polars.typing import GenericTransformer + +if TYPE_CHECKING: + from collections.abc import Mapping + +# Can't merge these op-mapping dictionaries because scoped enum values +# are exposed by cython with equality/hash based one their underlying +# representation type. So in a dict they are just treated as integers. +BINOP_TO_ASTOP = { + plc.binaryop.BinaryOperator.EQUAL: pexpr.ASTOperator.EQUAL, + plc.binaryop.BinaryOperator.NULL_EQUALS: pexpr.ASTOperator.NULL_EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL: pexpr.ASTOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS: pexpr.ASTOperator.LESS, + plc.binaryop.BinaryOperator.LESS_EQUAL: pexpr.ASTOperator.LESS_EQUAL, + plc.binaryop.BinaryOperator.GREATER: pexpr.ASTOperator.GREATER, + plc.binaryop.BinaryOperator.GREATER_EQUAL: pexpr.ASTOperator.GREATER_EQUAL, + plc.binaryop.BinaryOperator.ADD: pexpr.ASTOperator.ADD, + plc.binaryop.BinaryOperator.SUB: pexpr.ASTOperator.SUB, + plc.binaryop.BinaryOperator.MUL: pexpr.ASTOperator.MUL, + plc.binaryop.BinaryOperator.DIV: pexpr.ASTOperator.DIV, + plc.binaryop.BinaryOperator.TRUE_DIV: pexpr.ASTOperator.TRUE_DIV, + plc.binaryop.BinaryOperator.FLOOR_DIV: pexpr.ASTOperator.FLOOR_DIV, + plc.binaryop.BinaryOperator.PYMOD: pexpr.ASTOperator.PYMOD, + plc.binaryop.BinaryOperator.BITWISE_AND: pexpr.ASTOperator.BITWISE_AND, + plc.binaryop.BinaryOperator.BITWISE_OR: pexpr.ASTOperator.BITWISE_OR, + plc.binaryop.BinaryOperator.BITWISE_XOR: pexpr.ASTOperator.BITWISE_XOR, + plc.binaryop.BinaryOperator.LOGICAL_AND: pexpr.ASTOperator.LOGICAL_AND, + plc.binaryop.BinaryOperator.LOGICAL_OR: pexpr.ASTOperator.LOGICAL_OR, + plc.binaryop.BinaryOperator.NULL_LOGICAL_AND: pexpr.ASTOperator.NULL_LOGICAL_AND, + plc.binaryop.BinaryOperator.NULL_LOGICAL_OR: pexpr.ASTOperator.NULL_LOGICAL_OR, +} + +UOP_TO_ASTOP = { + plc.unary.UnaryOperator.SIN: pexpr.ASTOperator.SIN, + plc.unary.UnaryOperator.COS: pexpr.ASTOperator.COS, + plc.unary.UnaryOperator.TAN: pexpr.ASTOperator.TAN, + plc.unary.UnaryOperator.ARCSIN: pexpr.ASTOperator.ARCSIN, + plc.unary.UnaryOperator.ARCCOS: pexpr.ASTOperator.ARCCOS, + plc.unary.UnaryOperator.ARCTAN: pexpr.ASTOperator.ARCTAN, + plc.unary.UnaryOperator.SINH: pexpr.ASTOperator.SINH, + plc.unary.UnaryOperator.COSH: pexpr.ASTOperator.COSH, + plc.unary.UnaryOperator.TANH: pexpr.ASTOperator.TANH, + plc.unary.UnaryOperator.ARCSINH: pexpr.ASTOperator.ARCSINH, + plc.unary.UnaryOperator.ARCCOSH: pexpr.ASTOperator.ARCCOSH, + plc.unary.UnaryOperator.ARCTANH: pexpr.ASTOperator.ARCTANH, + plc.unary.UnaryOperator.EXP: pexpr.ASTOperator.EXP, + plc.unary.UnaryOperator.LOG: pexpr.ASTOperator.LOG, + plc.unary.UnaryOperator.SQRT: pexpr.ASTOperator.SQRT, + plc.unary.UnaryOperator.CBRT: pexpr.ASTOperator.CBRT, + plc.unary.UnaryOperator.CEIL: pexpr.ASTOperator.CEIL, + plc.unary.UnaryOperator.FLOOR: pexpr.ASTOperator.FLOOR, + plc.unary.UnaryOperator.ABS: pexpr.ASTOperator.ABS, + plc.unary.UnaryOperator.RINT: pexpr.ASTOperator.RINT, + plc.unary.UnaryOperator.BIT_INVERT: pexpr.ASTOperator.BIT_INVERT, + plc.unary.UnaryOperator.NOT: pexpr.ASTOperator.NOT, +} + +SUPPORTED_STATISTICS_BINOPS = { + plc.binaryop.BinaryOperator.EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS, + plc.binaryop.BinaryOperator.LESS_EQUAL, + plc.binaryop.BinaryOperator.GREATER, + plc.binaryop.BinaryOperator.GREATER_EQUAL, +} + +REVERSED_COMPARISON = { + plc.binaryop.BinaryOperator.EQUAL: plc.binaryop.BinaryOperator.EQUAL, + plc.binaryop.BinaryOperator.NOT_EQUAL: plc.binaryop.BinaryOperator.NOT_EQUAL, + plc.binaryop.BinaryOperator.LESS: plc.binaryop.BinaryOperator.GREATER, + plc.binaryop.BinaryOperator.LESS_EQUAL: plc.binaryop.BinaryOperator.GREATER_EQUAL, + plc.binaryop.BinaryOperator.GREATER: plc.binaryop.BinaryOperator.LESS, + plc.binaryop.BinaryOperator.GREATER_EQUAL: plc.binaryop.BinaryOperator.LESS_EQUAL, +} + + +Transformer: TypeAlias = GenericTransformer[expr.Expr, pexpr.Expression] + + +@singledispatch +def _to_ast(node: expr.Expr, self: Transformer) -> pexpr.Expression: + """ + Translate an expression to a pylibcudf Expression. + + Parameters + ---------- + node + Expression to translate. + self + Recursive transformer. The state dictionary should contain a + `for_parquet` key indicating if this transformation should + provide an expression suitable for use in parquet filters. + + If `for_parquet` is `False`, the dictionary should contain a + `name_to_index` mapping that maps column names to their + integer index in the table that will be used for evaluation of + the expression. + + Returns + ------- + pylibcudf Expression. + + Raises + ------ + NotImplementedError or KeyError if the expression cannot be translated. + """ + raise NotImplementedError(f"Unhandled expression type {type(node)}") + + +@_to_ast.register +def _(node: expr.Col, self: Transformer) -> pexpr.Expression: + if self.state["for_parquet"]: + return pexpr.ColumnNameReference(node.name) + return pexpr.ColumnReference(self.state["name_to_index"][node.name]) + + +@_to_ast.register +def _(node: expr.Literal, self: Transformer) -> pexpr.Expression: + return pexpr.Literal(plc.interop.from_arrow(node.value)) + + +@_to_ast.register +def _(node: expr.BinOp, self: Transformer) -> pexpr.Expression: + if node.op == plc.binaryop.BinaryOperator.NULL_NOT_EQUALS: + return pexpr.Operation( + pexpr.ASTOperator.NOT, + self( + # Reconstruct and apply, rather than directly + # constructing the right expression so we get the + # handling of parquet special cases for free. + expr.BinOp( + node.dtype, plc.binaryop.BinaryOperator.NULL_EQUALS, *node.children + ) + ), + ) + if self.state["for_parquet"]: + op1_col, op2_col = (isinstance(op, expr.Col) for op in node.children) + if op1_col ^ op2_col: + op = node.op + if op not in SUPPORTED_STATISTICS_BINOPS: + raise NotImplementedError( + f"Parquet filter binop with column doesn't support {node.op!r}" + ) + op1, op2 = node.children + if op2_col: + (op1, op2) = (op2, op1) + op = REVERSED_COMPARISON[op] + if not isinstance(op2, expr.Literal): + raise NotImplementedError( + "Parquet filter binops must have form 'col binop literal'" + ) + return pexpr.Operation(BINOP_TO_ASTOP[op], self(op1), self(op2)) + elif op1_col and op2_col: + raise NotImplementedError( + "Parquet filter binops must have one column reference not two" + ) + return pexpr.Operation(BINOP_TO_ASTOP[node.op], *map(self, node.children)) + + +@_to_ast.register +def _(node: expr.BooleanFunction, self: Transformer) -> pexpr.Expression: + if node.name == pl_expr.BooleanFunction.IsIn: + needles, haystack = node.children + if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: + # 16 is an arbitrary limit + needle_ref = self(needles) + values = [pexpr.Literal(plc.interop.from_arrow(v)) for v in haystack.value] + return reduce( + partial(pexpr.Operation, pexpr.ASTOperator.LOGICAL_OR), + ( + pexpr.Operation(pexpr.ASTOperator.EQUAL, needle_ref, value) + for value in values + ), + ) + if self.state["for_parquet"] and isinstance(node.children[0], expr.Col): + raise NotImplementedError( + f"Parquet filters don't support {node.name} on columns" + ) + if node.name == pl_expr.BooleanFunction.IsNull: + return pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])) + elif node.name == pl_expr.BooleanFunction.IsNotNull: + return pexpr.Operation( + pexpr.ASTOperator.NOT, + pexpr.Operation(pexpr.ASTOperator.IS_NULL, self(node.children[0])), + ) + elif node.name == pl_expr.BooleanFunction.Not: + return pexpr.Operation(pexpr.ASTOperator.NOT, self(node.children[0])) + raise NotImplementedError(f"AST conversion does not support {node.name}") + + +@_to_ast.register +def _(node: expr.UnaryFunction, self: Transformer) -> pexpr.Expression: + if isinstance(node.children[0], expr.Col) and self.state["for_parquet"]: + raise NotImplementedError( + "Parquet filters don't support {node.name} on columns" + ) + return pexpr.Operation( + UOP_TO_ASTOP[node._OP_MAPPING[node.name]], self(node.children[0]) + ) + + +def to_parquet_filter(node: expr.Expr) -> pexpr.Expression | None: + """ + Convert an expression to libcudf AST nodes suitable for parquet filtering. + + Parameters + ---------- + node + Expression to convert. + + Returns + ------- + pylibcudf Expression if conversion is possible, otherwise None. + """ + mapper: Transformer = make_recursive(_to_ast, state={"for_parquet": True}) + try: + return mapper(node) + except (KeyError, NotImplementedError): + return None + + +def to_ast( + node: expr.Expr, *, name_to_index: Mapping[str, int] +) -> pexpr.Expression | None: + """ + Convert an expression to libcudf AST nodes suitable for compute_column. + + Parameters + ---------- + node + Expression to convert. + name_to_index + Mapping from column names to their index in the table that + will be used for expression evaluation. + + Returns + ------- + pylibcudf Expressoin if conversion is possible, otherwise None. + """ + mapper: Transformer = make_recursive( + _to_ast, state={"for_parquet": False, "name_to_index": name_to_index} + ) + try: + return mapper(node) + except (KeyError, NotImplementedError): + return None From 6cb5440acb804001a3f365c3441ed472839ba4b9 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 10 Oct 2024 18:04:59 +0000 Subject: [PATCH 16/23] Implement predicate pushdown into parquet read We attempt to turn the predicate into a filter expression that the parquet reader understands. If successful then we don't have to apply the predicate as a post-filter. We can only do this when a row index is not requested. --- python/cudf_polars/cudf_polars/dsl/ir.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index d24a4c7fcf5..0dc53eba4e0 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -28,6 +28,7 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import Column, DataFrame from cudf_polars.dsl.nodebase import Node +from cudf_polars.dsl.to_ast import to_parquet_filter from cudf_polars.utils import dtypes if TYPE_CHECKING: @@ -417,9 +418,14 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: colnames[0], ) elif self.typ == "parquet": + filters = None + if self.predicate is not None and self.row_index is None: + # Can't apply filters during read if we have a row index. + filters = to_parquet_filter(self.predicate.value) tbl_w_meta = plc.io.parquet.read_parquet( plc.io.SourceInfo(self.paths), columns=with_columns, + filters=filters, nrows=n_rows, skip_rows=self.skip_rows, ) @@ -428,6 +434,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # TODO: consider nested column names? tbl_w_meta.column_names(include_children=False), ) + if filters is not None: + # Mask must have been applied. + return df elif self.typ == "ndjson": json_schema: list[tuple[str, str, list]] = [ (name, typ, []) for name, typ in self.schema.items() From 1545db850a2da827d0d6ea8b7edbdaf2c19996d3 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:35:36 +0000 Subject: [PATCH 17/23] Add tests of parquet filters --- .../cudf_polars/tests/test_parquet_filters.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 python/cudf_polars/tests/test_parquet_filters.py diff --git a/python/cudf_polars/tests/test_parquet_filters.py b/python/cudf_polars/tests/test_parquet_filters.py new file mode 100644 index 00000000000..652a7452c7e --- /dev/null +++ b/python/cudf_polars/tests/test_parquet_filters.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.fixture(scope="module") +def df(): + return pl.DataFrame( + { + "c": ["a", "b", "c", "d", "e", "f"], + "a": [1, 2, 3, None, 4, 5], + "b": pl.Series([None, None, 3, None, 4, 0], dtype=pl.Float64), + } + ) + + +@pytest.fixture(scope="module") +def pq_file(tmp_path_factory, df): + tmp_path = tmp_path_factory.mktemp("parquet_filter") + df.write_parquet(tmp_path / "tmp.pq", row_group_size=3) + return pl.scan_parquet(tmp_path / "tmp.pq") + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("a").is_in([0, 1]), + pl.col("a").is_between(0, 2), + (pl.col("a") < 2).not_(), + pl.lit(2) > pl.col("a"), + pl.lit(2) >= pl.col("a"), + pl.lit(2) < pl.col("a"), + pl.lit(2) <= pl.col("a"), + pl.lit(0) == pl.col("a"), + pl.lit(1) != pl.col("a"), + (pl.col("b") < pl.lit(2, dtype=pl.Float64).sqrt()), + (pl.col("a") >= pl.lit(2)) & (pl.col("b") > 0), + pl.col("a").is_null(), + pl.col("a").is_not_null(), + pl.col("a").abs().is_between(0, 2), + pl.col("a").ne_missing(pl.lit(None, dtype=pl.Int64)), + ], +) +@pytest.mark.parametrize("selection", [["c", "b"], ["a"], ["a", "c"], ["b"], "c"]) +def test_scan_by_hand(expr, selection, pq_file): + df = pq_file.collect() + q = pq_file.filter(expr).select(*selection) + # Not using assert_gpu_result_equal because + # https://github.com/pola-rs/polars/issues/19238 + got = q.collect(engine=pl.GPUEngine(raise_on_fail=True)) + expect = df.filter(expr).select(*selection) + assert_frame_equal(got, expect) From dff0aa154661f656fbef72b65f437ccc695bde9a Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 10:37:46 +0000 Subject: [PATCH 18/23] Add tests of to_ast and column compute --- python/cudf_polars/tests/dsl/test_to_ast.py | 70 +++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 python/cudf_polars/tests/dsl/test_to_ast.py diff --git a/python/cudf_polars/tests/dsl/test_to_ast.py b/python/cudf_polars/tests/dsl/test_to_ast.py new file mode 100644 index 00000000000..5d185c5adff --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_to_ast.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pylibcudf as plc +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +import cudf_polars.dsl.ir as ir_nodes +from cudf_polars import translate_ir +from cudf_polars.containers.dataframe import DataFrame, NamedColumn +from cudf_polars.dsl.to_ast import to_ast + + +@pytest.fixture(scope="module") +def df(): + return pl.LazyFrame( + { + "c": ["a", "b", "c", "d", "e", "f"], + "a": [1, 2, 3, None, 4, 5], + "b": pl.Series([None, None, 3, None, 4, 0], dtype=pl.Float64), + } + ) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("a").is_in([0, 1]), + pl.col("a").is_between(0, 2), + (pl.col("a") < pl.col("b")).not_(), + pl.lit(2) > pl.col("a"), + pl.lit(2) >= pl.col("a"), + pl.lit(2) < pl.col("a"), + pl.lit(2) <= pl.col("a"), + pl.lit(0) == pl.col("a"), + pl.lit(1) != pl.col("a"), + (pl.col("b") < pl.lit(2, dtype=pl.Float64).sqrt()), + (pl.col("a") >= pl.lit(2)) & (pl.col("b") > 0), + pl.col("a").is_null(), + pl.col("a").is_not_null(), + pl.col("a").abs().is_between(0, 2), + pl.col("a").ne_missing(pl.lit(None, dtype=pl.Int64)), + [pl.col("a") * 2, pl.col("b") + pl.col("a")], + ], +) +def test_compute_column(expr, df): + q = df.select(expr) + ir = translate_ir(q._ldf.visit()) + + assert isinstance(ir, ir_nodes.Select) + table = ir.children[0].evaluate(cache={}) + name_to_index = {c.name: i for i, c in enumerate(table.columns)} + + def compute_column(e): + ast = to_ast(e.value, name_to_index=name_to_index) + if ast is not None: + return NamedColumn( + plc.transform.compute_column(table.table, ast), name=e.name + ) + return e.evaluate(table) + + got = DataFrame(map(compute_column, ir.exprs)).to_polars() + + expect = q.collect() + + assert_frame_equal(expect, got) From d9175ad119c240180e225d2d03e0c726d7bf32f8 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 16 Oct 2024 13:11:40 +0000 Subject: [PATCH 19/23] WIP: expose mixed and conditional joins in pylibcudf --- python/pylibcudf/pylibcudf/join.pxd | 76 ++++ python/pylibcudf/pylibcudf/join.pyx | 405 ++++++++++++++++++++ python/pylibcudf/pylibcudf/libcudf/join.pxd | 113 ++++++ 3 files changed, 594 insertions(+) diff --git a/python/pylibcudf/pylibcudf/join.pxd b/python/pylibcudf/pylibcudf/join.pxd index 06969b4a2db..bb9162b466a 100644 --- a/python/pylibcudf/pylibcudf/join.pxd +++ b/python/pylibcudf/pylibcudf/join.pxd @@ -3,6 +3,7 @@ from pylibcudf.libcudf.types cimport null_equality from .column cimport Column +from .expressions cimport Expression from .table cimport Table @@ -37,3 +38,78 @@ cpdef Column left_anti_join( ) cpdef Table cross_join(Table left, Table right) + +cpdef tuple conditional_inner_join( + Table left, + Table right, + Expression binary_predicate, +) + +cpdef tuple conditional_left_join( + Table left, + Table right, + Expression binary_predicate, +) + +cpdef tuple conditional_full_join( + Table left, + Table right, + Expression binary_predicate, +) + +cpdef Column conditional_left_semi_join( + Table left, + Table right, + Expression binary_predicate, +) + +cpdef Column conditional_left_anti_join( + Table left, + Table right, + Expression binary_predicate, +) + +cpdef tuple mixed_inner_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +) + +cpdef tuple mixed_left_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +) + +cpdef tuple mixed_full_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +) + +cpdef Column mixed_left_semi_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +) + +cpdef Column mixed_left_anti_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +) diff --git a/python/pylibcudf/pylibcudf/join.pyx b/python/pylibcudf/pylibcudf/join.pyx index b019ed8f099..728c9499d04 100644 --- a/python/pylibcudf/pylibcudf/join.pyx +++ b/python/pylibcudf/pylibcudf/join.pyx @@ -12,6 +12,7 @@ from pylibcudf.libcudf.types cimport null_equality from rmm.librmm.device_buffer cimport device_buffer from .column cimport Column +from .expressions cimport Expression from .table cimport Table @@ -214,3 +215,407 @@ cpdef Table cross_join(Table left, Table right): with nogil: result = move(cpp_join.cross_join(left.view(), right.view())) return Table.from_libcudf(move(result)) + + +cpdef tuple conditional_inner_join( + Table left, + Table right, + Expression binary_predicate, +): + """Perform a conditional inner join between two tables. + + For details, see :cpp:func:`conditional_inner_join`. + + Parameters + ---------- + left : Table + The left table to join. + right : Table + The right table to join. + binary_predicate : Expression + Condition to join on. + + Returns + ------- + Tuple[Column, Column] + A tuple containing the row indices from the left and right tables after the + join. + """ + cdef cpp_join.gather_map_pair_type c_result + with nogil: + c_result = cpp_join.conditional_inner_join( + left.view(), right.view(), dereference(binary_predicate.c_obj.get()) + ) + return ( + _column_from_gather_map(move(c_result.first)), + _column_from_gather_map(move(c_result.second)), + ) + + +cpdef tuple conditional_left_join( + Table left, + Table right, + Expression binary_predicate, +): + """Perform a conditional left join between two tables. + + For details, see :cpp:func:`conditional_left_join`. + + Parameters + ---------- + left : Table + The left table to join. + right : Table + The right table to join. + binary_predicate : Expression + Condition to join on. + + Returns + ------- + Tuple[Column, Column] + A tuple containing the row indices from the left and right tables after the + join. + """ + cdef cpp_join.gather_map_pair_type c_result + with nogil: + c_result = cpp_join.conditional_left_join( + left.view(), right.view(), dereference(binary_predicate.c_obj.get()) + ) + return ( + _column_from_gather_map(move(c_result.first)), + _column_from_gather_map(move(c_result.second)), + ) + + +cpdef tuple conditional_full_join( + Table left, + Table right, + Expression binary_predicate, +): + """Perform a conditional full join between two tables. + + For details, see :cpp:func:`conditional_full_join`. + + Parameters + ---------- + left : Table + The left table to join. + right : Table + The right table to join. + binary_predicate : Expression + Condition to join on. + + Returns + ------- + Tuple[Column, Column] + A tuple containing the row indices from the left and right tables after the + join. + """ + cdef cpp_join.gather_map_pair_type c_result + with nogil: + c_result = cpp_join.conditional_full_join( + left.view(), right.view(), dereference(binary_predicate.c_obj.get()) + ) + return ( + _column_from_gather_map(move(c_result.first)), + _column_from_gather_map(move(c_result.second)), + ) + + +cpdef Column conditional_left_semi_join( + Table left, + Table right, + Expression binary_predicate, +): + """Perform a conditional left semi join between two tables. + + For details, see :cpp:func:`conditional_left_semi_join`. + + Parameters + ---------- + left : Table + The left table to join. + right : Table + The right table to join. + binary_predicate : Expression + Condition to join on. + + Returns + ------- + Column + A column containing the row indices from the left table after the join. + """ + cdef cpp_join.gather_map_type c_result + with nogil: + c_result = cpp_join.conditional_left_semi_join( + left.view(), right.view(), dereference(binary_predicate.c_obj.get()) + ) + return _column_from_gather_map(move(c_result)) + + +cpdef Column conditional_left_anti_join( + Table left, + Table right, + Expression binary_predicate, +): + """Perform a conditional left anti join between two tables. + + For details, see :cpp:func:`conditional_left_anti_join`. + + Parameters + ---------- + left : Table + The left table to join. + right : Table + The right table to join. + binary_predicate : Expression + Condition to join on. + + Returns + ------- + Column + A column containing the row indices from the left table after the join. + """ + cdef cpp_join.gather_map_type c_result + with nogil: + c_result = cpp_join.conditional_left_anti_join( + left.view(), right.view(), dereference(binary_predicate.c_obj.get()) + ) + return _column_from_gather_map(move(c_result)) + + +cpdef tuple mixed_inner_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +): + """Perform a mixed inner join between two tables. + + For details, see :cpp:func:`mixed_inner_join`. + + Parameters + ---------- + left_keys : Table + The left table to use for the equality join. + right_keys : Table + The right table to use for the equality join. + left_conditional : Table + The left table to use for the conditional join. + right_conditional : Table + The right table to use for the conditional join. + binary_predicate : Expression + Condition to join on. + nulls_equal : NullEquality + Should nulls compare equal in the equality join? + + Returns + ------- + Tuple[Column, Column] + A tuple containing the row indices from the left and right tables after the + join. + """ + cdef cpp_join.gather_map_pair_type c_result + with nogil: + c_result = cpp_join.mixed_inner_join( + left_keys.view(), + right_keys.view(), + left_conditional.view(), + right_conditional.view(), + dereference(binary_predicate.c_obj.get()), + nulls_equal, + ) + return ( + _column_from_gather_map(move(c_result.first)), + _column_from_gather_map(move(c_result.second)), + ) + + +cpdef tuple mixed_left_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +): + """Perform a mixed left join between two tables. + + For details, see :cpp:func:`mixed_left_join`. + + Parameters + ---------- + left_keys : Table + The left table to use for the equality join. + right_keys : Table + The right table to use for the equality join. + left_conditional : Table + The left table to use for the conditional join. + right_conditional : Table + The right table to use for the conditional join. + binary_predicate : Expression + Condition to join on. + nulls_equal : NullEquality + Should nulls compare equal in the equality join? + + Returns + ------- + Tuple[Column, Column] + A tuple containing the row indices from the left and right tables after the + join. + """ + cdef cpp_join.gather_map_pair_type c_result + with nogil: + c_result = cpp_join.mixed_left_join( + left_keys.view(), + right_keys.view(), + left_conditional.view(), + right_conditional.view(), + dereference(binary_predicate.c_obj.get()), + nulls_equal, + ) + return ( + _column_from_gather_map(move(c_result.first)), + _column_from_gather_map(move(c_result.second)), + ) + + +cpdef tuple mixed_full_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +): + """Perform a mixed full join between two tables. + + For details, see :cpp:func:`mixed_full_join`. + + Parameters + ---------- + left_keys : Table + The left table to use for the equality join. + right_keys : Table + The right table to use for the equality join. + left_conditional : Table + The left table to use for the conditional join. + right_conditional : Table + The right table to use for the conditional join. + binary_predicate : Expression + Condition to join on. + nulls_equal : NullEquality + Should nulls compare equal in the equality join? + + Returns + ------- + Tuple[Column, Column] + A tuple containing the row indices from the left and right tables after the + join. + """ + cdef cpp_join.gather_map_pair_type c_result + with nogil: + c_result = cpp_join.mixed_full_join( + left_keys.view(), + right_keys.view(), + left_conditional.view(), + right_conditional.view(), + dereference(binary_predicate.c_obj.get()), + nulls_equal, + ) + return ( + _column_from_gather_map(move(c_result.first)), + _column_from_gather_map(move(c_result.second)), + ) + + +cpdef Column mixed_left_semi_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +): + """Perform a mixed left semi join between two tables. + + For details, see :cpp:func:`mixed_left_semi_join`. + + Parameters + ---------- + left_keys : Table + The left table to use for the equality join. + right_keys : Table + The right table to use for the equality join. + left_conditional : Table + The left table to use for the conditional join. + right_conditional : Table + The right table to use for the conditional join. + binary_predicate : Expression + Condition to join on. + nulls_equal : NullEquality + Should nulls compare equal in the equality join? + + Returns + ------- + Column + A column containing the row indices from the left table after the join. + """ + cdef cpp_join.gather_map_type c_result + with nogil: + c_result = cpp_join.mixed_left_semi_join( + left_keys.view(), + right_keys.view(), + left_conditional.view(), + right_conditional.view(), + dereference(binary_predicate.c_obj.get()), + nulls_equal, + ) + return _column_from_gather_map(move(c_result)) + + +cpdef Column mixed_left_anti_join( + Table left_keys, + Table right_keys, + Table left_conditional, + Table right_conditional, + Expression binary_predicate, + null_equality nulls_equal +): + """Perform a mixed left anti join between two tables. + + For details, see :cpp:func:`mixed_left_anti_join`. + + Parameters + ---------- + left_keys : Table + The left table to use for the equality join. + right_keys : Table + The right table to use for the equality join. + left_conditional : Table + The left table to use for the conditional join. + right_conditional : Table + The right table to use for the conditional join. + binary_predicate : Expression + Condition to join on. + nulls_equal : NullEquality + Should nulls compare equal in the equality join? + + Returns + ------- + Column + A column containing the row indices from the left table after the join. + """ + cdef cpp_join.gather_map_type c_result + with nogil: + c_result = cpp_join.mixed_left_anti_join( + left_keys.view(), + right_keys.view(), + left_conditional.view(), + right_conditional.view(), + dereference(binary_predicate.c_obj.get()), + nulls_equal, + ) + return _column_from_gather_map(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/libcudf/join.pxd b/python/pylibcudf/pylibcudf/libcudf/join.pxd index 21033a0284e..f7b77ca5907 100644 --- a/python/pylibcudf/pylibcudf/libcudf/join.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/join.pxd @@ -1,10 +1,13 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. +from libc.stddef cimport size_t from libcpp cimport bool from libcpp.memory cimport unique_ptr +from libcpp.optional cimport optional from libcpp.pair cimport pair from libcpp.vector cimport vector from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.expressions cimport expression from pylibcudf.libcudf.table.table cimport table from pylibcudf.libcudf.table.table_view cimport table_view from pylibcudf.libcudf.types cimport null_equality, size_type @@ -74,3 +77,113 @@ cdef extern from "cudf/join.hpp" namespace "cudf" nogil: const table_view left, const table_view right, ) except + + + cdef gather_map_pair_type conditional_inner_join( + const table_view left, + const table_view right, + const expression binary_predicate, + ) except + + + cdef gather_map_pair_type conditional_inner_join( + const table_view left, + const table_view right, + const expression binary_predicate, + optional[size_t] output_size + ) except + + + cdef gather_map_pair_type conditional_left_join( + const table_view left, + const table_view right, + const expression binary_predicate, + ) except + + + cdef gather_map_pair_type conditional_left_join( + const table_view left, + const table_view right, + const expression binary_predicate, + optional[size_t] output_size + ) except + + + cdef gather_map_pair_type conditional_full_join( + const table_view left, + const table_view right, + const expression binary_predicate, + ) except + + + cdef gather_map_pair_type conditional_full_join( + const table_view left, + const table_view right, + const expression binary_predicate, + optional[size_t] output_size + ) except + + + cdef gather_map_type conditional_left_semi_join( + const table_view left, + const table_view right, + const expression binary_predicate, + ) except + + + cdef gather_map_type conditional_left_semi_join( + const table_view left, + const table_view right, + const expression binary_predicate, + optional[size_t] output_size + ) except + + + cdef gather_map_type conditional_left_anti_join( + const table_view left, + const table_view right, + const expression binary_predicate, + ) except + + + cdef gather_map_type conditional_left_anti_join( + const table_view left, + const table_view right, + const expression binary_predicate, + optional[size_t] output_size + ) except + + + cdef gather_map_pair_type mixed_inner_join( + const table_view left_equality, + const table_view right_equality, + const table_view left_conditional, + const table_view right_conditional, + const expression binary_predicate, + null_equality compare_nulls + ) except + + + cdef gather_map_pair_type mixed_left_join( + const table_view left_equality, + const table_view right_equality, + const table_view left_conditional, + const table_view right_conditional, + const expression binary_predicate, + null_equality compare_nulls + ) except + + + cdef gather_map_pair_type mixed_full_join( + const table_view left_equality, + const table_view right_equality, + const table_view left_conditional, + const table_view right_conditional, + const expression binary_predicate, + null_equality compare_nulls + ) except + + + cdef gather_map_type mixed_left_semi_join( + const table_view left_equality, + const table_view right_equality, + const table_view left_conditional, + const table_view right_conditional, + const expression binary_predicate, + null_equality compare_nulls + ) except + + + cdef gather_map_type mixed_left_anti_join( + const table_view left_equality, + const table_view right_equality, + const table_view left_conditional, + const table_view right_conditional, + const expression binary_predicate, + null_equality compare_nulls + ) except + From 57aed824219dee92107f07096b1020af315e4e76 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 7 Oct 2024 11:13:49 +0000 Subject: [PATCH 20/23] Add suffix property to Join node --- python/cudf_polars/cudf_polars/dsl/ir.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 0dc53eba4e0..1c42ac7e032 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -15,7 +15,7 @@ import itertools import json -from functools import cache +from functools import cache, cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar @@ -855,6 +855,17 @@ def __init__( ): raise NotImplementedError("Join with literal as join key.") + @cached_property + def suffix(self) -> str: + """ + The suffix to append to the names of columns in the right frame. + + The suffix is only applied if the name overlaps with a name in + the left frame. + """ + suffix = self.options[3] + return "_right" if suffix is None else suffix + @staticmethod @cache def _joiners( @@ -951,7 +962,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" left, right = (c.evaluate(cache=cache) for c in self.children) how, join_nulls, zlice, suffix, coalesce = self.options - suffix = "_right" if suffix is None else suffix + suffix = self.suffix if how == "cross": # Separate implementation, since cross_join returns the # result, not the gather maps From cc2a03239f91e3847eb9db9191a7ec0db45fae2a Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 4 Oct 2024 17:02:47 +0000 Subject: [PATCH 21/23] WIP: Implement inequality joins by translating to cross + filter Before working through the plumbing in pylibcudf for mixed and conditional joins and the ast evaluator, let's just support inequality joins by doing the basic thing. --- .../cudf_polars/cudf_polars/dsl/translate.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 522c4a6729c..0bec7f1c354 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -5,6 +5,7 @@ from __future__ import annotations +import functools import json from contextlib import AbstractContextManager, nullcontext from functools import singledispatch @@ -182,7 +183,41 @@ def _( with set_node(visitor, node.input_right): inp_right = translate_ir(visitor, n=None) right_on = [translate_named_expr(visitor, n=e) for e in node.right_on] - return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right) + if (how := node.options[0]) in { + "inner", + "left", + "right", + "full", + "cross", + "leftsemi", + "leftanti", + }: + return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right) + else: + how, op1, op2 = how + if how != "inequality": + raise NotImplementedError(f"Unsupported join type {how}") + # No exposure of mixed/conditional joins in pylibcudf yet, so in + # the first instance, implement by doing a cross join followed by + # a filter. + cross = ir.Join( + schema, [], [], ("cross", *node.options[1:]), inp_left, inp_right + ) + dtype = plc.DataType(plc.TypeId.BOOL8) + if op2 is None: + ops = [op1] + else: + ops = [op1, op2] + mask = functools.reduce( + functools.partial( + expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND + ), + ( + expr.BinOp(dtype, expr.BinOp._MAPPING[op], left.value, right.value) + for op, left, right in zip(ops, left_on, right_on, strict=True) + ), + ) + return ir.Filter(schema, expr.NamedExpr("mask", mask), cross) @_translate_ir.register From 83fa5b66b45615dfae5267dd1ed9b3c940d7db4d Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 7 Oct 2024 11:36:11 +0000 Subject: [PATCH 22/23] Fix expression references in inequality join translation Expressions referring to the right table must be suffixed if the name overlaps with that in the left table. --- .../cudf_polars/cudf_polars/dsl/translate.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 0bec7f1c354..a5723c2c856 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -20,7 +20,8 @@ from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir from cudf_polars.dsl import expr, ir -from cudf_polars.typing import NodeTraverser +from cudf_polars.dsl.traversal import make_recursive, reuse_if_unchanged +from cudf_polars.typing import ExprTransformer, NodeTraverser from cudf_polars.utils import dtypes, sorting __all__ = ["translate_ir", "translate_named_expr"] @@ -170,6 +171,22 @@ def _( ) +@singledispatch +def _rename(e: expr.Expr, self: ExprTransformer) -> expr.Expr: + raise NotImplementedError() + + +_rename.register(expr.Expr)(reuse_if_unchanged) + + +@_rename.register +def _(e: expr.Col, self: ExprTransformer) -> expr.Expr: + new_name = self.state["namer"](e.name) + if new_name != e.name: + return type(e)(e.dtype, new_name) + return e + + @_translate_ir.register def _( node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType] @@ -208,6 +225,20 @@ def _( ops = [op1] else: ops = [op1, op2] + suffix = cross.suffix + + # Column references in the right table refer to the post-join + # names, so with suffixes. + def renamer(name): + return name if name not in inp_left.schema else f"{name}{suffix}" + + mapper = make_recursive(_rename, state={"namer": renamer}) + right_on = [ + expr.NamedExpr(renamer(old.name), new) + for new, old in zip( + (mapper(e.value) for e in right_on), right_on, strict=True + ) + ] mask = functools.reduce( functools.partial( expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND From 4350006194115e64ce91e8b4a907a29807eee9de Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 7 Oct 2024 13:20:36 +0000 Subject: [PATCH 23/23] join_where tests --- python/cudf_polars/tests/test_join.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index 7d9ec98db97..ff080b716d1 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -39,6 +39,7 @@ def right(): { "a": [1, 4, 3, 7, None, None], "c": [2, 3, 4, 5, 6, 7], + "d": [6, None, 7, 8, -1, 2], } ) @@ -86,3 +87,24 @@ def test_join_literal_key_unsupported(left, right, left_on, right_on): q = left.join(right, left_on=left_on, right_on=right_on, how="inner") assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize( + "conditions", + [ + [pl.col("a") < pl.col("a_right")], + [pl.col("a_right") <= pl.col("a") * 2], + [pl.col("b") * 2 > pl.col("a_right"), pl.col("a") == pl.col("c_right")], + [pl.col("b") * 2 <= pl.col("a_right"), pl.col("a") < pl.col("c_right")], + pytest.param( + [pl.col("b") <= pl.col("a_right") * 7, pl.col("a") < pl.col("d") * 2], + marks=pytest.mark.xfail( + reason="https://github.com/pola-rs/polars/issues/19119" + ), + ), + ], +) +def test_join_where(left, right, conditions): + q = left.join_where(right, *conditions) + + assert_gpu_result_equal(q, check_row_order=False)