Skip to content

Commit

Permalink
Overview documentation for visitor pattern/utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Oct 14, 2024
1 parent 74a98f0 commit a7d050c
Showing 1 changed file with 176 additions and 26 deletions.
202 changes: 176 additions & 26 deletions python/cudf_polars/docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ You will need:
## Installing polars

`cudf-polars` works with polars >= 1.3, as long as the internal IR
version doesn't get a major version bump. So `pip install polars>=1.3`
should work. For development, if we're adding things to the polars
side of things, we will need to build polars from source:
The `cudf-polars` `pyproject.toml` advertises which polars versions it
works with. So for pure `cudf-polars` development, installing as
normal and satisfying the dependencies in the repository is
sufficient. For development, if we're adding things to the polars side
of things, we will need to build polars from source:

```sh
git clone https://github.com/pola-rs/polars
Expand Down Expand Up @@ -126,7 +127,6 @@ arguments, at the moment, `raise_on_fail` is also supported, which
raises, rather than falling back, during translation:

```python

result = q.collect(engine=pl.GPUEngine(raise_on_fail=True))
```

Expand All @@ -144,11 +144,69 @@ changes. We can therefore attempt to detect the IR version
appropriately. This should be done during IR translation in
`translate.py`.

## Adding a handler for a new plan node
# IR design

As noted, we translate the polars DSL into our own IR. This is both so
that we can smooth out minor version differences (advertised by
`NodeTraverser` version changes) within `cudf-polars`, and so that we
have the freedom to introduce new IR nodes and rewrite rules as might
be appropriate for GPU execution.

To that end, we provide facilities for definition of nodes as well as
writing traversals and rewrite rules. The abstract base class `Node`
in `dsl/nodebase.py` defines the interface for implementing new nodes,
and provides many useful default methods. See also the docstrings of
the `Node` class.

> ![NOTE] This generic implementation relies on nodes being treated as
> *immutable*. Do not implement in-place modification of nodes, bad
> things will happen.
## Defining nodes

Plan node definitions live in `cudf_polars/dsl/ir.py`, these are
`dataclasses` that inherit from the base `IR` node. The evaluation of
a plan node is done by implementing the `evaluate` method.
A concrete node type (`cudf-polars` has ones for expressions `Expr`
and ones for plan nodes `IR`), should inherit from `Node`. Nodes have
two types of data:

1. `children`: a tuple (possibly empty) of concrete nodes
2. non-child: arbitrary data attached to the node that is _not_ a
concrete node.

The base `Node` class requires that one advertise the _names_ of the
non-child attributes in the `_non_child` class variable. The
constructor of the concrete node should take its arguments in the
order `*_non_child` (ordered as the class variable does) and then
`*children`. For example, the `Sort` node, which sorts a column
generated by an expression, has this definition:

```python
class Expr(Node):
children: tuple[Expr, ...]

