Skip to content

Commit

Permalink
Implement a custom graph renaming routine for inline (#84)
Browse files Browse the repository at this point in the history
* Remove outdated check and test

* Fix test

* Implement anonymous reserved names in scopes

* Implement custom renaming routine

* Update CHANGELOG.rst

* Fix operator domain & opcode renaming

* Add support for graph-list attributes

* Add an explicit test for subgraph list renames

* Avoid initialising the NodeProto domain field

* Add a solid test for the renaming routine

* Avoid code repetition in recursive call
  • Loading branch information
jbachurski authored Jun 12, 2023
1 parent 7cc079d commit 5fb9026
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 70 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Change log

- The constructor for ``ai.onnx@18::Split`` is no longer generated incorrectly. No extraneous attribute is generated anymore, and the ``num_outputs`` attribute is marked as required (so that Spox can infer the number of outputs).

**Other changes**

- Inlining now no longer adds redundant ``Identity`` nodes and supports subgraphs, thanks to reimplementing the ONNX renaming routine.


0.8.1 (2023-05-xx)
------------------
Expand Down
142 changes: 92 additions & 50 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import itertools
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple

import onnx

from spox._exceptions import BuildError
from spox._fields import BaseAttributes, BaseInputs, BaseOutputs
from spox._internal_op import INTERNAL_MIN_OPSET, _InternalNode
from spox._node import OpType
Expand All @@ -13,6 +15,62 @@
from . import _value_prop


def rename_in_graph(
graph_: onnx.GraphProto,
rename: Callable[[str], str],
*,
rename_node: Optional[Callable[[str], str]] = None,
rename_op: Optional[Callable[[str, str], Tuple[str, str]]] = None,
) -> onnx.GraphProto:
def rename_in_subgraph(subgraph):
return rename_in_graph(
subgraph,
rename,
rename_node=rename_node,
rename_op=rename_op,
)

graph = onnx.GraphProto()
graph.CopyFrom(graph_)

for p in itertools.chain(graph.input, graph.initializer):
p.name = rename(p.name)
for si in graph.sparse_initializer:
si.values.name = rename(si.values.name)
si.indices.name = rename(si.indices.name)

for nd in graph.node:
if nd.name and rename_node is not None:
nd.name = rename_node(nd.name)
if rename_op is not None:
# This is a bit elaborate, but we do it this way as
# an unset domain field is different from an empty one.
if nd.HasField("domain"):
nd.domain, nd.op_type = rename_op(nd.domain, nd.op_type)
else:
# An empty domain is the default domain (ai.onnx)
domain, nd.op_type = rename_op("", nd.op_type)
if domain: # Only set the domain explicitly if it's changing
nd.domain = domain
for seq in (nd.input, nd.output):
for i, name in enumerate(seq):
seq[i] = rename(name)
for attr_proto in nd.attribute:
attr = onnx.helper.get_attribute_value(attr_proto)
if isinstance(attr, onnx.GraphProto):
attr_proto.g.CopyFrom(rename_in_subgraph(attr))
elif isinstance(attr, list) and all(
isinstance(g, onnx.GraphProto) for g in attr
):
for i, sub in enumerate(attr):
attr_proto.graphs[i].CopyFrom(rename_in_subgraph(sub))

for p in itertools.chain(graph.output, graph.value_info):
p.name = rename(p.name)

return graph


class _Inline(_InternalNode):
"""Internal operator used for inlining (embedding) an existing ONNX ModelProto inside a Spox graph."""

Expand Down Expand Up @@ -85,55 +143,39 @@ def propagate_values(self) -> Dict[str, _value_prop.PropValueType]:
def to_onnx(
self, scope: Scope, doc_string: Optional[str] = None, build_subgraph=None
) -> List[onnx.NodeProto]:
# Prefix all names in the graph to try and avoid name clashes
name = scope.node[self]
graph = onnx.GraphProto()
graph.CopyFrom(self.graph)
# FIXME: This is a bug upstream - when add_prefix_graph has rename_edges,
# unused inputs are not renamed. We apply identities to use the inputs.
for i in graph.input:
graph.node.append(
onnx.helper.make_node(
"Identity", [i.name], [f"__{i.name}_Identity_dummy_use"]
)
)
graph = onnx.compose.add_prefix_graph(graph, f"{name}__")
for _ in graph.input:
graph.node.pop()

nodes: List[onnx.NodeProto] = []
# Move initializers to Constant nodes
input_names = {i.name for i in graph.input}
nodes.extend(
onnx.helper.make_node("Constant", [], [i.name], value=i)
for i in graph.initializer
if i.name not in input_names
)
nodes.extend(
onnx.helper.make_node("Constant", [], [i.values.name], sparse_value=i)
for i in graph.sparse_initializer
if i.values.name not in input_names
)
# Apply a trivial renaming of inputs
for i, var in zip(graph.input, self.inputs.inputs):
nodes.append(
onnx.helper.make_node(
"Identity",
[scope.var[var]],
[i.name],
f"{i.name}__Identity_rename",
)
input_names: Dict[str, int] = {
p.name: i for i, p in enumerate(self.graph.input)
}
output_names: Dict[str, int] = {
p.name: i for i, p in enumerate(self.graph.output)
}
inner_renames: Dict[str, str] = {}
inner_node_renames: Dict[str, str] = {}

def reserve_prefixed(name: str) -> str:
return scope.var.reserve(
scope.var.maybe_enum(f"{scope.node[self]}__{name}")
)
# Then graph body
nodes.extend(graph.node)
# Finish with output renaming
for o, var in zip(graph.output, self.outputs.outputs):
nodes.append(
onnx.helper.make_node(
"Identity",
[o.name],
[scope.var[var]],
f"{o.name}__Identity_rename",
)

def apply_rename(name: str) -> str:
if name in input_names:
return scope.var[self.inputs.inputs[input_names[name]]]
if name in output_names:
return scope.var[self.outputs.outputs[output_names[name]]]
if name not in inner_renames:
inner_renames[name] = reserve_prefixed(name)
return inner_renames[name]

def apply_node_rename(name: str) -> str:
if name not in inner_node_renames:
inner_node_renames[name] = reserve_prefixed(name)
return inner_node_renames[name]

graph = rename_in_graph(self.graph, apply_rename, rename_node=apply_node_rename)

if graph.initializer:
raise BuildError(
"Inlined graph initializers should be handled beforehand and be removed from the graph."
)
nodes: List[onnx.NodeProto] = list(graph.node)
return nodes
32 changes: 23 additions & 9 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import contextlib
import itertools
from typing import Dict, Optional, Protocol
from typing import Dict, List, Optional, Protocol

import numpy as np
import onnx
Expand Down Expand Up @@ -218,14 +218,6 @@ def inline(model: onnx.ModelProto) -> _InlineCall:
_signature_msg = f"signature {in_names}{_defaults_msg} -> {out_names}"

model = _copy_model(model)
# FIXME: Renaming does not work on subgraphs as of ONNX 1.13/1.14.
for node in model.graph.node:
for attr in node.attribute:
if attr.HasField("g") or attr.graphs:
raise ValueError(
"Inlining models with subgraphs is not supported due to "
"lack of upstream support for renaming values in subgraphs."
)
# FIXME: Support for functions is a bit involved, as it interacts with build.
if model.functions:
raise ValueError(
Expand All @@ -240,6 +232,28 @@ def inline(model: onnx.ModelProto) -> _InlineCall:
info.type.CopyFrom(
_strip_dim_symbol(Type._from_onnx(info.type), lambda x: True)._to_onnx()
)
# We handle everything related to initializers here, as currently build does not support them too well
# Overridable initializers are saved to in_defaults, non-overridable replaced with Constant
preamble: List[onnx.NodeProto] = []
input_names = {i.name for i in model.graph.input}
preamble.extend(
onnx.helper.make_node("Constant", [], [i.name], value=i)
for i in model.graph.initializer
if i.name not in input_names
)
preamble.extend(
onnx.helper.make_node("Constant", [], [i.values.name], sparse_value=i)
for i in model.graph.sparse_initializer
if i.values.name not in input_names
)
del model.graph.initializer[:]
del model.graph.sparse_initializer[:]
# The API on the protobuf list is a bit limited
# - this prepends the preamble before the rest of the nodes
model.graph.node.reverse()
model.graph.node.extend(reversed(preamble))
model.graph.node.reverse()
# Now we can assume the graph has no initializers

def inline_inner(*args: Var, **kwargs: Var) -> Dict[str, Var]:
for name, arg in zip(in_names, args):
Expand Down
16 changes: 14 additions & 2 deletions src/spox/_scope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Generic, Hashable, Optional, TypeVar, Union, overload
from typing import Dict, Generic, Hashable, Optional, Set, TypeVar, Union, overload

from ._node import Node
from ._var import Var
Expand All @@ -22,12 +22,14 @@ class ScopeSpace(Generic[H]):

name_of: Dict[H, str]
of_name: Dict[str, H]
reserved: Set[str]
parent: "Optional[ScopeSpace[H]]"

def __init__(
self,
name_of: Optional[Dict[H, str]] = None,
of_name: Optional[Dict[str, H]] = None,
reserved: Optional[Set[str]] = None,
parent: "Optional[ScopeSpace[H]]" = None,
):
"""
Expand All @@ -37,17 +39,21 @@ def __init__(
Name of a given object in this namespace.
of_name
Object with a given name in this namespace.
reserved
Set of reserved names, taken up by anonymous objects.
parent
Parent scope's namespace. Is accessed first before all checks, but is not modified directly.
Namespace of a parent scope. Is accessed first before all checks, but never modified.
"""
self.name_of = name_of.copy() if name_of is not None else {}
self.of_name = of_name.copy() if of_name is not None else {}
self.reserved = reserved.copy() if reserved is not None else set()
self.parent = parent

def __contains__(self, item: Union[str, H]) -> bool:
"""Checks if a given name or object is declared in this (or outer) namespace."""
return (
(self.parent is not None and item in self.parent)
or item in self.reserved
or item in self.name_of
or item in self.of_name
)
Expand Down Expand Up @@ -125,6 +131,12 @@ def maybe_enum(self, base: str, suffix: str = "_{}") -> str:
return base
return self.enum(base, suffix)

def reserve(self, name: str) -> str:
if name in self:
raise ScopeError(f"Reserved name is already in use: {name}")
self.reserved.add(name)
return name


class Scope:
"""
Expand Down
Loading

0 comments on commit 5fb9026

Please sign in to comment.