Skip to content

Commit

Permalink
Handle input initializers correctly in constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Nov 14, 2024
1 parent 1cfe0ca commit fd31cf5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
24 changes: 24 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class Replacement:
class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
self._initializer_inputs: list[set[ir.Value]] = []

def get_sym_value(self, value: ir.Value | None) -> Any:
if value is None:
Expand All @@ -146,6 +147,19 @@ def get_sym_value(self, value: ir.Value | None) -> Any:
def set_sym_value(self, value: ir.Value, sym_value: Any) -> None:
self._sym_value_map[value] = sym_value

def push_initializer_inputs(self) -> None:
self._initializer_inputs.append(set())

def pop_initializer_inputs(self) -> None:
self._initializer_inputs.pop()

def add_initializer_input(self, value: ir.Value) -> None:
assert self._initializer_inputs
self._initializer_inputs[-1].add(value)

def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)


# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
Expand Down Expand Up @@ -754,6 +768,9 @@ def process_node(self, node: ir.Node):
if any(x is None for x in input_values):
return None

if any(self._state.is_initializer_input(x) for x in node.inputs):

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "is_initializer_input" of "OptimizerState" has incompatible type "Value | None"; expected "Value" To disable, use # type: ignore[arg-type]
return None

if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
Expand Down Expand Up @@ -817,9 +834,16 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
self.replace_node(node, replacement, root)

def visit_graph(self, graph: ir.Graph) -> None:
self._state.push_initializer_inputs()
for input in graph.inputs:
if input.const_value is not None:
self._state.add_initializer_input(input)

for node in graph:
self.visit_node(node, graph)

self._state.pop_initializer_inputs()

def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)
Expand Down
24 changes: 24 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import parameterized
import pytest

import onnxscript.ir as ir
import onnxscript.optimizer as optimizer
from onnxscript.ir import serde
from onnxscript.optimizer import _constant_folding
Expand Down Expand Up @@ -434,5 +435,28 @@ def test_concat_identity(self):
self.assertEqual(optimized.graph.node[0].op_type, "Identity")


class FoldConstantsIrTest(unittest.TestCase):
def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model:
model_proto = onnx.parser.parse_model(model_text)
model = serde.deserialize_model(model_proto)
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(model)
return model

def test_initializer_input_not_folded(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z)
{
# c is not a constant, and following should not be folded.
two_c = Add (c, c)
z = Mul (x, two_c)
}
"""
optimized = self._fold(model_text)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph.node(0).op_type, "Add")


if __name__ == "__main__":
unittest.main()

0 comments on commit fd31cf5

Please sign in to comment.