class Sort(Expr):
_non_child = ("dtype", "options")
children: tuple[Expr]
def __init__(self, dtype, options, column: Expr):
self.dtype = dtype
self.options = options
self.children = (column,)
```

By following this pattern, we get an automatic (caching)
implementation of `__hash__` and `__eq__`, as well as a useful
`reconstruct` method that will rebuild the node with new children.

If you want to control the behaviour of `__hash__` and `__eq__` for a
single node, override (respectively) the `get_hashable` and `is_equal`
methods.

## Adding new translation rules from the polars IR

### Plan nodes

Plan node definitions live in `cudf_polars/dsl/ir.py`, these all
inherit from the base `IR` node. The evaluation of a plan node is done
by implementing the `evaluate` method.

To translate the plan node, add a case handler in `translate_ir` which
lives in `cudf_polars/dsl/translate.py`.
Expand All @@ -163,25 +221,12 @@ translating a `Join` node, the left keys (expressions) should be
translated with the left input active (and right keys with right
input). To facilitate this, use the `set_node` context manager.

## Adding a handler for a new expression node
### Expression nodes

Adding a handle for an expression node is very similar to a plan node.
Expressions are all defined in `cudf_polars/dsl/expr.py` and inherit
from `Expr`. Unlike plan nodes, these are not `dataclasses`, since it
is simpler for us to implement efficient hashing, repr, and equality if we
can write that ourselves.

Every expression consists of two types of data:
1. child data (other `Expr`s)
2. non-child data (anything other than an `Expr`)
The generic implementations of special methods in the base `Expr` base
class require that the subclasses advertise which arguments to the
constructor are non-child in a `_non_child` class slot. The
constructor should then take arguments:
```python
def __init__(self, *non_child_data: Any, *children: Expr):
```
Read the docstrings in the `Expr` class for more details.
Expressions are defined in `cudf_polars/dsl/expressions/` and exported
into the `dsl` namespace via `expr.py`. They inherit
from `Expr`.

Expressions are evaluated by implementing a `do_evaluate` method that
takes a `DataFrame` as context (this provides columns) along with an
Expand All @@ -198,6 +243,111 @@ To simplify state tracking, all columns should be considered immutable
on construction. This matches the "functional" description coming from
the logical plan in any case, so is reasonably natural.

## Traversing and transforming nodes

As well as just representing and evaluating nodes. We also provide
facilities for traversing a tree of nodes and defining transformation
rules in `dsl/traversal.py`. The simplest is `traversal`, this yields
all _unique_ nodes in an expression parent before child, children
in-order left to right (i.e. a pre-order traversal). Use this if you
want to know some specific thing about an expression. For example, to
determine if an expression contains a `Literal` node:

```python
def has_literal(node: Expr) -> bool:
return any(isinstance(e, Literal) for e in traversal(node))
```

For transformations and rewrites, we use the following generic
pattern. Rather than defining methods on each node in turn for a
particular rewrite rule, we prefer free functions and use
`functools.singledispatch` to provide dispatching.

It is often convenient to provide (immutable) state to a visitor, as
well as some facility to perform DAG-aware rewrites (reusing a
transformation for an expression if we have already seen it). We
therefore adopt the following pattern of writing DAG-aware visitors.
Suppose we want a rewrite rule (`rewrite`) between expressions
(`Expr`) and some new type `T`. We define our general transformation
function `rewrite` with type `Expr -> (Expr -> T) -> T`:

```python
@singledispatch
def rewrite(e: Expr, rec: Callable[[Expr], T]) -> T:
...
```

Note in particular that the function to perform the recursion is
passed as the second argument. We now, in the usual fashion, register
handlers for different expression types. To use this function, we need
to be able to provide both the expression to convert and the recursive
function itself. To do this we must convert our `rewrite` function
into something that only takes a single argument (the expression to
rewrite), but carries around information about how to perform the
recursion. To this end, we have two utilities in `traversal.py`:

- `make_recursive` and
- `CachingVisitor`.

Both of these can be wrapped around a transformation function like
`rewrite` to provide a function `Expr -> T`. We can also attach
arbitrary state to the objects they return, which `rewrite` can
inspect. `make_recursive` is very simple, and provides no caching of
intermediate results (so any DAGs that are visited will be viewed as
trees). `CachingVisitor` provides the same interface, but maintains a
cache of intermediate results, and reuses them if the same expression
is seen again.

Finally, for writing transformations that take nodes and deliver new
nodes (e.g. rewrite rules), we have a final utility
`reuse_if_unchanged` which can be used as a base case transformation
for node to node rewrites. It is a depth-first visit that transforms
children but only returns a new node with new children if the rewrite
on children changed things.

To see how these pieces fit together, let us consider writing a
`rename` function that takes an expression (potentially with
references to columns) along with a mapping defining a renaming
between (some subset of) column names. The goal is to deliver a new
expression with appropriate columns renamed.

To start, we define the dispatch function
```python
@singledispatch
def _rename(e: Expr, rec: Callable[[Expr], Expr]) -> Expr:
raise NotImplementedError(f"No handler for {type(e)}")
```
then we register specific handlers, first for columns:
```python
@_rename.register
def _(e: Col, rec: Callable[[Expr], Expr]) -> Expr:
mapping = rec.mapping # state set on rec
if e.name in mapping:
# If we have a rename, return a new Col reference
# with a new name
return type(e)(e.dtype, mapping[e.name])
return e
```
and then for the remaining expressions
```python
_rename.register(Expr)(reuse_if_unchanged)
```
> ![NOTE] In this case, we could have put the generic handler in
> the `_rename` function, however, then we would not get a nice error
> message if we accidentally sent in an object of the incorrect type.
Finally we tie everything together with a public function:

```python
def rename(e: Expr, mapping: Mapping[str, str]) -> Expr:
"""Rename column references in an expression."""
mapper = CachingVisitor(_rename)
# or
# mapper = make_recursive(_rename)
mapper.mapping = mapping
return mapper(e)
```

# Containers

Containers should be constructed as relatively lightweight objects
Expand Down

0 comments on commit a7d050c

Please sign in to comment.