Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.10"
ONNX_IR = "onnx_ir==0.1.12"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
59 changes: 35 additions & 24 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,24 +1039,29 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
original_value = node.outputs[0]
if not isinstance(array, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
# So, a constant-value of type sequence is not folded, but it can be used
# to optimize subsequent operations when possible.
logger.info(
"Skip storing constant folded value %s due to unsupported type %s.",
irvalue.name,
type(value),
original_value.name,
type(array),
)
return None

tensor = ir.tensor(value)
tensor.name = irvalue.name
irvalue.const_value = tensor
tensor = ir.tensor(array)
tensor.name = original_value.name
initializer = ir.Value(
name=original_value.name,
type=ir.TensorType(ir.DataType(tensor.dtype)),
shape=tensor.shape, # type: ignore[arg-type]
const_value=tensor,
)

if value.size > self.output_size_limit:
if array.size > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
Expand All @@ -1065,25 +1070,23 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
array = _get_numpy_value(input)
if array is not None:
removed_input_size += array.size
increased_size = value.size - removed_input_size
increased_size = array.size - removed_input_size
if increased_size > 0:
logger.info(
"Skip storing constant folded nvalue %s due to large size %s.",
irvalue.name,
value.size,
original_value.name,
array.size,
)
return None

logger.debug(
"New constant for value %s dtype: %s shape: %s",
irvalue.name,
value.dtype,
value.shape,
"New Initializer for value %s dtype: %s shape: %s",
original_value.name,
array.dtype,
array.shape,
)

attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node
return initializer

def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
Expand All @@ -1109,7 +1112,13 @@ def process_node(self, node: ir.Node) -> Replacement | None:
self._do_inference(node)

if node.domain not in self._opset_imports:
logger.debug(
"Skipping constant folding for node %r due to missing opset import for domain %r.",
node.name,
node.domain,
)
return None

version = self._opset_imports[node.domain]
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
Expand Down Expand Up @@ -1153,7 +1162,7 @@ def process_node(self, node: ir.Node) -> Replacement | None:
)
return None

# Ensure all node inputs are constants
# Ensure all node inputs are constants or initializers
if any(x.const_value is None for x in node.inputs if x is not None):
return None

Expand Down Expand Up @@ -1227,10 +1236,13 @@ def convert(av):
if outputs is None:
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if replacement is None:
new_initializer_value = self.new_initializer(node, outputs)
if new_initializer_value is None:
return None
return Replacement(replacement.outputs, [replacement])
# Add the new initializer to the graph
assert node.graph is not None
node.graph.register_initializer(new_initializer_value)
return Replacement([new_initializer_value], [])
else:
logger.warning(
"Skipping constant folding for op %s with multiple outputs.", node.op_type
Expand All @@ -1244,7 +1256,6 @@ def replace_node(

# Record the names of the values that has contributed to the replacement
_record_contributing_values(node, replacement)

ir.convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)
Expand Down
42 changes: 23 additions & 19 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test_fold_add(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_cast_like(self):
model = """
Expand All @@ -51,8 +51,8 @@ def test_fold_cast_like(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_shape(self):
model = """
Expand All @@ -67,8 +67,8 @@ def test_fold_shape(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_shape_slice(self):
model = """
Expand All @@ -83,8 +83,8 @@ def test_fold_shape_slice(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_if_cond(self):
model = """
Expand Down Expand Up @@ -130,9 +130,11 @@ def test_fold_inside_if_branch(self):
optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 1)
then_graph = optimized.graph[0].attributes["then_branch"].as_graph()
self.assertEqual(len(then_graph), 2)
self.assertEqual(len(then_graph), 1)
self.assertIn("temp", then_graph.initializers)
else_graph = optimized.graph[0].attributes["else_branch"].as_graph()
self.assertEqual(len(else_graph), 2)
self.assertEqual(len(else_graph), 1)
self.assertIn("temp", else_graph.initializers)

def test_fold_if_propagate(self):
model = """
Expand All @@ -154,9 +156,8 @@ def test_fold_if_propagate(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "m_square")
self.assertEqual(optimized.graph[0].op_type, "Constant")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("m_square", optimized.graph.initializers)

def test_fold_redundant_cast(self):
model = """
Expand Down Expand Up @@ -209,8 +210,8 @@ def test_shape_inference(self):
"""

optimized = self._fold(model, onnx_shape_inference=True)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "C")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("C", optimized.graph.initializers)

def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(
self,
Expand Down Expand Up @@ -614,7 +615,8 @@ def test_input_size_limit(self):
# Since there is no increase in model-size, output-size is not a concern.
optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant", "Add"])
self.assertEqual(ops, ["Add"])
self.assertIn("w_squared", optimized.graph.initializers)

def test_transpose_is_always_folded(self):
model_text = """
Expand All @@ -633,7 +635,8 @@ def test_transpose_is_always_folded(self):
# Input size limit will not prevent folding of Transpose op
optimized = self._fold(model, input_size_limit=1)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant"])
self.assertEqual(ops, [])
self.assertIn("z", optimized.graph.initializers)

def test_node_is_folded_if_specified_as_should_fold(self):
model_text = """
Expand All @@ -656,9 +659,10 @@ def test_node_is_folded_if_specified_as_should_fold(self):
model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None
)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant"])
self.assertEqual(ops, [])
self.assertIn("z", optimized.graph.initializers)
np.testing.assert_array_equal(
optimized.graph.node(0).attributes["value"].as_tensor().numpy(),
optimized.graph.initializers["z"].const_value,
np.ones((42, 42), dtype=np.int64),
)

Expand Down
Loading