Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Implement inequality joins by translating to cross + filter #17000

Draft
wants to merge 23 commits into
base: branch-24.12
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6248ec3
Renaming in typing for clarity
wence- Oct 14, 2024
8b5aaed
Extract abstract base for nodes into new file
wence- Oct 8, 2024
26b9d7d
Use new Node base class for expressions
wence- Oct 8, 2024
ffe460c
Infrastructure for traversal and visitors
wence- Oct 14, 2024
83a60f0
Use abstract Node infrastructure to define IR nodes
wence- Oct 8, 2024
a234e37
Add tests of traversal over IR nodes
wence- Oct 14, 2024
73019c8
Overview documentation for visitor pattern/utilities
wence- Oct 14, 2024
b14b150
Some grammar fixes
wence- Oct 14, 2024
a49846f
Reinstate docstrings for properties
wence- Oct 14, 2024
9449b44
Use side-effect free rather than pure
wence- Oct 14, 2024
e36ead1
A few type annotations
wence- Oct 11, 2024
6fced33
Expose all type ids and match order with libcudf
wence- Oct 16, 2024
7de74d4
Support all types for scalars in pylibcudf Expressions
wence- Oct 10, 2024
2caabfc
Expose compute_column
wence- Oct 16, 2024
0f5a670
Implement conversion from Expr nodes to pylibcudf Expressions
wence- Oct 10, 2024
6cb5440
Implement predicate pushdown into parquet read
wence- Oct 10, 2024
1545db8
Add tests of parquet filters
wence- Oct 16, 2024
dff0aa1
Add tests of to_ast and column compute
wence- Oct 16, 2024
d9175ad
WIP: expose mixed and conditional joins in pylibcudf
wence- Oct 16, 2024
57aed82
Add suffix property to Join node
wence- Oct 7, 2024
cc2a032
WIP: Implement inequality joins by translating to cross + filter
wence- Oct 4, 2024
83fa5b6
Fix expression references in inequality join translation
wence- Oct 7, 2024
4350006
join_where tests
wence- Oct 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Agg(Expr):
def __init__(
self, dtype: plc.DataType, name: str, options: Any, *children: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.name = name
self.options = options
self.children = children
Expand Down
92 changes: 5 additions & 87 deletions python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.dsl.nodebase import Node

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Mapping

from cudf_polars.containers import Column, DataFrame

Expand All @@ -32,100 +33,17 @@ class ExecutionContext(IntEnum):
ROLLING = enum.auto()


class Expr:
"""
An abstract expression object.
class Expr(Node):
"""An abstract expression object."""

This contains a (potentially empty) tuple of child expressions,
along with non-child data. For uniform reconstruction and
implementation of hashing and equality schemes, child classes need
to provide a certain amount of metadata when they are defined.
Specifically, the ``_non_child`` attribute must list, in-order,
the names of the slots that are passed to the constructor. The
constructor must take arguments in the order ``(*_non_child,
*children).``
"""

__slots__ = ("dtype", "_hash_value", "_repr_value")
__slots__ = ("dtype",)
dtype: plc.DataType
"""Data type of the expression."""
_hash_value: int
"""Caching slot for the hash of the expression."""
_repr_value: str
"""Caching slot for repr of the expression."""
children: tuple[Expr, ...] = ()
"""Children of the expression."""
_non_child: ClassVar[tuple[str, ...]] = ("dtype",)
"""Names of non-child data (not Exprs) for reconstruction."""

# Constructor must take arguments in order (*_non_child, *children)
def __init__(self, dtype: plc.DataType) -> None:
self.dtype = dtype

def _ctor_arguments(self, children: Sequence[Expr]) -> Sequence:
return (*(getattr(self, attr) for attr in self._non_child), *children)

def get_hash(self) -> int:
"""
Return the hash of this expr.
Override this in subclasses, rather than __hash__.
Returns
-------
The integer hash value.
"""
return hash((type(self), self._ctor_arguments(self.children)))

def __hash__(self) -> int:
"""Hash of an expression with caching."""
try:
return self._hash_value
except AttributeError:
self._hash_value = self.get_hash()
return self._hash_value

def is_equal(self, other: Any) -> bool:
"""
Equality of two expressions.
Override this in subclasses, rather than __eq__.
Parameter
---------
other
object to compare to
Returns
-------
True if the two expressions are equal, false otherwise.
"""
if type(self) is not type(other):
return False # pragma: no cover; __eq__ trips first
return self._ctor_arguments(self.children) == other._ctor_arguments(
other.children
)

def __eq__(self, other: Any) -> bool:
"""Equality of expressions."""
if type(self) is not type(other) or hash(self) != hash(other):
return False
else:
return self.is_equal(other)

def __ne__(self, other: Any) -> bool:
"""Inequality of expressions."""
return not self.__eq__(other)

def __repr__(self) -> str:
"""String representation of an expression with caching."""
try:
return self._repr_value
except AttributeError:
args = ", ".join(f"{arg!r}" for arg in self._ctor_arguments(self.children))
self._repr_value = f"{type(self).__name__}({args})"
return self._repr_value

def do_evaluate(
self,
df: DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
left: Expr,
right: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
if plc.traits.is_boolean(self.dtype):
# For boolean output types, bitand and bitor implement
# boolean logic, so translate. bitxor also does, but the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.name = name
self.children = children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.name = name
self.children = children
Expand Down
10 changes: 5 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/expressions/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Hashable, Mapping

import pyarrow as pa

Expand All @@ -34,7 +34,7 @@ class Literal(Expr):
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None:
super().__init__(dtype)
self.dtype = dtype
assert value.type == plc.interop.to_arrow(dtype)
self.value = value

Expand All @@ -61,16 +61,16 @@ class LiteralColumn(Expr):
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
super().__init__(dtype)
self.dtype = dtype
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))

