-
Notifications
You must be signed in to change notification settings - Fork 890
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Infrastructure for traversal and visitors
And tests of basic functionality.
- Loading branch information
Showing
2 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Traversal and visitor utilities for nodes.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Hashable | ||
from typing import TYPE_CHECKING, Any, Generic, TypeVar | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable, Generator, MutableMapping | ||
|
||
from cudf_polars.dsl.nodebase import Node | ||
|
||
|
||
def traversal(node: Node) -> Generator[Node, None, None]: | ||
""" | ||
Pre-order traversal of nodes in an expression. | ||
Parameters | ||
---------- | ||
node | ||
Root of expression to traverse. | ||
Yields | ||
------ | ||
Unique nodes in the expression, parent before child, children | ||
in-order from left to right. | ||
""" | ||
seen = {node} | ||
lifo = [node] | ||
|
||
while lifo: | ||
node = lifo.pop() | ||
yield node | ||
for child in reversed(node.children): | ||
if child not in seen: | ||
seen.add(child) | ||
lifo.append(child) | ||
|
||
|
||
U = TypeVar("U", bound=Hashable) | ||
V = TypeVar("V") | ||
|
||
|
||
def reuse_if_unchanged(e: Node, fn: Callable[[Node], Node]) -> Node: | ||
""" | ||
Recipe for transforming nodes that returns the old object if unchanged. | ||
Parameters | ||
---------- | ||
e | ||
Node to recurse on | ||
fn | ||
Function to transform children | ||
Notes | ||
----- | ||
This can be used as a generic "base case" handler when | ||
writing transforms that take nodes and produce new nodes. | ||
Returns | ||
------- | ||
Existing node `e` if transformed children are unchanged, otherwise | ||
reconstructed node with new children. | ||
""" | ||
new_children = [fn(c) for c in e.children] | ||
if all(new == old for new, old in zip(new_children, e.children, strict=True)): | ||
return e | ||
return e.reconstruct(new_children) | ||
|
||
|
||
def make_recursive(fn: Callable[[U, Callable[[U], V]], V]) -> Callable[[U], V]: | ||
""" | ||
No-op wrapper for recursive visitors. | ||
Facilitates using visitors that don't need caching but are written | ||
in the same style. | ||
Arbitrary immutable state can be attached to the visitor by | ||
setting properties on the wrapper, since the functions will | ||
receive the wrapper as an argument. | ||
Parameters | ||
---------- | ||
fn | ||
Function to transform inputs to outputs. Should take as its | ||
second argument a callable from input to output. | ||
Notes | ||
----- | ||
All transformation functions *must* be pure. | ||
Usually, prefer a :class:`CachingVisitor`, but if we know that we | ||
don't need caching in a transformation and then this no-op | ||
approach is slightly cheaper. | ||
Returns | ||
------- | ||
Recursive function without caching. | ||
See Also | ||
-------- | ||
CachingVisitor | ||
""" | ||
|
||
def rec(node: U) -> V: | ||
return fn(node, rec) | ||
|
||
return rec | ||
|
||
|
||
class CachingVisitor(Generic[U, V]): | ||
""" | ||
Caching wrapper for recursive visitors. | ||
Facilitates writing visitors where already computed results should | ||
be cached and reused. The cache is managed automatically, and is | ||
tied to the lifetime of the wrapper. | ||
Arbitrary immutable state can be attached to the visitor by | ||
setting properties on the wrapper, since the functions will | ||
receive the wrapper as an argument. | ||
Parameters | ||
---------- | ||
fn | ||
Function to transform inputs to outputs. Should take as its | ||
second argument the recursive cache manager. | ||
Notes | ||
----- | ||
All transformation functions *must* be pure. | ||
Returns | ||
------- | ||
Recursive function with caching. | ||
""" | ||
|
||
def __init__(self, fn: Callable[[U, Callable[[U], V]], V]) -> None: | ||
self.fn = fn | ||
self.cache: MutableMapping[U, V] = {} | ||
|
||
def __call__(self, value: U) -> V: | ||
""" | ||
Apply the function to a value. | ||
Parameters | ||
---------- | ||
value | ||
The value to transform. | ||
Returns | ||
------- | ||
A transformed value. | ||
""" | ||
try: | ||
return self.cache[value] | ||
except KeyError: | ||
return self.cache.setdefault(value, self.fn(value, self)) | ||
|
||
if TYPE_CHECKING: | ||
# Advertise to type-checkers that dynamic attributes are allowed | ||
def __setattr__(self, name: str, value: Any) -> None: ... # noqa: D105 | ||
def __getattr__(self, name: str) -> Any: ... # noqa: D105 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
import pylibcudf as plc | ||
|
||
from cudf_polars.dsl import expr | ||
from cudf_polars.dsl.traversal import ( | ||
CachingVisitor, | ||
make_recursive, | ||
reuse_if_unchanged, | ||
traversal, | ||
) | ||
|
||
|
||
def make_expr(dt, n1, n2): | ||
a1 = expr.Col(dt, n1) | ||
a2 = expr.Col(dt, n2) | ||
|
||
return expr.BinOp(dt, plc.binaryop.BinaryOperator.MUL, a1, a2) | ||
|
||
|
||
def test_traversal_unique(): | ||
dt = plc.DataType(plc.TypeId.INT8) | ||
|
||
e1 = make_expr(dt, "a", "a") | ||
unique_exprs = list(traversal(e1)) | ||
|
||
assert len(unique_exprs) == 2 | ||
assert set(unique_exprs) == {expr.Col(dt, "a"), e1} | ||
assert unique_exprs == [e1, expr.Col(dt, "a")] | ||
|
||
e2 = make_expr(dt, "a", "b") | ||
unique_exprs = list(traversal(e2)) | ||
|
||
assert len(unique_exprs) == 3 | ||
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e2} | ||
assert unique_exprs == [e2, expr.Col(dt, "a"), expr.Col(dt, "b")] | ||
|
||
e3 = make_expr(dt, "b", "a") | ||
unique_exprs = list(traversal(e3)) | ||
|
||
assert len(unique_exprs) == 3 | ||
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e3} | ||
assert unique_exprs == [e3, expr.Col(dt, "b"), expr.Col(dt, "a")] | ||
|
||
|
||
def rename(e, rec): | ||
if isinstance(e, expr.Col) and e.name in rec.mapping: | ||
return type(e)(e.dtype, rec.mapping[e.name]) | ||
return reuse_if_unchanged(e, rec) | ||
|
||
|
||
def test_caching_visitor(): | ||
dt = plc.DataType(plc.TypeId.INT8) | ||
|
||
e1 = make_expr(dt, "a", "b") | ||
|
||
mapper = CachingVisitor(rename) | ||
mapper.mapping = {"b": "c"} | ||
|
||
renamed = mapper(e1) | ||
assert renamed == make_expr(dt, "a", "c") | ||
assert len(mapper.cache) == 3 | ||
|
||
e2 = make_expr(dt, "a", "a") | ||
mapper = CachingVisitor(rename) | ||
mapper.mapping = {"b": "c"} | ||
|
||
renamed = mapper(e2) | ||
assert renamed == make_expr(dt, "a", "a") | ||
assert len(mapper.cache) == 2 | ||
mapper = CachingVisitor(rename) | ||
mapper.mapping = {"a": "c"} | ||
|
||
renamed = mapper(e2) | ||
assert renamed == make_expr(dt, "c", "c") | ||
assert len(mapper.cache) == 2 | ||
|
||
|
||
def test_noop_visitor(): | ||
dt = plc.DataType(plc.TypeId.INT8) | ||
|
||
e1 = make_expr(dt, "a", "b") | ||
|
||
mapper = make_recursive(rename) | ||
mapper.mapping = {"b": "c"} | ||
|
||
renamed = mapper(e1) | ||
assert renamed == make_expr(dt, "a", "c") | ||
|
||
e2 = make_expr(dt, "a", "a") | ||
mapper = make_recursive(rename) | ||
mapper.mapping = {"b": "c"} | ||
|
||
renamed = mapper(e2) | ||
assert renamed == make_expr(dt, "a", "a") | ||
mapper = make_recursive(rename) | ||
mapper.mapping = {"a": "c"} | ||
|
||
renamed = mapper(e2) | ||
assert renamed == make_expr(dt, "c", "c") |