Skip to content

Commit

Permalink
Infrastructure for traversal and visitors
Browse files Browse the repository at this point in the history
And tests of basic functionality.
  • Loading branch information
wence- committed Oct 14, 2024
1 parent a97f45a commit 65eb953
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 0 deletions.
166 changes: 166 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Traversal and visitor utilities for nodes."""

from __future__ import annotations

from collections.abc import Hashable
from typing import TYPE_CHECKING, Any, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import Callable, Generator, MutableMapping

from cudf_polars.dsl.nodebase import Node


def traversal(node: Node) -> Generator[Node, None, None]:
"""
Pre-order traversal of nodes in an expression.
Parameters
----------
node
Root of expression to traverse.
Yields
------
Unique nodes in the expression, parent before child, children
in-order from left to right.
"""
seen = {node}
lifo = [node]

while lifo:
node = lifo.pop()
yield node
for child in reversed(node.children):
if child not in seen:
seen.add(child)
lifo.append(child)


U = TypeVar("U", bound=Hashable)
V = TypeVar("V")


def reuse_if_unchanged(e: Node, fn: Callable[[Node], Node]) -> Node:
"""
Recipe for transforming nodes that returns the old object if unchanged.
Parameters
----------
e
Node to recurse on
fn
Function to transform children
Notes
-----
This can be used as a generic "base case" handler when
writing transforms that take nodes and produce new nodes.
Returns
-------
Existing node `e` if transformed children are unchanged, otherwise
reconstructed node with new children.
"""
new_children = [fn(c) for c in e.children]
if all(new == old for new, old in zip(new_children, e.children, strict=True)):
return e
return e.reconstruct(new_children)


def make_recursive(fn: Callable[[U, Callable[[U], V]], V]) -> Callable[[U], V]:
"""
No-op wrapper for recursive visitors.
Facilitates using visitors that don't need caching but are written
in the same style.
Arbitrary immutable state can be attached to the visitor by
setting properties on the wrapper, since the functions will
receive the wrapper as an argument.
Parameters
----------
fn
Function to transform inputs to outputs. Should take as its
second argument a callable from input to output.
Notes
-----
All transformation functions *must* be pure.
Usually, prefer a :class:`CachingVisitor`, but if we know that we
don't need caching in a transformation and then this no-op
approach is slightly cheaper.
Returns
-------
Recursive function without caching.
See Also
--------
CachingVisitor
"""

def rec(node: U) -> V:
return fn(node, rec)

return rec


class CachingVisitor(Generic[U, V]):
"""
Caching wrapper for recursive visitors.
Facilitates writing visitors where already computed results should
be cached and reused. The cache is managed automatically, and is
tied to the lifetime of the wrapper.
Arbitrary immutable state can be attached to the visitor by
setting properties on the wrapper, since the functions will
receive the wrapper as an argument.
Parameters
----------
fn
Function to transform inputs to outputs. Should take as its
second argument the recursive cache manager.
Notes
-----
All transformation functions *must* be pure.
Returns
-------
Recursive function with caching.
"""

def __init__(self, fn: Callable[[U, Callable[[U], V]], V]) -> None:
self.fn = fn
self.cache: MutableMapping[U, V] = {}

def __call__(self, value: U) -> V:
"""
Apply the function to a value.
Parameters
----------
value
The value to transform.
Returns
-------
A transformed value.
"""
try:
return self.cache[value]
except KeyError:
return self.cache.setdefault(value, self.fn(value, self))

if TYPE_CHECKING:
# Advertise to type-checkers that dynamic attributes are allowed
def __setattr__(self, name: str, value: Any) -> None: ... # noqa: D105
def __getattr__(self, name: str) -> Any: ... # noqa: D105
103 changes: 103 additions & 0 deletions python/cudf_polars/tests/dsl/test_traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pylibcudf as plc

from cudf_polars.dsl import expr
from cudf_polars.dsl.traversal import (
CachingVisitor,
make_recursive,
reuse_if_unchanged,
traversal,
)


def make_expr(dt, n1, n2):
a1 = expr.Col(dt, n1)
a2 = expr.Col(dt, n2)

return expr.BinOp(dt, plc.binaryop.BinaryOperator.MUL, a1, a2)


def test_traversal_unique():
dt = plc.DataType(plc.TypeId.INT8)

e1 = make_expr(dt, "a", "a")
unique_exprs = list(traversal(e1))

assert len(unique_exprs) == 2
assert set(unique_exprs) == {expr.Col(dt, "a"), e1}
assert unique_exprs == [e1, expr.Col(dt, "a")]

e2 = make_expr(dt, "a", "b")
unique_exprs = list(traversal(e2))

assert len(unique_exprs) == 3
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e2}
assert unique_exprs == [e2, expr.Col(dt, "a"), expr.Col(dt, "b")]

e3 = make_expr(dt, "b", "a")
unique_exprs = list(traversal(e3))

assert len(unique_exprs) == 3
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e3}
assert unique_exprs == [e3, expr.Col(dt, "b"), expr.Col(dt, "a")]


def rename(e, rec):
if isinstance(e, expr.Col) and e.name in rec.mapping:
return type(e)(e.dtype, rec.mapping[e.name])
return reuse_if_unchanged(e, rec)


def test_caching_visitor():
dt = plc.DataType(plc.TypeId.INT8)

e1 = make_expr(dt, "a", "b")

mapper = CachingVisitor(rename)
mapper.mapping = {"b": "c"}

renamed = mapper(e1)
assert renamed == make_expr(dt, "a", "c")
assert len(mapper.cache) == 3

e2 = make_expr(dt, "a", "a")
mapper = CachingVisitor(rename)
mapper.mapping = {"b": "c"}

renamed = mapper(e2)
assert renamed == make_expr(dt, "a", "a")
assert len(mapper.cache) == 2
mapper = CachingVisitor(rename)
mapper.mapping = {"a": "c"}

renamed = mapper(e2)
assert renamed == make_expr(dt, "c", "c")
assert len(mapper.cache) == 2


def test_noop_visitor():
dt = plc.DataType(plc.TypeId.INT8)

e1 = make_expr(dt, "a", "b")

mapper = make_recursive(rename)
mapper.mapping = {"b": "c"}

renamed = mapper(e1)
assert renamed == make_expr(dt, "a", "c")

e2 = make_expr(dt, "a", "a")
mapper = make_recursive(rename)
mapper.mapping = {"b": "c"}

renamed = mapper(e2)
assert renamed == make_expr(dt, "a", "a")
mapper = make_recursive(rename)
mapper.mapping = {"a": "c"}

renamed = mapper(e2)
assert renamed == make_expr(dt, "c", "c")

0 comments on commit 65eb953

Please sign in to comment.