Skip to content

Commit

Permalink
A couple of optimizer and rewriter extensions (#1937)
Browse files Browse the repository at this point in the history
A few extensions motivated by ongoing transformer fusion optimizations:

Pattern matching:
* Extend pattern-matching pattern to allow specifying that extra-inputs
are allowed.

Optimizations:
* Concat (x) can be replaced by Identity(x)
* Redundant cast optimization was missing in core optimizer (though
present as a llama rewrite rule).
* Dropout optimizations moved into core optimizer (from rewrite rule;
rewrite rule has an issue, use of attribute instead of input, and it
seemed better to move it into core optimizer).

In general, for optimizations involving a single node, the core
optimizer is a better place (at least, as long as they are generic, and
not backend-specific) than rewrite rules. It is more efficient.

* Fix input/output size limit of constant-folding to be number of bytes.
(It is currently inconsistent, as number of bytes for one and number of
elements for another).
  • Loading branch information
gramalingam authored Nov 8, 2024
1 parent edfa265 commit 32090a8
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 17 deletions.
65 changes: 54 additions & 11 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,29 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


# TODO(rama): The following should not be necessary. Generic incremental shape-inference
# should handle this. This essentially implements type/shape-inference for Cast op.
@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
output = _get_output(node, 0)
if input is not None and output is not None:
input_shape = input.shape
if input_shape is not None:
output.shape = input_shape.copy()
if output is not None:
output_dtype = _get_int_attribute(node, "to", None)
if output_dtype is not None:
output.type = ir.TensorType(ir.DataType(output_dtype))

if input is None or output is None:
return None

# TODO(rama): Parts of the following logic (implementing type/shape inference
# for Cast op) should be unnecessary. Generic incremental shape-inference
# should handle this. Only the optimization to eliminate redundant Cast ops
# should be needed here.

input_shape = input.shape
if input_shape is not None:
output.shape = input_shape.copy()

input_dtype = _get_input_element_type(node, 0)
output_dtype = _get_int_attribute(node, "to", None)
if output_dtype is not None:
if input_dtype == output_dtype:
return op.Identity(input)
output.type = ir.TensorType(ir.DataType(output_dtype))
return None


Expand Down Expand Up @@ -413,6 +422,40 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None


@register("Concat")
def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Concat node with a single input by Identity"""
inputs = node.inputs
if len(inputs) == 1:
return op.Identity(inputs[0])
return None


@register("Dropout", version=(12, None))
def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Dropout by Identity when applicable."""
if len(node.outputs) != 1:
# If output mask is requested, optimization is more complex.
# TODO: handle this case. But unlikely to be needed in practice.
return None
inputs = node.inputs
if (len(inputs) <= 2) or inputs[2] is None:
# No training_mode specified:
return op.Identity(inputs[0])
if _get_bool_value(inputs[2]) is False:
# training_mode is False: dropout is not applied.
return op.Identity(inputs[0])
ratio = _get_numpy_value(inputs[1])
if ratio is None:
return None
if ratio.size != 1: # Only scalar dropout ratio is supported.
return None
if ratio.item() == 0:
# dropout ratio is 0: dropout is not applied.
return op.Identity(inputs[0])
return None


@register("ConcatFromSequence")
def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = node.inputs[0]
Expand Down Expand Up @@ -711,7 +754,7 @@ def process_node(self, node: ir.Node):
if any(x is None for x in input_values):
return None

if any(input.size > self._input_size_limit for input in input_values): # type: ignore[union-attr]
if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
logger.debug(
Expand Down
39 changes: 39 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,45 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1(
self.assertEqual(optimized.graph.node[6].op_type, "Concat")
onnx.checker.check_model(optimized)

@parameterized.parameterized.expand(
[
("output = Dropout(input)",),
("output = Dropout(input, zero, true)",),
("output = Dropout(input, half)",),
("output = Dropout(input, half, false)",),
]
)
def test_dropout_identity(self, dropout_node: str):
if not self.using_ir:
self.skipTest("New optimizations not supported for legacy optimizer")
model = onnx.parser.parse_model(f"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] input) => (float[N] output)
<float zero = {{0.0}}, float half = {{0.5}}, bool true = {{1}}, bool false = {{0}}>
{{
{dropout_node}
}}
""")
optimized = self._fold(model)
self.assertEqual(len(optimized.graph.node), 1)
self.assertEqual(optimized.graph.node[0].op_type, "Identity")

def test_concat_identity(self):
if not self.using_ir:
self.skipTest("New optimizations not supported for legacy optimizer")
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x) => (float[N] z)
{
z = Concat <axis=-1> (x)
}
"""
)
optimized = self._fold(model)
self.assertEqual(len(optimized.graph.node), 1)
self.assertEqual(optimized.graph.node[0].op_type, "Identity")


if __name__ == "__main__":
unittest.main()
37 changes: 31 additions & 6 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __call__(
_version: int | None = None,
_outputs: int | list[str | None] = 1,
_allow_other_attributes: bool | None = None,
_allow_other_inputs: bool | None = None,
**kwargs,
):
if _version is not None:
Expand All @@ -249,7 +250,13 @@ def __call__(
inputs = [_to_value_pattern(x) for x in args]
attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()}
node_pattern = NodePattern(
opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes
opset_pattern,
self.op_name,
inputs,
attributes,
_outputs,
allow_other_attributes=_allow_other_attributes,
allow_other_inputs=_allow_other_inputs,
)
self.pattern_builder.add_node(node_pattern)
output_values = node_pattern.outputs
Expand Down Expand Up @@ -471,16 +478,22 @@ def __init__(
inputs: Sequence[int | float | ValuePattern | None],
attributes: dict[str, AttrPattern],
outputs: Sequence[str | None],
*,
allow_other_attributes: bool | None,
allow_other_inputs: bool | None,
):
if allow_other_attributes is None:
# Default behavior: allow other unmatched attributes in the node.
allow_other_attributes = True
if allow_other_inputs is None:
# TODO(rama): Should we default to True? For now, we preserve the current behavior.
allow_other_inputs = False
self.domain = domain
self.op = StringConstantPattern(op) if isinstance(op, str) else op
self.inputs = [_to_value_pattern(x) for x in inputs]
self.attributes = attributes
self.allow_other_attributes = allow_other_attributes
self.allow_other_inputs = allow_other_inputs
# In the common case, domain and op are constants, which can be used to optimize matching.
if isinstance(op, str) and isinstance(domain, StringConstantPattern):
# TODO(rama): support overloaded operators.
Expand Down Expand Up @@ -557,7 +570,13 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat
inputs = [inputs[1], inputs[0]]
outputs = [value.name for value in self.outputs]
copied = NodePattern(
self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes
self.domain,
self.op,
inputs,
self.attributes,
outputs,
allow_other_attributes=self.allow_other_attributes,
allow_other_inputs=self.allow_other_inputs,
)
node_map[self] = copied
return copied
Expand Down Expand Up @@ -1022,10 +1041,16 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool:
self._matched[pattern_node] = node

# TODO: Revisit this to handle optional trailing inputs better.
if len(node.inputs) != len(pattern_node.inputs):
return self.fail(
"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}"
)
if pattern_node.allow_other_inputs:
if len(node.inputs) < len(pattern_node.inputs):
return self.fail(
f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})"
)
else:
if len(node.inputs) != len(pattern_node.inputs):
return self.fail(
f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}"
)

for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs):
# arg_pattern could be a Var, if it's the original arg.
Expand Down

0 comments on commit 32090a8

Please sign in to comment.