Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.10"
ONNX_IR = "onnx_ir==0.1.12"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
59 changes: 35 additions & 24 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,24 +1039,29 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
def new_initializer(self, node: ir.Node, array) -> ir.Value | None:
original_value = node.outputs[0]
if not isinstance(array, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
# So, a constant-value of type sequence is not folded, but it can be used
# to optimize subsequent operations when possible.
logger.info(
"Skip storing constant folded value %s due to unsupported type %s.",
irvalue.name,
type(value),
original_value.name,
type(array),
)
return None

tensor = ir.tensor(value)
tensor.name = irvalue.name
irvalue.const_value = tensor
tensor = ir.tensor(array)
tensor.name = original_value.name
initializer = ir.Value(
name=original_value.name,
type=ir.TensorType(ir.DataType(tensor.dtype)),
shape=tensor.shape, # type: ignore[arg-type]
const_value=tensor,
)

if value.size > self.output_size_limit:
if array.size > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
Expand All @@ -1065,25 +1070,23 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
array = _get_numpy_value(input)
if array is not None:
removed_input_size += array.size
increased_size = value.size - removed_input_size
increased_size = array.size - removed_input_size
if increased_size > 0:
logger.info(
"Skip storing constant folded nvalue %s due to large size %s.",
irvalue.name,
value.size,
original_value.name,
array.size,
)
return None

logger.debug(
"New constant for value %s dtype: %s shape: %s",
irvalue.name,
value.dtype,
value.shape,
"New Initializer for value %s dtype: %s shape: %s",
original_value.name,
array.dtype,
array.shape,
)

attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node
return initializer

def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
Expand All @@ -1109,7 +1112,13 @@ def process_node(self, node: ir.Node) -> Replacement | None:
self._do_inference(node)

if node.domain not in self._opset_imports:
logger.debug(
"Skipping constant folding for node %r due to missing opset import for domain %r.",
node.name,
node.domain,
)
return None

version = self._opset_imports[node.domain]
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
Expand Down Expand Up @@ -1153,7 +1162,7 @@ def process_node(self, node: ir.Node) -> Replacement | None:
)
return None

# Ensure all node inputs are constants
# Ensure all node inputs are constants or initializers
if any(x.const_value is None for x in node.inputs if x is not None):
return None

Expand Down Expand Up @@ -1227,10 +1236,13 @@ def convert(av):
if outputs is None:
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if replacement is None:
new_initializer_value = self.new_initializer(node, outputs)
if new_initializer_value is None:
return None
return Replacement(replacement.outputs, [replacement])
# Add the new initializer to the graph
assert node.graph is not None
node.graph.register_initializer(new_initializer_value)
return Replacement([new_initializer_value], [])
else:
logger.warning(
"Skipping constant folding for op %s with multiple outputs.", node.op_type
Expand All @@ -1244,7 +1256,6 @@ def replace_node(

# Record the names of the values that has contributed to the replacement
_record_contributing_values(node, replacement)

ir.convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)
Expand Down
42 changes: 23 additions & 19 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def test_fold_add(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_cast_like(self):
model = """
Expand All @@ -51,8 +51,8 @@ def test_fold_cast_like(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_shape(self):
model = """
Expand All @@ -67,8 +67,8 @@ def test_fold_shape(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_shape_slice(self):
model = """
Expand All @@ -83,8 +83,8 @@ def test_fold_shape_slice(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "four")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("four", optimized.graph.initializers)

def test_fold_if_cond(self):
model = """
Expand Down Expand Up @@ -130,9 +130,11 @@ def test_fold_inside_if_branch(self):
optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 1)
then_graph = optimized.graph[0].attributes["then_branch"].as_graph()
self.assertEqual(len(then_graph), 2)
self.assertEqual(len(then_graph), 1)
self.assertIn("temp", then_graph.initializers)
else_graph = optimized.graph[0].attributes["else_branch"].as_graph()
self.assertEqual(len(else_graph), 2)
self.assertEqual(len(else_graph), 1)
self.assertIn("temp", else_graph.initializers)

def test_fold_if_propagate(self):
model = """
Expand All @@ -154,9 +156,8 @@ def test_fold_if_propagate(self):
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "m_square")
self.assertEqual(optimized.graph[0].op_type, "Constant")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("m_square", optimized.graph.initializers)

def test_fold_redundant_cast(self):
model = """
Expand Down Expand Up @@ -209,8 +210,8 @@ def test_shape_inference(self):
"""

optimized = self._fold(model, onnx_shape_inference=True)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph[0].outputs[0].name, "C")
self.assertEqual(len(optimized.graph), 1)
self.assertIn("C", optimized.graph.initializers)

def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split(
self,
Expand Down Expand Up @@ -614,7 +615,8 @@ def test_input_size_limit(self):
# Since there is no increase in model-size, output-size is not a concern.
optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant", "Add"])
self.assertEqual(ops, ["Add"])
self.assertIn("w_squared", optimized.graph.initializers)

def test_transpose_is_always_folded(self):
model_text = """
Expand All @@ -633,7 +635,8 @@ def test_transpose_is_always_folded(self):
# Input size limit will not prevent folding of Transpose op
optimized = self._fold(model, input_size_limit=1)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant"])
self.assertEqual(ops, [])
self.assertIn("z", optimized.graph.initializers)

def test_node_is_folded_if_specified_as_should_fold(self):
model_text = """
Expand All @@ -656,9 +659,10 @@ def test_node_is_folded_if_specified_as_should_fold(self):
model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None
)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant"])
self.assertEqual(ops, [])
self.assertIn("z", optimized.graph.initializers)
np.testing.assert_array_equal(
optimized.graph.node(0).attributes["value"].as_tensor().numpy(),
optimized.graph.initializers["z"].const_value,
np.ones((42, 42), dtype=np.int64),
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dependencies = [
"ml_dtypes",
"numpy",
"onnx_ir>=0.1.10,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
"onnx_ir>=0.1.12,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
"onnx>=1.16",
"packaging",
"typing_extensions>=4.10",
Expand Down
Loading