def get_hash(self) -> int:
def get_hashable(self) -> Hashable:
"""Compute a hash of the column."""
# This is stricter than necessary, but we only need this hash
# for identity in groupby replacements so it's OK. And this
# way we avoid doing potentially expensive compute.
return hash((type(self), self.dtype, id(self.value)))
return (type(self), self.dtype, id(self.value))

def do_evaluate(
self,
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class RollingWindow(Expr):
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (agg,)
raise NotImplementedError("Rolling window not implemented")
Expand All @@ -34,7 +34,7 @@ class GroupedRollingWindow(Expr):
children: tuple[Expr, ...]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (agg, *by)
raise NotImplementedError("Grouped rolling window not implemented")
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Gather(Expr):
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.children = (values, indices)

def do_evaluate(
Expand Down Expand Up @@ -70,7 +70,7 @@ class Filter(Expr):
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
super().__init__(dtype)
self.dtype = dtype
self.children = (values, indices)

def do_evaluate(
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Sort(Expr):
def __init__(
self, dtype: plc.DataType, options: tuple[bool, bool, bool], column: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (column,)

Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
column: Expr,
*by: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (column, *by)

Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.name = name
self.children = children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Ternary(Expr):
def __init__(
self, dtype: plc.DataType, when: Expr, then: Expr, otherwise: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.children = (when, then, otherwise)

def do_evaluate(
Expand Down
7 changes: 5 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Cast(Expr):
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, value: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.children = (value,)
if not dtypes.can_cast(value.dtype, self.dtype):
raise NotImplementedError(
Expand Down Expand Up @@ -62,6 +62,9 @@ class Len(Expr):

children: tuple[()]

def __init__(self, dtype: plc.DataType) -> None:
self.dtype = dtype

def do_evaluate(
self,
df: DataFrame,
Expand Down Expand Up @@ -142,7 +145,7 @@ class UnaryFunction(Expr):
def __init__(
self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.name = name
self.options = options
self.children = children
Expand Down
Loading
Loading