Skip to content

Commit

Permalink
A couple of optimizations and refinements (#1947)
Browse files Browse the repository at this point in the history
Extract the independent optimization/refinements from [the fusion
PR](#1938) as a separate PR,
ready to be reviewed/merged. (The fusion work is still WIP.)

* Replace Expand by Identity when applicable (in core optimization)
* Cleanup Dropout Identity replacement in the case when Dropout has mask
output
* Make repeated (redundant) call to inliner efficient
  • Loading branch information
gramalingam authored Nov 15, 2024
1 parent d81480b commit e6e3d52
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 45 deletions.
15 changes: 13 additions & 2 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import onnx

import onnxscript.optimizer._constant_folding as constant_folding
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import basic_constant_propagation
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import optimize_ir
from onnxscript.optimizer._remove_unused import remove_unused_nodes

basic_constant_propagation = constant_folding.basic_constant_propagation
fold_constants_ir = constant_folding.fold_constants


def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
Expand All @@ -19,8 +22,16 @@ def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
return legacy_optimizer.optimize(model, *args, **kwargs)


def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
return legacy_constant_folding.fold_constants(model, *args, **kwargs)


__all__ = [
"fold_constants",
"fold_constants_ir",
"remove_unused_nodes",
"optimize",
"optimize_ir",
Expand Down
43 changes: 36 additions & 7 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,25 +448,54 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
@register("Dropout", version=(12, None))
def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Dropout by Identity when applicable."""
if len(node.outputs) != 1:
# If output mask is requested, optimization is more complex.
# TODO: handle this case. But unlikely to be needed in practice.
return None

def optimized_dropout():
input = node.inputs[0]
output = op.Identity(input)
if len(node.outputs) == 1:
return output
else:
true_tensor = ir.tensor([True])
input_shape = op.Shape(input)
mask = op.ConstantOfShape(input_shape, value=true_tensor)
return output, mask

inputs = node.inputs
if (len(inputs) <= 2) or inputs[2] is None:
# No training_mode specified:
return op.Identity(inputs[0])
return optimized_dropout()
if _get_bool_value(inputs[2]) is False:
# training_mode is False: dropout is not applied.
return op.Identity(inputs[0])
return optimized_dropout()
ratio = _get_numpy_value(inputs[1])
if ratio is None:
return None
if ratio.size != 1: # Only scalar dropout ratio is supported.
return None
if ratio.item() == 0:
# dropout ratio is 0: dropout is not applied.
return op.Identity(inputs[0])
return optimized_dropout()
return None


@register("Expand")
def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Expand node by Identity when applicable."""
if len(node.inputs) != 2:
return None
if (input := node.inputs[0]) is None:
return None
if (input_shape := input.shape) is None:
# Input shape is not known.
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
return None
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
if input_shape.dims == tuple(expanded_shape.tolist()):
return op.Identity(input)
return None


Expand Down
95 changes: 62 additions & 33 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,29 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
self.assertEqual(optimized.graph.node[6].op_type, "Concat")
onnx.checker.check_model(optimized)


class FoldConstantsIrTest(unittest.TestCase):
def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model:
model_proto = onnx.parser.parse_model(model_text)
model = serde.deserialize_model(model_proto)
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(model)
return model

def test_initializer_input_not_folded(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z)
{
# c is not a constant, and following should not be folded.
two_c = Add (c, c)
z = Mul (x, two_c)
}
"""
optimized = self._fold(model_text)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph.node(0).op_type, "Add")

@parameterized.parameterized.expand(
[
("output = Dropout(input)",),
Expand All @@ -404,58 +427,64 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
]
)
def test_dropout_identity(self, dropout_node: str):
if not self.using_ir:
self.skipTest("New optimizations not supported for legacy optimizer")
model = onnx.parser.parse_model(f"""
model = f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] input) => (float[N] output)
<float zero = {{0.0}}, float half = {{0.5}}, bool true = {{1}}, bool false = {{0}}>
{{
{dropout_node}
}}
""")
"""
optimized = self._fold(model)
self.assertEqual(len(optimized.graph.node), 1)
self.assertEqual(optimized.graph.node[0].op_type, "Identity")
self.assertEqual(len(optimized.graph), 1)
self.assertEqual(optimized.graph.node(0).op_type, "Identity")

@parameterized.parameterized.expand(
[
("output, mask = Dropout(input)",),
("output, mask = Dropout(input, zero, true)",),
("output, mask = Dropout(input, half)",),
("output, mask = Dropout(input, half, false)",),
]
)
def test_dropout_identity_mask(self, dropout_node: str):
model = f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] input) => (float[N] output, bool[N] mask)
<float zero = {{0.0}}, float half = {{0.5}}, bool true = {{1}}, bool false = {{0}}>
{{
{dropout_node}
}}
"""
optimized = self._fold(model)
nodes = list(optimized.graph)
self.assertEqual(len(nodes), 3)
ops = [node.op_type for node in nodes]
self.assertEqual(ops, ["Identity", "Shape", "ConstantOfShape"])

def test_concat_identity(self):
if not self.using_ir:
self.skipTest("New optimizations not supported for legacy optimizer")
model = onnx.parser.parse_model(
"""
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z)
{
z = Concat <axis=-1> (x)
}
"""
)
optimized = self._fold(model)
self.assertEqual(len(optimized.graph.node), 1)
self.assertEqual(optimized.graph.node[0].op_type, "Identity")

self.assertEqual(len(optimized.graph), 1)
self.assertEqual(optimized.graph.node(0).op_type, "Identity")

class FoldConstantsIrTest(unittest.TestCase):
def _fold(self, model_text: str, onnx_shape_inference=False) -> ir.Model:
model_proto = onnx.parser.parse_model(model_text)
model = serde.deserialize_model(model_proto)
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(model)
return model

def test_initializer_input_not_folded(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 18]>
agraph (float[N] x, float[1] c = {1.0} ) => (float[N] z)
def test_expand_identity(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[128, 256] x) => (float[128, 256] z)
{
# c is not a constant, and following should not be folded.
two_c = Add (c, c)
z = Mul (x, two_c)
shape = Constant <value_ints=[128, 256]> ()
z = Expand (x, shape)
}
"""
optimized = self._fold(model_text)
self.assertEqual(len(optimized.graph), 2)
self.assertEqual(optimized.graph.node(0).op_type, "Add")
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def inline_calls_in(self, graph: ir.Graph) -> None:

def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()
if model.functions:
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()

0 comments on commit e6e3d52

Please sign in to comment.