From e6e3d52531e3ca882888da8e466b47fd921d678d Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Thu, 14 Nov 2024 22:18:01 -0800 Subject: [PATCH] A couple of optimizations and refinements (#1947) Extract the independent optimization/refinements from [the fusion PR](https://github.com/microsoft/onnxscript/pull/1938) as a separate PR, ready to be reviewed/merged. (The fusion work is still WIP.) * Replace Expand by Identity when applicable (in core optimization) * Cleanup Dropout Identity replacement in the case when Dropout has mask output * Make repeated (redundant) call to inliner efficient --- onnxscript/optimizer/__init__.py | 15 ++- onnxscript/optimizer/_constant_folding.py | 43 +++++++-- .../optimizer/_constant_folding_test.py | 95 ++++++++++++------- onnxscript/optimizer/_inliner.py | 7 +- 4 files changed, 115 insertions(+), 45 deletions(-) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index f30976c24..8ba6229c1 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -4,13 +4,16 @@ import onnx +import onnxscript.optimizer._constant_folding as constant_folding import onnxscript.optimizer._legacy._optimizer as legacy_optimizer +import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding from onnxscript import ir -from onnxscript.optimizer._constant_folding import basic_constant_propagation -from onnxscript.optimizer._legacy.constant_folding import fold_constants from onnxscript.optimizer._optimizer import optimize_ir from onnxscript.optimizer._remove_unused import remove_unused_nodes +basic_constant_propagation = constant_folding.basic_constant_propagation +fold_constants_ir = constant_folding.fold_constants + def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): if isinstance(model, ir.Model): @@ -19,8 +22,16 @@ def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): return legacy_optimizer.optimize(model, *args, **kwargs) +def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs): + if isinstance(model, ir.Model): + return constant_folding.fold_constants(model, *args, **kwargs) + else: + return legacy_constant_folding.fold_constants(model, *args, **kwargs) + + __all__ = [ "fold_constants", + "fold_constants_ir", "remove_unused_nodes", "optimize", "optimize_ir", diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index a5141c6bc..4053bb2a1 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -448,17 +448,25 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: @register("Dropout", version=(12, None)) def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Dropout by Identity when applicable.""" - if len(node.outputs) != 1: - # If output mask is requested, optimization is more complex. - # TODO: handle this case. But unlikely to be needed in practice. - return None + + def optimized_dropout(): + input = node.inputs[0] + output = op.Identity(input) + if len(node.outputs) == 1: + return output + else: + true_tensor = ir.tensor([True]) + input_shape = op.Shape(input) + mask = op.ConstantOfShape(input_shape, value=true_tensor) + return output, mask + inputs = node.inputs if (len(inputs) <= 2) or inputs[2] is None: # No training_mode specified: - return op.Identity(inputs[0]) + return optimized_dropout() if _get_bool_value(inputs[2]) is False: # training_mode is False: dropout is not applied. - return op.Identity(inputs[0]) + return optimized_dropout() ratio = _get_numpy_value(inputs[1]) if ratio is None: return None @@ -466,7 +474,28 @@ def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None if ratio.item() == 0: # dropout ratio is 0: dropout is not applied. - return op.Identity(inputs[0]) + return optimized_dropout() + return None + + +@register("Expand") +def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace an Expand node by Identity when applicable.""" + if len(node.inputs) != 2: + return None + if (input := node.inputs[0]) is None: + return None + if (input_shape := input.shape) is None: + # Input shape is not known. + return None + if (expanded_shape := _get_numpy_value(node.inputs[1])) is None: + # Target shape is not known. + return None + if expanded_shape.ndim != 1: + # Target shape must be a 1D tensor. Erroneous model. + return None + if input_shape.dims == tuple(expanded_shape.tolist()): + return op.Identity(input) return None diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index d6a799116..8f2dc0026 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -395,6 +395,29 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( self.assertEqual(optimized.graph.node[6].op_type, "Concat") onnx.checker.check_model(optimized) + +class FoldConstantsIrTest(unittest.TestCase): + def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model: + model_proto = onnx.parser.parse_model(model_text) + model = serde.deserialize_model(model_proto) + _constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) + optimizer.remove_unused_nodes(model) + return model + + def test_initializer_input_not_folded(self): + model_text = """ + + agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z) + { + # c is not a constant, and following should not be folded. + two_c = Add (c, c) + z = Mul (x, two_c) + } + """ + optimized = self._fold(model_text) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph.node(0).op_type, "Add") + @parameterized.parameterized.expand( [ ("output = Dropout(input)",), @@ -404,58 +427,64 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( ] ) def test_dropout_identity(self, dropout_node: str): - if not self.using_ir: - self.skipTest("New optimizations not supported for legacy optimizer") - model = onnx.parser.parse_model(f""" + model = f""" agraph (float[N] input) => (float[N] output) {{ {dropout_node} }} - """) + """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 1) - self.assertEqual(optimized.graph.node[0].op_type, "Identity") + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph.node(0).op_type, "Identity") + + @parameterized.parameterized.expand( + [ + ("output, mask = Dropout(input)",), + ("output, mask = Dropout(input, zero, true)",), + ("output, mask = Dropout(input, half)",), + ("output, mask = Dropout(input, half, false)",), + ] + ) + def test_dropout_identity_mask(self, dropout_node: str): + model = f""" + + agraph (float[N] input) => (float[N] output, bool[N] mask) + + {{ + {dropout_node} + }} + """ + optimized = self._fold(model) + nodes = list(optimized.graph) + self.assertEqual(len(nodes), 3) + ops = [node.op_type for node in nodes] + self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"]) def test_concat_identity(self): - if not self.using_ir: - self.skipTest("New optimizations not supported for legacy optimizer") - model = onnx.parser.parse_model( - """ + model = """ agraph (float[N] x) => (float[N] z) { z = Concat (x) } """ - ) optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 1) - self.assertEqual(optimized.graph.node[0].op_type, "Identity") - + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph.node(0).op_type, "Identity") -class FoldConstantsIrTest(unittest.TestCase): - def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model: - model_proto = onnx.parser.parse_model(model_text) - model = serde.deserialize_model(model_proto) - _constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) - optimizer.remove_unused_nodes(model) - return model - - def test_initializer_input_not_folded(self): - model_text = """ - - agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z) + def test_expand_identity(self): + model = """ + + agraph (float[128, 256] x) => (float[128, 256] z) { - # c is not a constant, and following should not be folded. - two_c = Add (c, c) - z = Mul (x, two_c) + shape = Constant () + z = Expand (x, shape) } - """ - optimized = self._fold(model_text) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph.node(0).op_type, "Add") + """ + optimized = self._fold(model) + self.assertEqual(optimized.graph.node(-1).op_type, "Identity") if __name__ == "__main__": diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/optimizer/_inliner.py index 798bc302a..31bb92087 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/optimizer/_inliner.py @@ -305,6 +305,7 @@ def inline_calls_in(self, graph: ir.Graph) -> None: def inline(model: ir.Model) -> None: """Inline all function calls (recursively) in the model.""" - inliner = _Inliner(model) - inliner.inline_calls_in(model.graph) - model.functions.clear() + if model.functions: + inliner = _Inliner(model) + inliner.inline_calls_in(model.graph) + model.functions.clear()