Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A couple of ir utilities #1972

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,78 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation


def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5) -> None:
"""Display the (backward or forward) subgraph from a given value or node upto a certain depth."""
slice = []

Check warning on line 13 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L13

Added line #L13 was not covered by tests

def visit(node: ir.Node, depth):

Check warning on line 15 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L15

Added line #L15 was not covered by tests
if node in slice:
return
slice.append(node)

Check warning on line 18 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L17-L18

Added lines #L17 - L18 were not covered by tests
if depth < depth_limit:
if backward:
for inp in node.inputs:
if inp is not None and inp.producer() is not None:
visit(inp.producer(), depth + 1) # type: ignore[arg-type]

Check warning on line 23 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L23

Added line #L23 was not covered by tests
else:
for out in node.outputs:
for consumer, _ in out.uses():
visit(consumer, depth + 1)

Check warning on line 27 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L27

Added line #L27 was not covered by tests

if isinstance(x, ir.Node):
visit(x, 0)

Check warning on line 30 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L30

Added line #L30 was not covered by tests
elif isinstance(x, ir.Value) and x.producer() is not None:
visit(x.producer(), 0) # type: ignore[arg-type]

Check warning on line 32 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L32

Added line #L32 was not covered by tests
if slice:
graph = slice[0].graph

Check warning on line 34 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L34

Added line #L34 was not covered by tests
if graph:
# Display nodes in same order as in graph:
# Currently doesn't handle (control-flow) subgraphs
for node in graph:
if node in slice:
node.display()

Check warning on line 40 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L40

Added line #L40 was not covered by tests
else:
for node in reversed(slice):
node.display()

Check warning on line 43 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L43

Added line #L43 was not covered by tests


def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
node = value.producer()
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
"""Convenience wrapper to get (optional) numpy value from an optional IR Value.

This is intended for use in optimizations/rewriting. Note that this does not
yet handle the distinction between inputs with default values (values that are
both graph inputs and graph initializers), which should not be treated as a
constant, and true constant values. The caller should make the distinction, as
a value does not contain enough information to determine this. (TODO)
"""
if val is None:
return None
const_value = val.const_value

Check warning on line 64 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L63-L64

Added lines #L63 - L64 were not covered by tests
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:

Check warning on line 68 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L66-L68

Added lines #L66 - L68 were not covered by tests
# External data is not available.
return None
return None

Check warning on line 71 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L70-L71

Added lines #L70 - L71 were not covered by tests


def get_singleton_value(val: ir.Value | None):
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)

Check warning on line 76 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L76

Added line #L76 was not covered by tests
if np_val is not None and np_val.size == 1:
return np_val.item()
return None

Check warning on line 79 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L78-L79

Added lines #L78 - L79 were not covered by tests
Loading