From 73c1d2b9e5fe56bde3889503c1b463c868fa6ea7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 1 Oct 2025 08:58:21 -0700 Subject: [PATCH 1/5] Implement shape merging in identity elimination pass Following https://github.com/microsoft/onnxscript/pull/2588. Handle shape info as well. Signed-off-by: Justin Chu --- .../passes/common/identity_elimination.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/onnx_ir/passes/common/identity_elimination.py b/src/onnx_ir/passes/common/identity_elimination.py index ae28bbcf..899455f7 100644 --- a/src/onnx_ir/passes/common/identity_elimination.py +++ b/src/onnx_ir/passes/common/identity_elimination.py @@ -15,6 +15,27 @@ logger = logging.getLogger(__name__) +def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: + def merge_dims(dim1, dim2): + if dim1 == dim2: + return dim1 + if not isinstance(dim1, ir.SymbolicDim): + return dim1 # Prefer int value over symbolic dim + if not isinstance(dim2, ir.SymbolicDim): + return dim2 + if dim1.value is None: + return dim2 + return dim1 + + if shape1 is None: + return shape2 + if shape2 is None: + return shape1 + if len(shape1) != len(shape2): + raise ValueError("Shapes must have the same rank.") + return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) + + class IdentityEliminationPass(ir.passes.InPlacePass): """Pass for eliminating redundant Identity nodes. @@ -75,6 +96,11 @@ def _try_eliminate_identity_node(self, node: ir.Node) -> bool: if output_is_graph_output and input_is_graph_input: return False + # Copy over shape/type if the output has more complete information + input_value.shape = _merge_shapes(input_value.shape, output_value.shape) + if input_value.type is None: + input_value.type = output_value.type + # Case 1 & 2 (merged): Eliminate the identity node # Replace all uses of output with input ir.convenience.replace_all_uses_with(output_value, input_value) From 05735d21fcaf0e10eef8c7b51d00b608147b9314 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 1 Oct 2025 09:01:02 -0700 Subject: [PATCH 2/5] Update src/onnx_ir/passes/common/identity_elimination.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/identity_elimination.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/identity_elimination.py b/src/onnx_ir/passes/common/identity_elimination.py index 899455f7..65d71dba 100644 --- a/src/onnx_ir/passes/common/identity_elimination.py +++ b/src/onnx_ir/passes/common/identity_elimination.py @@ -32,7 +32,7 @@ def merge_dims(dim1, dim2): if shape2 is None: return shape1 if len(shape1) != len(shape2): - raise ValueError("Shapes must have the same rank.") + raise ValueError(f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}.") return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) From fbc9fbf05f006ea3b3b56845a24d9607607b27b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 1 Oct 2025 14:43:26 -0700 Subject: [PATCH 3/5] Format Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/identity_elimination.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/identity_elimination.py b/src/onnx_ir/passes/common/identity_elimination.py index 65d71dba..eb0895f2 100644 --- a/src/onnx_ir/passes/common/identity_elimination.py +++ b/src/onnx_ir/passes/common/identity_elimination.py @@ -32,7 +32,9 @@ def merge_dims(dim1, dim2): if shape2 is None: return shape1 if len(shape1) != len(shape2): - raise ValueError(f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}.") + raise ValueError( + f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}." + ) return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) From aec769fcf72b3967ba4f17a591481fafa1da59da Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 6 Oct 2025 09:59:34 -0700 Subject: [PATCH 4/5] tests Signed-off-by: Justin Chu --- .../common/identity_elimination_test.py | 352 ++++++++++++++++++ 1 file changed, 352 insertions(+) diff --git a/src/onnx_ir/passes/common/identity_elimination_test.py b/src/onnx_ir/passes/common/identity_elimination_test.py index 45903594..bafb6cce 100644 --- a/src/onnx_ir/passes/common/identity_elimination_test.py +++ b/src/onnx_ir/passes/common/identity_elimination_test.py @@ -578,6 +578,358 @@ def test_multiple_graph_outputs_with_identity(self): self.assertEqual(other_output.name, "other_output") self.assertIs(other_output, add_node.outputs[0]) + def test_shape_merging_input_has_no_shape(self): + """Test shape merging when input value has no shape but output has shape info.""" + # Create input value with no shape + input_value = ir.Value( + producer=None, name="input", shape=None, type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Identity node with output that has shape info + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([2, 3, 4]) + identity_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.FLOAT) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([2, 3, 4]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify Identity node was removed + remaining_nodes = list(result.model.graph) + self.assertEqual(len(remaining_nodes), 1) + self.assertEqual(remaining_nodes[0].op_type, "Add") + + # Verify the input value now has the shape from the identity output + self.assertIsNotNone(input_value.shape) + self.assertEqual(input_value.shape, ir.Shape([2, 3, 4])) + + def test_shape_merging_output_has_no_shape(self): + """Test shape merging when output value has no shape but input has shape info.""" + # Create input value with shape + input_value = ir.val( + "input", shape=ir.Shape([5, 6]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Identity node with output that has no shape info + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = None + identity_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([5, 6]), type=ir.TensorType(ir.DataType.FLOAT) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([5, 6]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify the input value kept its original shape + self.assertEqual(input_value.shape, ir.Shape([5, 6])) + + def test_shape_merging_identical_shapes(self): + """Test shape merging when input and output have identical shape information.""" + # Create input value with shape + input_value = ir.val( + "input", shape=ir.Shape([3, 4, 5]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Identity node with identical shape + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([3, 4, 5]) + identity_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([3, 4, 5]), type=ir.TensorType(ir.DataType.FLOAT) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([3, 4, 5]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify the input value kept its shape unchanged + self.assertEqual(input_value.shape, ir.Shape([3, 4, 5])) + + def test_shape_merging_int_vs_symbolic_dims(self): + """Test shape merging where one shape has int dims and other has symbolic dims.""" + # Create symbolic dimensions + sym_dim = ir.SymbolicDim("batch") + + # Create input value with symbolic dimensions + input_value = ir.Value( + producer=None, + name="input", + shape=ir.Shape([sym_dim, 10]), + type=ir.TensorType(ir.DataType.FLOAT), + ) + + # Create Identity node with int dimensions in the same positions + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([32, 10]) # int dims + identity_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([32, 10]), type=ir.TensorType(ir.DataType.FLOAT) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([32, 10]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify the input value now has the int dimensions (preferred over symbolic) + expected_shape = ir.Shape([32, 10]) + self.assertEqual(input_value.shape, expected_shape) + + def test_shape_merging_different_ranks_raises_error(self): + """Test that merging shapes with different ranks raises an error.""" + # Create input value with 2D shape + input_value = ir.Value( + producer=None, + name="input", + shape=ir.Shape([2, 3]), + type=ir.TensorType(ir.DataType.FLOAT), + ) + + # Create Identity node with 3D shape (different rank) + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([2, 3, 4]) # Different rank + identity_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create Add node + add_input = ir.val( + "add_input", shape=ir.Shape([2, 3, 4]), type=ir.TensorType(ir.DataType.FLOAT) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([2, 3, 4]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + + # Should raise ValueError due to different ranks + with self.assertRaises(ValueError) as context: + pass_instance(model) + + self.assertIn("Shapes must have the same rank", str(context.exception)) + + def test_type_copying_from_output_to_input(self): + """Test that type is copied from output to input when input has no type.""" + # Create input value with no type + input_value = ir.Value(producer=None, name="input", shape=ir.Shape([2, 2]), type=None) + + # Create Identity node with type info + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([2, 2]) + identity_node.outputs[0].type = ir.TensorType(ir.DataType.INT32) + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.INT32) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([2, 2]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.INT32) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify the input value now has the type from the identity output + self.assertIsNotNone(input_value.type) + self.assertEqual(input_value.type, ir.TensorType(ir.DataType.INT32)) + + def test_type_preservation_when_input_already_has_type(self): + """Test that existing input type is preserved when both input and output have types.""" + # Create input value with existing type + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create Identity node with different type + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([2, 2]) + identity_node.outputs[0].type = ir.TensorType(ir.DataType.INT32) # Different type + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.INT32) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([2, 2]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.INT32) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify the input value kept its original type (not overwritten) + self.assertEqual(input_value.type, ir.TensorType(ir.DataType.FLOAT)) + + def test_symbolic_dim_merging_edge_cases(self): + """Test various symbolic dimension merging scenarios.""" + # Test case where one symbolic dim has value None, other has value + sym_dim_no_value = ir.SymbolicDim(None) # value is None by default + sym_dim_with_value = ir.SymbolicDim("batch") # different name + + # Create input value with symbolic dim that has no value + input_value = ir.Value( + producer=None, + name="input", + shape=ir.Shape([sym_dim_no_value, 10]), + type=ir.TensorType(ir.DataType.FLOAT), + ) + + # Create Identity node with symbolic dim that has a value + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "identity_output" + identity_node.outputs[0].shape = ir.Shape([sym_dim_with_value, 10]) + identity_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create Add node that uses the Identity output + add_input = ir.val( + "add_input", shape=ir.Shape([16, 10]), type=ir.TensorType(ir.DataType.FLOAT) + ) + add_node = ir.Node("", "Add", inputs=[identity_node.outputs[0], add_input]) + add_node.outputs[0].name = "add_output" + add_node.outputs[0].shape = ir.Shape([16, 10]) + add_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value, add_input], + outputs=[add_node.outputs[0]], + nodes=[identity_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_elimination.IdentityEliminationPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify the input value now has the merged shape + self.assertIsNotNone(input_value.shape) + assert input_value.shape is not None # For type checker + self.assertEqual(input_value.shape[0], sym_dim_with_value) + self.assertEqual(input_value.shape[1], 10) + def test_duplicate_identity_output_in_graph_outputs(self): """Test case where the same Identity output appears multiple times in graph outputs.""" # Create intermediate value (not a graph input) From 6c0edccff7c135e77885277e72d803883328a126 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 6 Oct 2025 10:34:02 -0700 Subject: [PATCH 5/5] Apply suggestion from @justinchuby Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/identity_elimination_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/identity_elimination_test.py b/src/onnx_ir/passes/common/identity_elimination_test.py index bafb6cce..165eaf76 100644 --- a/src/onnx_ir/passes/common/identity_elimination_test.py +++ b/src/onnx_ir/passes/common/identity_elimination_test.py @@ -882,7 +882,7 @@ def test_type_preservation_when_input_already_has_type(self): def test_symbolic_dim_merging_edge_cases(self): """Test various symbolic dimension merging scenarios.""" # Test case where one symbolic dim has value None, other has value - sym_dim_no_value = ir.SymbolicDim(None) # value is None by default + sym_dim_no_value = ir.SymbolicDim(None) sym_dim_with_value = ir.SymbolicDim("batch") # different name # Create input value with symbolic dim that has no value