diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6a37efa16..e9276cb32 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -292,20 +292,29 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return default -# TODO(rama): The following should not be necessary. Generic incremental shape-inference -# should handle this. This essentially implements type/shape-inference for Cast op. @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) output = _get_output(node, 0) - if input is not None and output is not None: - input_shape = input.shape - if input_shape is not None: - output.shape = input_shape.copy() - if output is not None: - output_dtype = _get_int_attribute(node, "to", None) - if output_dtype is not None: - output.type = ir.TensorType(ir.DataType(output_dtype)) + + if input is None or output is None: + return None + + # TODO(rama): Parts of the following logic (implementing type/shape inference + # for Cast op) should be unnecessary. Generic incremental shape-inference + # should handle this. Only the optimization to eliminate redundant Cast ops + # should be needed here. + + input_shape = input.shape + if input_shape is not None: + output.shape = input_shape.copy() + + input_dtype = _get_input_element_type(node, 0) + output_dtype = _get_int_attribute(node, "to", None) + if output_dtype is not None: + if input_dtype == output_dtype: + return op.Identity(input) + output.type = ir.TensorType(ir.DataType(output_dtype)) return None @@ -413,6 +422,40 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None +@register("Concat") +def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Concat node with a single input by Identity""" + inputs = node.inputs + if len(inputs) == 1: + return op.Identity(inputs[0]) + return None + + +@register("Dropout", version=(12, None)) +def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Dropout by Identity when applicable.""" + if len(node.outputs) != 1: + # If output mask is requested, optimization is more complex. + # TODO: handle this case. But unlikely to be needed in practice. + return None + inputs = node.inputs + if (len(inputs) <= 2) or inputs[2] is None: + # No training_mode specified: + return op.Identity(inputs[0]) + if _get_bool_value(inputs[2]) is False: + # training_mode is False: dropout is not applied. + return op.Identity(inputs[0]) + ratio = _get_numpy_value(inputs[1]) + if ratio is None: + return None + if ratio.size != 1: # Only scalar dropout ratio is supported. + return None + if ratio.item() == 0: + # dropout ratio is 0: dropout is not applied. + return op.Identity(inputs[0]) + return None + + @register("ConcatFromSequence") def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] @@ -711,7 +754,7 @@ def process_node(self, node: ir.Node): if any(x is None for x in input_values): return None - if any(input.size > self._input_size_limit for input in input_values): # type: ignore[union-attr] + if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr] if logger.isEnabledFor(logging.DEBUG): input_sizes = [input.size for input in input_values] # type: ignore[union-attr] logger.debug( diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index b80f01c8f..52e06bd56 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -394,6 +394,45 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( self.assertEqual(optimized.graph.node[6].op_type, "Concat") onnx.checker.check_model(optimized) + @parameterized.parameterized.expand( + [ + ("output = Dropout(input)",), + ("output = Dropout(input, zero, true)",), + ("output = Dropout(input, half)",), + ("output = Dropout(input, half, false)",), + ] + ) + def test_dropout_identity(self, dropout_node: str): + if not self.using_ir: + self.skipTest("New optimizations not supported for legacy optimizer") + model = onnx.parser.parse_model(f""" + + agraph (float[N] input) => (float[N] output) + + {{ + {dropout_node} + }} + """) + optimized = self._fold(model) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].op_type, "Identity") + + def test_concat_identity(self): + if not self.using_ir: + self.skipTest("New optimizations not supported for legacy optimizer") + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Concat (x) + } + """ + ) + optimized = self._fold(model) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].op_type, "Identity") + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 059895ea8..66d9b3196 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -225,6 +225,7 @@ def __call__( _version: int | None = None, _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, + _allow_other_inputs: bool | None = None, **kwargs, ): if _version is not None: @@ -249,7 +250,13 @@ def __call__( inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( - opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes + opset_pattern, + self.op_name, + inputs, + attributes, + _outputs, + allow_other_attributes=_allow_other_attributes, + allow_other_inputs=_allow_other_inputs, ) self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs @@ -471,16 +478,22 @@ def __init__( inputs: Sequence[int | float | ValuePattern | None], attributes: dict[str, AttrPattern], outputs: Sequence[str | None], + *, allow_other_attributes: bool | None, + allow_other_inputs: bool | None, ): if allow_other_attributes is None: # Default behavior: allow other unmatched attributes in the node. allow_other_attributes = True + if allow_other_inputs is None: + # TODO(rama): Should we default to True? For now, we preserve the current behavior. + allow_other_inputs = False self.domain = domain self.op = StringConstantPattern(op) if isinstance(op, str) else op self.inputs = [_to_value_pattern(x) for x in inputs] self.attributes = attributes self.allow_other_attributes = allow_other_attributes + self.allow_other_inputs = allow_other_inputs # In the common case, domain and op are constants, which can be used to optimize matching. if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. @@ -557,7 +570,13 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] copied = NodePattern( - self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes + self.domain, + self.op, + inputs, + self.attributes, + outputs, + allow_other_attributes=self.allow_other_attributes, + allow_other_inputs=self.allow_other_inputs, ) node_map[self] = copied return copied @@ -1022,10 +1041,16 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node # TODO: Revisit this to handle optional trailing inputs better. - if len(node.inputs) != len(pattern_node.inputs): - return self.fail( - "Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" - ) + if pattern_node.allow_other_inputs: + if len(node.inputs) < len(pattern_node.inputs): + return self.fail( + f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})" + ) + else: + if len(node.inputs) != len(pattern_node.inputs): + return self.fail( + f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" + ) for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): # arg_pattern could be a Var, if it's the original arg.