From d751d37fbc0298119f0b9d78b009d06708998b4d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 23 Oct 2024 13:21:05 -0700 Subject: [PATCH 01/28] Fusions --- onnxscript/optimizer/_constant_folding.py | 2 +- onnxscript/rewriter/attention.py | 23 ++++++++ onnxscript/rewriter/group_query_attention.py | 16 ++++++ onnxscript/rewriter/pattern.py | 2 +- onnxscript/rewriter/rms_normalization.py | 56 ++++++++++++++++++++ onnxscript/rewriter/rotary_embedding.py | 29 ++++++++++ onnxscript/rewriter/skip_normalization.py | 32 +++++++++++ 7 files changed, 158 insertions(+), 2 deletions(-) create mode 100644 onnxscript/rewriter/attention.py create mode 100644 onnxscript/rewriter/group_query_attention.py create mode 100644 onnxscript/rewriter/rms_normalization.py create mode 100644 onnxscript/rewriter/rotary_embedding.py create mode 100644 onnxscript/rewriter/skip_normalization.py diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6a37efa16..c3c061281 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -711,7 +711,7 @@ def process_node(self, node: ir.Node): if any(x is None for x in input_values): return None - if any(input.size > self._input_size_limit for input in input_values): # type: ignore[union-attr] + if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr] if logger.isEnabledFor(logging.DEBUG): input_sizes = [input.size for input in input_values] # type: ignore[union-attr] logger.debug( diff --git a/onnxscript/rewriter/attention.py b/onnxscript/rewriter/attention.py new file mode 100644 index 000000000..dae2357b1 --- /dev/null +++ b/onnxscript/rewriter/attention.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnxscript.ir as ir +from onnxscript.rewriter import pattern + +def sdpa_pattern(op, query, key_transposed, value, query_scale, key_scale, mask): + scaled_query = op.Mul(query, query_scale) + scaled_key_transposed = op.Mul(key_transposed, key_scale) + attn_score = op.MatMul(scaled_query, scaled_key_transposed) + masked_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + +def sdpa(op, query, key_transposed, value, query_scale, key_scale, mask): + # TODO + # check if query_scale and key_scale are scalars == sqrt(sqrt(dimsize)) + return op.SDPA(query, key_transposed, value, query_scale, key_scale, mask, _domain="local") + +rule = pattern.RewriteRule(sdpa_pattern, sdpa) \ No newline at end of file diff --git a/onnxscript/rewriter/group_query_attention.py b/onnxscript/rewriter/group_query_attention.py new file mode 100644 index 000000000..8a00b2b7b --- /dev/null +++ b/onnxscript/rewriter/group_query_attention.py @@ -0,0 +1,16 @@ +""" +MultiHeadAttention: + +for Q, K, V: + MatMul + Reshape to B, S, 32, 64 + Transpose to B, 32, S, 64 + +Here, 32 is the number of heads and 64 is the head size + +Embed Q and K + +One of the embeddings (namely ) is also output of layer +and last two axes are transposed for SDPA + +""" \ No newline at end of file diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 059895ea8..d25e09a93 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1024,7 +1024,7 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: # TODO: Revisit this to handle optional trailing inputs better. if len(node.inputs) != len(pattern_node.inputs): return self.fail( - "Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" + f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" ) for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): diff --git a/onnxscript/rewriter/rms_normalization.py b/onnxscript/rewriter/rms_normalization.py new file mode 100644 index 000000000..02d9553ca --- /dev/null +++ b/onnxscript/rewriter/rms_normalization.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnxscript.ir as ir +from onnxscript.rewriter import pattern + +def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: + if val is None: + return None + const_value = val.const_value + if const_value is not None: + try: + return const_value.numpy() + except FileNotFoundError: + # External data is not available. + return None + return None + +def _get_scalar_value(val: ir.Value | None): + np_val = _get_numpy_value(val) + if np_val is not None and np_val.size == 1: + return np_val.item() + return None + +# Pattern to match against +def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): + x_cast = op.Cast(x, to=compute_dtype) + x_square = op.Pow(x_cast, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x_cast, reciprocal_rms) + normalized_cast = op.Cast(normalized, to=target_dtype) + return op.Mul(scale, normalized_cast) + +# Replacement +def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): + epsilon_value = _get_scalar_value(epsilon) + if not isinstance(epsilon_value, float): + return None + source_dtype = x.dtype + if source_dtype is None or source_dtype != target_dtype.value: + return None + return op.SimplifiedLayerNormalization ( + x, + scale, + axis=-1, + epsilon=epsilon_value, + stash_type=compute_dtype.value, + _domain="com.microsoft") + + +rule = pattern.RewriteRule(rms_norm_pattern, simplified_layer_norm) diff --git a/onnxscript/rewriter/rotary_embedding.py b/onnxscript/rewriter/rotary_embedding.py new file mode 100644 index 000000000..720bbc983 --- /dev/null +++ b/onnxscript/rewriter/rotary_embedding.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnxscript.ir as ir +from onnxscript.rewriter import pattern + +def rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + +def rotate_half(op, x, start1, end1, start2, end2): + # TODO: check if start1, end1, start2, end2 are valid + return op.RotateHalf(x, _domain="local") + +def embed_pattern(op, x, cos, sin, dc1, dc2, dc3, dc4): + return x * cos + op.RotateHalf(x, dc1, dc2, dc3, dc4, _domain="local") * sin + +def embed(op, x, cos, sin, **_): + return op.Embed(x, _domain="local") + +rule = pattern.RewriteRule(rotate_half_pattern, rotate_half) + +embed_rule = pattern.RewriteRule(embed_pattern, embed) \ No newline at end of file diff --git a/onnxscript/rewriter/skip_normalization.py b/onnxscript/rewriter/skip_normalization.py new file mode 100644 index 000000000..eecb9a60b --- /dev/null +++ b/onnxscript/rewriter/skip_normalization.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnxscript.ir as ir +from onnxscript.rewriter import pattern + +def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): + skip_sum = op.Add(input, skip) + normalized = op.SimplifiedLayerNormalization( + skip_sum, + gamma, + axis=-1, + epsilon=epsilon, + stash_type=stash_type, + _domain="com.microsoft") + return normalized, skip_sum + +def skip_normalization(op, input, skip, gamma, epsilon, stash_type): + normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( + input, + skip, + gamma, + epsilon=epsilon, + stash_type=stash_type, + _domain="com.microsoft", + _outputs=4 + ) + return normalized, skip_sum + +rule = pattern.RewriteRule(skip_norm_pattern, skip_normalization, matcher=pattern.SimplePatternMatcher) \ No newline at end of file From 8c4dff5505d75df34a3e5163e416128fa59c7814 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 23 Oct 2024 17:26:22 -0700 Subject: [PATCH 02/28] MultiHeadAttention fusion --- onnxscript/rewriter/group_query_attention.py | 16 ------ onnxscript/rewriter/multi_head_attention.py | 55 ++++++++++++++++++++ onnxscript/rewriter/pattern.py | 24 ++++++--- 3 files changed, 73 insertions(+), 22 deletions(-) delete mode 100644 onnxscript/rewriter/group_query_attention.py create mode 100644 onnxscript/rewriter/multi_head_attention.py diff --git a/onnxscript/rewriter/group_query_attention.py b/onnxscript/rewriter/group_query_attention.py deleted file mode 100644 index 8a00b2b7b..000000000 --- a/onnxscript/rewriter/group_query_attention.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -MultiHeadAttention: - -for Q, K, V: - MatMul - Reshape to B, S, 32, 64 - Transpose to B, 32, S, 64 - -Here, 32 is the number of heads and 64 is the head size - -Embed Q and K - -One of the embeddings (namely ) is also output of layer -and last two axes are transposed for SDPA - -""" \ No newline at end of file diff --git a/onnxscript/rewriter/multi_head_attention.py b/onnxscript/rewriter/multi_head_attention.py new file mode 100644 index 000000000..91fbf3334 --- /dev/null +++ b/onnxscript/rewriter/multi_head_attention.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnxscript.ir as ir +from onnxscript.rewriter import pattern + +""" +MultiHeadAttention: + +for Q, K, V: + MatMul + Reshape to B, S, 32, 64 + Transpose to B, 32, S, 64 + +Here, 32 is the number of heads and 64 is the head size + +Embed Q and K + +One of the embeddings (namely ) is also output of layer +and last two axes are transposed for SDPA + +""" + +def project_transpose_head(op, input, weight): + projected = op.MatMul(input, weight) + # Reshape from (B, S, D) to (B, S, H, D/H) + reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) + return transposed + +def multi_head_attention_pattern (op, input, query_weight, key_weight, value_weight): + query = project_transpose_head(op, input, query_weight) + query_rope = op.Embed(query, _domain="local") + key = project_transpose_head(op, input, key_weight) + key_rope = op.Embed(key, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) + key_reshaped_transposed = op.Transpose(key_reshaped) + key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) + value = project_transpose_head(op, input, value_weight) + attention = op.SDPA(query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local") + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, value, key_rope + +def multi_head_attention(op, input, query_weight, key_weight, value_weight): + # TODO: other checks and concatenation of weights + return op.MultiHeadAttention(input, query_weight, key_weight, value_weight, _domain="local", _outputs=3) + +rule = pattern.RewriteRule(multi_head_attention_pattern, multi_head_attention) \ No newline at end of file diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d25e09a93..408cf276d 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -225,6 +225,7 @@ def __call__( _version: int | None = None, _outputs: int | list[str | None] = 1, _allow_other_attributes: bool | None = None, + _allow_other_inputs: bool | None = None, **kwargs, ): if _version is not None: @@ -249,7 +250,7 @@ def __call__( inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( - opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes + opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes, _allow_other_inputs ) self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs @@ -472,15 +473,20 @@ def __init__( attributes: dict[str, AttrPattern], outputs: Sequence[str | None], allow_other_attributes: bool | None, + _allow_other_inputs: bool | None, ): if allow_other_attributes is None: # Default behavior: allow other unmatched attributes in the node. allow_other_attributes = True + if _allow_other_inputs is None: + # TODO(rama): Should we default to True? For now, we preserve the current behavior. + _allow_other_inputs = False self.domain = domain self.op = StringConstantPattern(op) if isinstance(op, str) else op self.inputs = [_to_value_pattern(x) for x in inputs] self.attributes = attributes self.allow_other_attributes = allow_other_attributes + self.allow_other_inputs = _allow_other_inputs # In the common case, domain and op are constants, which can be used to optimize matching. if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. @@ -557,7 +563,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] copied = NodePattern( - self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes + self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes, self.allow_other_inputs ) node_map[self] = copied return copied @@ -1022,10 +1028,16 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node # TODO: Revisit this to handle optional trailing inputs better. - if len(node.inputs) != len(pattern_node.inputs): - return self.fail( - f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" - ) + if pattern_node.allow_other_inputs: + if len(node.inputs) < len(pattern_node.inputs): + return self.fail( + f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})" + ) + else: + if len(node.inputs) != len(pattern_node.inputs): + return self.fail( + f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" + ) for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): # arg_pattern could be a Var, if it's the original arg. From 4d3ff901063cd6a1a07bc0d9b2c819525e7f10bc Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 6 Nov 2024 08:55:32 -0800 Subject: [PATCH 03/28] Move transformers optimization into onnxruntime folder --- onnxscript/rewriter/_ir_utils.py | 20 ++++++++++++ onnxscript/rewriter/no_op.py | 2 +- .../onnxruntime/_optimize_transformers.py | 30 ++++++++++++++++++ .../_optimize_transformers_test.py | 31 +++++++++++++++++++ .../rewriter/{ => onnxruntime}/attention.py | 0 .../{ => onnxruntime}/multi_head_attention.py | 0 .../{ => onnxruntime}/rms_normalization.py | 20 ++---------- .../{ => onnxruntime}/rotary_embedding.py | 4 +-- .../{ => onnxruntime}/skip_normalization.py | 0 9 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/_optimize_transformers.py create mode 100644 onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py rename onnxscript/rewriter/{ => onnxruntime}/attention.py (100%) rename onnxscript/rewriter/{ => onnxruntime}/multi_head_attention.py (100%) rename onnxscript/rewriter/{ => onnxruntime}/rms_normalization.py (68%) rename onnxscript/rewriter/{ => onnxruntime}/rotary_embedding.py (85%) rename onnxscript/rewriter/{ => onnxruntime}/skip_normalization.py (100%) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index bd353f388..8d4f05901 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations +import numpy as np import onnxscript.ir as ir from onnxscript.optimizer import basic_constant_propagation @@ -11,3 +12,22 @@ def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: if node is not None: basic_constant_propagation([node]) return value.const_value + +def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: + if val is None: + return None + const_value = val.const_value + if const_value is not None: + try: + return const_value.numpy() + except FileNotFoundError: + # External data is not available. + return None + return None + +def get_singleton_value(val: ir.Value | None): + '''Returns element of a single element tensor constant value, and None otherwise.''' + np_val = get_numpy_value(val) + if np_val is not None and np_val.size == 1: + return np_val.item() + return None \ No newline at end of file diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 6d25b0ed3..7c2d91635 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -24,7 +24,7 @@ def div_by_1(op, x): def dropout_zero(op, x): - return op.Dropout(x, ratio=0.0) + return op.Dropout(x, 0.0) def dropout_inference(op, x): diff --git a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/_optimize_transformers.py new file mode 100644 index 000000000..1eef9ccc6 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/_optimize_transformers.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter.onnxruntime import attention, multi_head_attention, rms_normalization, rotary_embedding, skip_normalization +from onnxscript.rewriter import no_op +from onnxscript.optimizer import _constant_folding, remove_unused_nodes + +def optimize(irmodel: ir.Model) -> None: + + def apply(rulename: str, rule): + count = rule.apply_to_model(irmodel) + print(f"{rulename} count: {count}") + + _constant_folding.fold_constants(irmodel, input_size_limit=5120000*4, output_size_limit=5120000*4) + + apply("Dropout", no_op.dropout_zero_rule) + remove_unused_nodes(irmodel) + + apply("RMS Normalization", rms_normalization.rule) + apply("Skip Normalization", skip_normalization.rule) + + _constant_folding.fold_constants(irmodel) + remove_unused_nodes(irmodel) + + apply("Attention", attention.rule) + apply("Rotate", rotary_embedding.rule) + apply("Embed", rotary_embedding.embed_rule) + apply("Multi-Head-Attention", multi_head_attention.rule) diff --git a/onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py new file mode 100644 index 000000000..14d5e5450 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +import onnxscript.ir as ir +import onnxscript.optimizer +import onnxscript.rewriter.onnxruntime._optimize_transformers as optimize_transformers + +def _get_smollm_model() -> ir.Model: + checkpoint = "HuggingFaceTB/SmolLM-1.7B" + device = "cpu" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) + inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device) + program = torch.onnx.export(model, inputs, "tsmodel.onnx", dynamo=True) + model = program.model + onnxscript.optimizer.optimize_ir(model) + return model + +class TestOptimizeTransformers(unittest.TestCase): + + def test_optimize_transformers(self): + model = _get_smollm_model() + optimize_transformers.optimize(model) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/onnxscript/rewriter/attention.py b/onnxscript/rewriter/onnxruntime/attention.py similarity index 100% rename from onnxscript/rewriter/attention.py rename to onnxscript/rewriter/onnxruntime/attention.py diff --git a/onnxscript/rewriter/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/multi_head_attention.py similarity index 100% rename from onnxscript/rewriter/multi_head_attention.py rename to onnxscript/rewriter/onnxruntime/multi_head_attention.py diff --git a/onnxscript/rewriter/rms_normalization.py b/onnxscript/rewriter/onnxruntime/rms_normalization.py similarity index 68% rename from onnxscript/rewriter/rms_normalization.py rename to onnxscript/rewriter/onnxruntime/rms_normalization.py index 02d9553ca..322297378 100644 --- a/onnxscript/rewriter/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/rms_normalization.py @@ -4,25 +4,9 @@ import numpy as np import onnxscript.ir as ir -from onnxscript.rewriter import pattern +from onnxscript.rewriter import _ir_utils, pattern -def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: - if val is None: - return None - const_value = val.const_value - if const_value is not None: - try: - return const_value.numpy() - except FileNotFoundError: - # External data is not available. - return None - return None -def _get_scalar_value(val: ir.Value | None): - np_val = _get_numpy_value(val) - if np_val is not None and np_val.size == 1: - return np_val.item() - return None # Pattern to match against def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): @@ -38,7 +22,7 @@ def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): # Replacement def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): - epsilon_value = _get_scalar_value(epsilon) + epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): return None source_dtype = x.dtype diff --git a/onnxscript/rewriter/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/rotary_embedding.py similarity index 85% rename from onnxscript/rewriter/rotary_embedding.py rename to onnxscript/rewriter/onnxruntime/rotary_embedding.py index 720bbc983..f99d329a5 100644 --- a/onnxscript/rewriter/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/rotary_embedding.py @@ -18,8 +18,8 @@ def rotate_half(op, x, start1, end1, start2, end2): # TODO: check if start1, end1, start2, end2 are valid return op.RotateHalf(x, _domain="local") -def embed_pattern(op, x, cos, sin, dc1, dc2, dc3, dc4): - return x * cos + op.RotateHalf(x, dc1, dc2, dc3, dc4, _domain="local") * sin +def embed_pattern(op, x, cos, sin): + return x * cos + op.RotateHalf(x, _domain="local") * sin def embed(op, x, cos, sin, **_): return op.Embed(x, _domain="local") diff --git a/onnxscript/rewriter/skip_normalization.py b/onnxscript/rewriter/onnxruntime/skip_normalization.py similarity index 100% rename from onnxscript/rewriter/skip_normalization.py rename to onnxscript/rewriter/onnxruntime/skip_normalization.py From 4a667f960f5419ecca18d6090b8c5a4674403dd2 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 6 Nov 2024 16:39:59 -0800 Subject: [PATCH 04/28] Support some SDPA variations --- onnxscript/optimizer/_constant_folding.py | 30 ++++++++++++------- .../onnxruntime/_optimize_transformers.py | 7 ++++- onnxscript/rewriter/onnxruntime/attention.py | 17 ++++++++++- .../onnxruntime/multi_head_attention.py | 2 +- 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index c3c061281..6c1844b18 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -292,20 +292,30 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return default -# TODO(rama): The following should not be necessary. Generic incremental shape-inference -# should handle this. This essentially implements type/shape-inference for Cast op. + @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) output = _get_output(node, 0) - if input is not None and output is not None: - input_shape = input.shape - if input_shape is not None: - output.shape = input_shape.copy() - if output is not None: - output_dtype = _get_int_attribute(node, "to", None) - if output_dtype is not None: - output.type = ir.TensorType(ir.DataType(output_dtype)) + + if input is None or output is None: + return None + + # TODO(rama): Parts of the following logic (implementing type/shape inference + # for Cast op) should be unnecessary. Generic incremental shape-inference + # should handle this. Only the optimization to eliminate redundant Cast ops + # should be needed here. + + input_shape = input.shape + if input_shape is not None: + output.shape = input_shape.copy() + + input_dtype = _get_input_element_type(node, 0) + output_dtype = _get_int_attribute(node, "to", None) + if output_dtype is not None: + if input_dtype == output_dtype: + return op.Identity(input) + output.type = ir.TensorType(ir.DataType(output_dtype)) return None diff --git a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/_optimize_transformers.py index 1eef9ccc6..cbe96d3b6 100644 --- a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/_optimize_transformers.py @@ -6,6 +6,10 @@ from onnxscript.rewriter.onnxruntime import attention, multi_head_attention, rms_normalization, rotary_embedding, skip_normalization from onnxscript.rewriter import no_op from onnxscript.optimizer import _constant_folding, remove_unused_nodes +from onnxscript.rewriter.llama_rule_sets import ExpandIdentity +import onnxscript.rewriter.pattern as pattern + +expand_rule = pattern.make_rewrite_rule_from_class(ExpandIdentity) def optimize(irmodel: ir.Model) -> None: @@ -16,6 +20,7 @@ def apply(rulename: str, rule): _constant_folding.fold_constants(irmodel, input_size_limit=5120000*4, output_size_limit=5120000*4) apply("Dropout", no_op.dropout_zero_rule) + apply("Expand", expand_rule) remove_unused_nodes(irmodel) apply("RMS Normalization", rms_normalization.rule) @@ -24,7 +29,7 @@ def apply(rulename: str, rule): _constant_folding.fold_constants(irmodel) remove_unused_nodes(irmodel) - apply("Attention", attention.rule) + apply("Attention", attention.rules) apply("Rotate", rotary_embedding.rule) apply("Embed", rotary_embedding.embed_rule) apply("Multi-Head-Attention", multi_head_attention.rule) diff --git a/onnxscript/rewriter/onnxruntime/attention.py b/onnxscript/rewriter/onnxruntime/attention.py index dae2357b1..3522b0bc7 100644 --- a/onnxscript/rewriter/onnxruntime/attention.py +++ b/onnxscript/rewriter/onnxruntime/attention.py @@ -20,4 +20,19 @@ def sdpa(op, query, key_transposed, value, query_scale, key_scale, mask): # check if query_scale and key_scale are scalars == sqrt(sqrt(dimsize)) return op.SDPA(query, key_transposed, value, query_scale, key_scale, mask, _domain="local") -rule = pattern.RewriteRule(sdpa_pattern, sdpa) \ No newline at end of file +def sdpa_pattern2(op, query, key_transposed, value, scale): + attn_score = op.MatMul(query, key_transposed) + masked_score = op.Div(attn_score, scale) + attn_weight = op.Softmax(masked_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + +def sdpa2(op, query, key_transposed, value, scale): + # TODO + # check if scale == (sqrt(dimsize)) + return op.SDPA2(query, key_transposed, value, scale, _domain="local") + +rule = pattern.RewriteRule(sdpa_pattern, sdpa) +rule2 = pattern.RewriteRule(sdpa_pattern2, sdpa2) + +rules = pattern.RewriteRuleSet([rule, rule2]) \ No newline at end of file diff --git a/onnxscript/rewriter/onnxruntime/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/multi_head_attention.py index 91fbf3334..fe22981f6 100644 --- a/onnxscript/rewriter/onnxruntime/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/multi_head_attention.py @@ -46,7 +46,7 @@ def multi_head_attention_pattern (op, input, query_weight, key_weight, value_wei attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) - return attention_reshaped, value, key_rope + return attention_reshaped, key_rope, value def multi_head_attention(op, input, query_weight, key_weight, value_weight): # TODO: other checks and concatenation of weights From 33c37531c6256e7e61473c1656c346a9d95d282b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 7 Nov 2024 16:53:52 -0800 Subject: [PATCH 05/28] Add variations of rules for SDPA --- onnxscript/optimizer/_constant_folding.py | 6 ++++ .../onnxruntime/_optimize_transformers.py | 18 ++++++++-- onnxscript/rewriter/onnxruntime/attention.py | 18 ++++++++-- .../onnxruntime/multi_head_attention.py | 33 +++++++++++++++---- .../rewriter/onnxruntime/rotary_embedding.py | 2 +- 5 files changed, 66 insertions(+), 11 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 6c1844b18..28a165d6c 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -422,6 +422,12 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: state.set_sym_value(output, list(node.inputs)) return None +@register("Concat") +def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + inputs = node.inputs + if (len(inputs) == 1): + return op.Identity(inputs[0]) + return None @register("ConcatFromSequence") def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: diff --git a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/_optimize_transformers.py index cbe96d3b6..4641a7e40 100644 --- a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/_optimize_transformers.py @@ -6,10 +6,24 @@ from onnxscript.rewriter.onnxruntime import attention, multi_head_attention, rms_normalization, rotary_embedding, skip_normalization from onnxscript.rewriter import no_op from onnxscript.optimizer import _constant_folding, remove_unused_nodes -from onnxscript.rewriter.llama_rule_sets import ExpandIdentity +from onnxscript.rewriter.llama_rule_sets import ExpandIdentity, TransposeIdentity import onnxscript.rewriter.pattern as pattern expand_rule = pattern.make_rewrite_rule_from_class(ExpandIdentity) +transpose_rule = pattern.make_rewrite_rule_from_class(TransposeIdentity) + +def basic_optimize(irmodel: ir.Model) -> None: + + def apply(rulename: str, rule): + count = rule.apply_to_model(irmodel) + print(f"{rulename} count: {count}") + + _constant_folding.fold_constants(irmodel, input_size_limit=5120000*4, output_size_limit=5120000*4) + + apply("Dropout", no_op.dropout_zero_rule) + apply("Expand", expand_rule) + apply("Transpose", transpose_rule) + remove_unused_nodes(irmodel) def optimize(irmodel: ir.Model) -> None: @@ -32,4 +46,4 @@ def apply(rulename: str, rule): apply("Attention", attention.rules) apply("Rotate", rotary_embedding.rule) apply("Embed", rotary_embedding.embed_rule) - apply("Multi-Head-Attention", multi_head_attention.rule) + apply("Multi-Head-Attention", multi_head_attention.rules) diff --git a/onnxscript/rewriter/onnxruntime/attention.py b/onnxscript/rewriter/onnxruntime/attention.py index 3522b0bc7..5d089d18a 100644 --- a/onnxscript/rewriter/onnxruntime/attention.py +++ b/onnxscript/rewriter/onnxruntime/attention.py @@ -30,9 +30,23 @@ def sdpa_pattern2(op, query, key_transposed, value, scale): def sdpa2(op, query, key_transposed, value, scale): # TODO # check if scale == (sqrt(dimsize)) - return op.SDPA2(query, key_transposed, value, scale, _domain="local") + return op.SDPA(query, key_transposed, value, scale, _domain="local") + +def sdpa_pattern3(op, query, key_transposed, value, scale, mask): + attn_score = op.MatMul(query, key_transposed) + scaled_score = op.Div(attn_score, scale) + masked_score = op.Add(scaled_score, mask) + attn_weight = op.Softmax(masked_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + +def sdpa3(op, query, key_transposed, value, scale, mask): + # TODO + # check if scale == (sqrt(dimsize)) + return op.SDPA(query, key_transposed, value, scale, mask, _domain="local") rule = pattern.RewriteRule(sdpa_pattern, sdpa) rule2 = pattern.RewriteRule(sdpa_pattern2, sdpa2) +rule3 = pattern.RewriteRule(sdpa_pattern3, sdpa3) -rules = pattern.RewriteRuleSet([rule, rule2]) \ No newline at end of file +rules = pattern.RewriteRuleSet([rule, rule2, rule3]) \ No newline at end of file diff --git a/onnxscript/rewriter/onnxruntime/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/multi_head_attention.py index fe22981f6..1895cca3d 100644 --- a/onnxscript/rewriter/onnxruntime/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/multi_head_attention.py @@ -31,11 +31,11 @@ def project_transpose_head(op, input, weight): transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) return transposed -def multi_head_attention_pattern (op, input, query_weight, key_weight, value_weight): +def multi_head_attention_pattern (op, input, query_weight, key_weight, value_weight, cos, sin): query = project_transpose_head(op, input, query_weight) - query_rope = op.Embed(query, _domain="local") + query_rope = op.Embed(query, cos, sin, _domain="local") key = project_transpose_head(op, input, key_weight) - key_rope = op.Embed(key, _domain="local") + key_rope = op.Embed(key, cos, sin, _domain="local") # Transpose last two axes of key_rope to compute dot-product via matmul. key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) key_reshaped_transposed = op.Transpose(key_reshaped) @@ -48,8 +48,29 @@ def multi_head_attention_pattern (op, input, query_weight, key_weight, value_wei attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) return attention_reshaped, key_rope, value -def multi_head_attention(op, input, query_weight, key_weight, value_weight): +def multi_head_attention_pattern2 (op, input, query_weight, key_weight, value_weight, cos, sin): + """Variation of first pattern with Reshape omitted.""" + query = project_transpose_head(op, input, query_weight) + query_rope = op.Embed(query, cos, sin, _domain="local") + key = project_transpose_head(op, input, key_weight) + key_rope = op.Embed(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + # Reshape omitted here. + key_transposed = op.Transpose(key_rope) + # Reshape omitted here + value = project_transpose_head(op, input, value_weight) + attention = op.SDPA(query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local") + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, key_rope, value + +def multi_head_attention(op, input, query_weight, key_weight, value_weight, cos, sin,): # TODO: other checks and concatenation of weights - return op.MultiHeadAttention(input, query_weight, key_weight, value_weight, _domain="local", _outputs=3) + return op.MultiHeadAttention(input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3) + +rule = pattern.RewriteRule(multi_head_attention_pattern, multi_head_attention) +rule2 = pattern.RewriteRule(multi_head_attention_pattern2, multi_head_attention) -rule = pattern.RewriteRule(multi_head_attention_pattern, multi_head_attention) \ No newline at end of file +rules = pattern.RewriteRuleSet([rule, rule2]) \ No newline at end of file diff --git a/onnxscript/rewriter/onnxruntime/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/rotary_embedding.py index f99d329a5..872405106 100644 --- a/onnxscript/rewriter/onnxruntime/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/rotary_embedding.py @@ -22,7 +22,7 @@ def embed_pattern(op, x, cos, sin): return x * cos + op.RotateHalf(x, _domain="local") * sin def embed(op, x, cos, sin, **_): - return op.Embed(x, _domain="local") + return op.Embed(x, cos, sin, _domain="local") rule = pattern.RewriteRule(rotate_half_pattern, rotate_half) From 404e5c3a8cebeff55bfba0284d20058f9bee11ed Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 7 Nov 2024 17:35:42 -0800 Subject: [PATCH 06/28] Add attention scale validation --- onnxscript/rewriter/onnxruntime/attention.py | 43 +++++++++++++++---- .../onnxruntime/multi_head_attention.py | 13 ++++-- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/attention.py b/onnxscript/rewriter/onnxruntime/attention.py index 5d089d18a..310d6d7b6 100644 --- a/onnxscript/rewriter/onnxruntime/attention.py +++ b/onnxscript/rewriter/onnxruntime/attention.py @@ -2,9 +2,10 @@ # Licensed under the MIT License. from __future__ import annotations +import math import numpy as np import onnxscript.ir as ir -from onnxscript.rewriter import pattern +from onnxscript.rewriter import pattern, _ir_utils def sdpa_pattern(op, query, key_transposed, value, query_scale, key_scale, mask): scaled_query = op.Mul(query, query_scale) @@ -16,9 +17,21 @@ def sdpa_pattern(op, query, key_transposed, value, query_scale, key_scale, mask) return attn_output def sdpa(op, query, key_transposed, value, query_scale, key_scale, mask): - # TODO - # check if query_scale and key_scale are scalars == sqrt(sqrt(dimsize)) - return op.SDPA(query, key_transposed, value, query_scale, key_scale, mask, _domain="local") + # Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(dimsize)) + query_scale_value = _ir_utils.get_singleton_value(query_scale) + key_scale_value = _ir_utils.get_singleton_value(key_scale) + if not isinstance(query_scale_value, float) or not isinstance(key_scale_value, float): + return None + scaling_factor = query_scale_value * key_scale_value + scaling_factor = 1.0 / (scaling_factor * scaling_factor) + # If the dim_size is not statically known, we cannot check if the scale is correct: + if query is None or query.shape is None or len(query.shape) < 2: + return None + dimsize = query.shape[-1] + if not isinstance(dimsize, int) or not math.isclose(scaling_factor, dimsize, abs_tol=1e-3): + return None + return op.SDPA(query, key_transposed, value, mask, _domain="local") + def sdpa_pattern2(op, query, key_transposed, value, scale): attn_score = op.MatMul(query, key_transposed) @@ -27,9 +40,23 @@ def sdpa_pattern2(op, query, key_transposed, value, scale): attn_output = op.MatMul(attn_weight, value) return attn_output +def valid_post_scale(scale, query) -> bool: + # Checks if scale == (sqrt(dimsize)) + scale_value = _ir_utils.get_singleton_value(scale) + if not isinstance(scale_value, float): + return False + scaling_factor = scale_value * scale_value + # If the dim_size is not statically known, we cannot check if the scale is correct: + if query is None or query.shape is None or len(query.shape) < 2: + return False + dimsize = query.shape[-1] + if not isinstance(dimsize, int) or not math.isclose(scaling_factor, dimsize, abs_tol=1e-3): + return False + return True + def sdpa2(op, query, key_transposed, value, scale): - # TODO - # check if scale == (sqrt(dimsize)) + if not valid_post_scale(scale, query): + return None return op.SDPA(query, key_transposed, value, scale, _domain="local") def sdpa_pattern3(op, query, key_transposed, value, scale, mask): @@ -41,8 +68,8 @@ def sdpa_pattern3(op, query, key_transposed, value, scale, mask): return attn_output def sdpa3(op, query, key_transposed, value, scale, mask): - # TODO - # check if scale == (sqrt(dimsize)) + if not valid_post_scale(scale, query): + return None return op.SDPA(query, key_transposed, value, scale, mask, _domain="local") rule = pattern.RewriteRule(sdpa_pattern, sdpa) diff --git a/onnxscript/rewriter/onnxruntime/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/multi_head_attention.py index 1895cca3d..08d1c7c6c 100644 --- a/onnxscript/rewriter/onnxruntime/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/multi_head_attention.py @@ -9,10 +9,17 @@ """ MultiHeadAttention: +D: input embedding dimension +H: number of heads +d_h: head size +usually, D = H * d_h + +thus, weights are usually of shape (D, D) and (D, D) and (D, D) + for Q, K, V: - MatMul - Reshape to B, S, 32, 64 - Transpose to B, 32, S, 64 + MatMul (Input, W for Q, K, V) => B, S, D + Reshape to B, S, 32, 64 (that is, B, S, H, d_h) + Transpose to B, 32, S, 64 (that is, B, H, S, d_h) Here, 32 is the number of heads and 64 is the head size From e98682fae97d756e9d60338ad1a7b3b9f6922223 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 7 Nov 2024 19:11:46 -0800 Subject: [PATCH 07/28] Add validation conditions for rotary embedding --- .../rewriter/onnxruntime/rotary_embedding.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/rotary_embedding.py index 872405106..3732a653f 100644 --- a/onnxscript/rewriter/onnxruntime/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/rotary_embedding.py @@ -4,7 +4,7 @@ import numpy as np import onnxscript.ir as ir -from onnxscript.rewriter import pattern +from onnxscript.rewriter import pattern, _ir_utils def rotate_half_pattern(op, x, start1, end1, start2, end2): # Slice(input, starts, ends, axes, steps) @@ -15,8 +15,19 @@ def rotate_half_pattern(op, x, start1, end1, start2, end2): return rotated_x def rotate_half(op, x, start1, end1, start2, end2): - # TODO: check if start1, end1, start2, end2 are valid - return op.RotateHalf(x, _domain="local") + # Check that x is being split into two equal halves: + start1_val = _ir_utils.get_singleton_value(start1) + end1_val = _ir_utils.get_singleton_value(end1) + start2_val = _ir_utils.get_singleton_value(start2) + end2_val = _ir_utils.get_singleton_value(end2) + + if x is None or x.shape is None or len(x.shape) != 4: + return None + dim_size = x.shape[3] + half_dim_size = dim_size // 2 + if start1_val == 0 and end1_val == half_dim_size and start2_val == half_dim_size and end2_val >= dim_size: + return op.RotateHalf(x, _domain="local") + return None def embed_pattern(op, x, cos, sin): return x * cos + op.RotateHalf(x, _domain="local") * sin From 001bb59bca40828a4c488f9be2061bdae25d1874 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 7 Nov 2024 22:09:32 -0800 Subject: [PATCH 08/28] Add tests --- .../_optimize_transformers_test.py | 31 ------ .../xformers/_optimize_transformers_test.py | 95 +++++++++++++++++++ 2 files changed, 95 insertions(+), 31 deletions(-) delete mode 100644 onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py diff --git a/onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py deleted file mode 100644 index 14d5e5450..000000000 --- a/onnxscript/rewriter/onnxruntime/_optimize_transformers_test.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -import onnxscript.ir as ir -import onnxscript.optimizer -import onnxscript.rewriter.onnxruntime._optimize_transformers as optimize_transformers - -def _get_smollm_model() -> ir.Model: - checkpoint = "HuggingFaceTB/SmolLM-1.7B" - device = "cpu" - tokenizer = AutoTokenizer.from_pretrained(checkpoint) - model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) - inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device) - program = torch.onnx.export(model, inputs, "tsmodel.onnx", dynamo=True) - model = program.model - onnxscript.optimizer.optimize_ir(model) - return model - -class TestOptimizeTransformers(unittest.TestCase): - - def test_optimize_transformers(self): - model = _get_smollm_model() - optimize_transformers.optimize(model) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py new file mode 100644 index 000000000..5d8f1e0b6 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from transformers import LlamaConfig +import transformers.models.llama.modeling_llama as modeling_llama +import torch +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime import _optimize_transformers as optimize_transformers + +# Create a LlamaConfig object with the desired parameters +_config = LlamaConfig( + _name_or_path="HuggingFaceTB/SmolLM-1.7B", + architectures=["LlamaForCausalLM"], + attention_bias=False, + attention_dropout=0.0, + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=2048, + initializer_range=0.02, + intermediate_size=8192, + max_position_embeddings=2048, + model_type="llama", + num_attention_heads=32, + num_hidden_layers=24, + num_key_value_heads=32, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=True, + torch_dtype="float32", + transformers_version="4.37.2", + use_cache=True, + vocab_size=49152 +) + +# Create a LlamaAttention object with the desired parameters +# model = modeling_llama.LlamaAttention(_config, 0) +model = modeling_llama.LlamaSdpaAttention(_config, 0) + +# Dimensions for inputs: +_batch_size = 1 +_seq_len = 10 +_hidden_size = _config.hidden_size +_num_attention_heads = _config.num_attention_heads +dim = _hidden_size // _num_attention_heads + +# Generate inputs: +_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32) +_attention_mask = torch.rand(_batch_size, 1, _seq_len, _seq_len, dtype=torch.float32) +_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) + +# Get model in ONNX format +def _get_model(llama_attention_class, with_mask: bool): + model = llama_attention_class(_config, 0) + if with_mask: + inputs = (_hidden_states, _attention_mask, _position_ids) + else: + inputs = (_hidden_states, None, _position_ids) + exported = torch.onnx.export(model, inputs, dynamo=True) + onnxscript.optimizer.optimize(exported.model) + # optimize_transformers.basic_optimize(exported.model) + return exported.model + +class TestOptimizeTransformers(unittest.TestCase): + + def test_attention(self): + model = _get_model(modeling_llama.LlamaAttention, with_mask=False) + optimize_transformers.optimize(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("MultiHeadAttention", op_types) + + def test_masked_attention(self): + model = _get_model(modeling_llama.LlamaAttention, with_mask=True) + optimize_transformers.optimize(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("MultiHeadAttention", op_types) + + def test_sdpa_attention(self): + model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=False) + optimize_transformers.optimize(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("MultiHeadAttention", op_types) + + def test_masked_sdpa_attention(self): + model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=True) + optimize_transformers.optimize(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("MultiHeadAttention", op_types) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 40b9052f72ab722634f65a2da54dbd6dbce65905 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 07:24:07 -0800 Subject: [PATCH 09/28] Move into new xformers folder --- .../onnxruntime/{ => xformers}/_optimize_transformers.py | 2 +- .../onnxruntime/xformers/_optimize_transformers_test.py | 2 +- onnxscript/rewriter/onnxruntime/{ => xformers}/attention.py | 0 .../rewriter/onnxruntime/{ => xformers}/multi_head_attention.py | 0 .../rewriter/onnxruntime/{ => xformers}/rms_normalization.py | 0 .../rewriter/onnxruntime/{ => xformers}/rotary_embedding.py | 0 .../rewriter/onnxruntime/{ => xformers}/skip_normalization.py | 0 7 files changed, 2 insertions(+), 2 deletions(-) rename onnxscript/rewriter/onnxruntime/{ => xformers}/_optimize_transformers.py (92%) rename onnxscript/rewriter/onnxruntime/{ => xformers}/attention.py (100%) rename onnxscript/rewriter/onnxruntime/{ => xformers}/multi_head_attention.py (100%) rename onnxscript/rewriter/onnxruntime/{ => xformers}/rms_normalization.py (100%) rename onnxscript/rewriter/onnxruntime/{ => xformers}/rotary_embedding.py (100%) rename onnxscript/rewriter/onnxruntime/{ => xformers}/skip_normalization.py (100%) diff --git a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py similarity index 92% rename from onnxscript/rewriter/onnxruntime/_optimize_transformers.py rename to onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index 4641a7e40..57a01f135 100644 --- a/onnxscript/rewriter/onnxruntime/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -3,10 +3,10 @@ from __future__ import annotations import onnxscript.ir as ir -from onnxscript.rewriter.onnxruntime import attention, multi_head_attention, rms_normalization, rotary_embedding, skip_normalization from onnxscript.rewriter import no_op from onnxscript.optimizer import _constant_folding, remove_unused_nodes from onnxscript.rewriter.llama_rule_sets import ExpandIdentity, TransposeIdentity +from onnxscript.rewriter.onnxruntime.xformers import attention, multi_head_attention, rms_normalization, rotary_embedding, skip_normalization import onnxscript.rewriter.pattern as pattern expand_rule = pattern.make_rewrite_rule_from_class(ExpandIdentity) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 5d8f1e0b6..67a72a311 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -7,7 +7,7 @@ import transformers.models.llama.modeling_llama as modeling_llama import torch import onnxscript.optimizer -from onnxscript.rewriter.onnxruntime import _optimize_transformers as optimize_transformers +from onnxscript.rewriter.onnxruntime.xformers import _optimize_transformers as optimize_transformers # Create a LlamaConfig object with the desired parameters _config = LlamaConfig( diff --git a/onnxscript/rewriter/onnxruntime/attention.py b/onnxscript/rewriter/onnxruntime/xformers/attention.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/attention.py rename to onnxscript/rewriter/onnxruntime/xformers/attention.py diff --git a/onnxscript/rewriter/onnxruntime/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/multi_head_attention.py rename to onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py diff --git a/onnxscript/rewriter/onnxruntime/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/rms_normalization.py rename to onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py diff --git a/onnxscript/rewriter/onnxruntime/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/rotary_embedding.py rename to onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py diff --git a/onnxscript/rewriter/onnxruntime/skip_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/skip_normalization.py rename to onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py From 94ce2f301c632bf8f9a2d6dab16015d347ed1f35 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 08:38:03 -0800 Subject: [PATCH 10/28] Add dropout to optimizer --- onnxscript/optimizer/_constant_folding.py | 25 +++++++++++++++++++ .../optimizer/_constant_folding_test.py | 22 ++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 28a165d6c..35432de77 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -424,11 +424,36 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: @register("Concat") def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Concat node with a single input by Identity""" inputs = node.inputs if (len(inputs) == 1): return op.Identity(inputs[0]) return None +@register("Dropout", version=(12, None)) +def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Dropout by Identity when applicable.""" + if len(node.outputs) != 1: + # If output mask is requested, optimization is more complex. + # TODO: handle this case. But unlikely to be needed in practice. + return None + inputs = node.inputs + if (len(inputs) <= 2) or inputs[2] is None: + # No training_mode specified: + return op.Identity(inputs[0]) + if _get_bool_value(inputs[2]) is False: + # training_mode is False: dropout is not applied. + return op.Identity(inputs[0]) + ratio = _get_numpy_value(inputs[1]) + if ratio is None: + return None + if ratio.size != 1: # Only scalar dropout ratio is supported. + return None + if ratio.item() == 0: + # dropout ratio is 0: dropout is not applied. + return op.Identity(inputs[0]) + return None + @register("ConcatFromSequence") def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index b80f01c8f..60f18316b 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -394,6 +394,28 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( self.assertEqual(optimized.graph.node[6].op_type, "Concat") onnx.checker.check_model(optimized) + @parameterized.parameterized.expand( + [ + ("output = Dropout(input)",), + ("output = Dropout(input, zero, true)",), + ("output = Dropout(input, half)",), + ("output = Dropout(input, half, false)",), + ] + ) + def test_dropout_identity(self, dropout_node: str): + if not self.using_ir: + return + model = onnx.parser.parse_model(f""" + + agraph (float[N] input) => (float[N] output) + + {{ + {dropout_node} + }} + """) + optimized = self._fold(model) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].op_type, "Identity") if __name__ == "__main__": unittest.main() From bf3b64af4054d1c6ffda0df15bf9aa78de657f5d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 08:38:54 -0800 Subject: [PATCH 11/28] Run lint --- onnxscript/optimizer/_constant_folding.py | 12 +- .../optimizer/_constant_folding_test.py | 1 + onnxscript/rewriter/_ir_utils.py | 7 +- .../xformers/_optimize_transformers.py | 24 ++-- .../xformers/_optimize_transformers_test.py | 24 ++-- .../onnxruntime/xformers/attention.py | 18 ++- .../xformers/multi_head_attention.py | 111 ++++++++++-------- .../onnxruntime/xformers/rms_normalization.py | 13 +- .../onnxruntime/xformers/rotary_embedding.py | 18 ++- .../xformers/skip_normalization.py | 16 ++- onnxscript/rewriter/pattern.py | 16 ++- 11 files changed, 164 insertions(+), 96 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 35432de77..e9276cb32 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -292,15 +292,14 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return default - @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) output = _get_output(node, 0) - + if input is None or output is None: return None - + # TODO(rama): Parts of the following logic (implementing type/shape inference # for Cast op) should be unnecessary. Generic incremental shape-inference # should handle this. Only the optimization to eliminate redundant Cast ops @@ -422,14 +421,16 @@ def sequence_construct(node: ir.Node, op, state: OptimizerState) -> ReturnValue: state.set_sym_value(output, list(node.inputs)) return None + @register("Concat") def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Concat node with a single input by Identity""" inputs = node.inputs - if (len(inputs) == 1): + if len(inputs) == 1: return op.Identity(inputs[0]) return None + @register("Dropout", version=(12, None)) def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Dropout by Identity when applicable.""" @@ -447,13 +448,14 @@ def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: ratio = _get_numpy_value(inputs[1]) if ratio is None: return None - if ratio.size != 1: # Only scalar dropout ratio is supported. + if ratio.size != 1: # Only scalar dropout ratio is supported. return None if ratio.item() == 0: # dropout ratio is 0: dropout is not applied. return op.Identity(inputs[0]) return None + @register("ConcatFromSequence") def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 60f18316b..58d37b952 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -417,5 +417,6 @@ def test_dropout_identity(self, dropout_node: str): self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].op_type, "Identity") + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 8d4f05901..eadb67f0a 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np + import onnxscript.ir as ir from onnxscript.optimizer import basic_constant_propagation @@ -13,6 +14,7 @@ def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: basic_constant_propagation([node]) return value.const_value + def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: if val is None: return None @@ -25,9 +27,10 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None return None + def get_singleton_value(val: ir.Value | None): - '''Returns element of a single element tensor constant value, and None otherwise.''' + """Returns element of a single element tensor constant value, and None otherwise.""" np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: return np_val.item() - return None \ No newline at end of file + return None diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index 57a01f135..d637d75fa 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -3,35 +3,45 @@ from __future__ import annotations import onnxscript.ir as ir -from onnxscript.rewriter import no_op +import onnxscript.rewriter.pattern as pattern from onnxscript.optimizer import _constant_folding, remove_unused_nodes +from onnxscript.rewriter import no_op from onnxscript.rewriter.llama_rule_sets import ExpandIdentity, TransposeIdentity -from onnxscript.rewriter.onnxruntime.xformers import attention, multi_head_attention, rms_normalization, rotary_embedding, skip_normalization -import onnxscript.rewriter.pattern as pattern +from onnxscript.rewriter.onnxruntime.xformers import ( + attention, + multi_head_attention, + rms_normalization, + rotary_embedding, + skip_normalization, +) expand_rule = pattern.make_rewrite_rule_from_class(ExpandIdentity) transpose_rule = pattern.make_rewrite_rule_from_class(TransposeIdentity) -def basic_optimize(irmodel: ir.Model) -> None: +def basic_optimize(irmodel: ir.Model) -> None: def apply(rulename: str, rule): count = rule.apply_to_model(irmodel) print(f"{rulename} count: {count}") - _constant_folding.fold_constants(irmodel, input_size_limit=5120000*4, output_size_limit=5120000*4) + _constant_folding.fold_constants( + irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4 + ) apply("Dropout", no_op.dropout_zero_rule) apply("Expand", expand_rule) apply("Transpose", transpose_rule) remove_unused_nodes(irmodel) -def optimize(irmodel: ir.Model) -> None: +def optimize(irmodel: ir.Model) -> None: def apply(rulename: str, rule): count = rule.apply_to_model(irmodel) print(f"{rulename} count: {count}") - _constant_folding.fold_constants(irmodel, input_size_limit=5120000*4, output_size_limit=5120000*4) + _constant_folding.fold_constants( + irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4 + ) apply("Dropout", no_op.dropout_zero_rule) apply("Expand", expand_rule) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 67a72a311..04ff18fa5 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -3,11 +3,15 @@ from __future__ import annotations import unittest -from transformers import LlamaConfig -import transformers.models.llama.modeling_llama as modeling_llama + import torch +import transformers.models.llama.modeling_llama as modeling_llama +from transformers import LlamaConfig + import onnxscript.optimizer -from onnxscript.rewriter.onnxruntime.xformers import _optimize_transformers as optimize_transformers +from onnxscript.rewriter.onnxruntime.xformers import ( + _optimize_transformers as optimize_transformers, +) # Create a LlamaConfig object with the desired parameters _config = LlamaConfig( @@ -34,7 +38,7 @@ torch_dtype="float32", transformers_version="4.37.2", use_cache=True, - vocab_size=49152 + vocab_size=49152, ) # Create a LlamaAttention object with the desired parameters @@ -53,6 +57,7 @@ _attention_mask = torch.rand(_batch_size, 1, _seq_len, _seq_len, dtype=torch.float32) _position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) + # Get model in ONNX format def _get_model(llama_attention_class, with_mask: bool): model = llama_attention_class(_config, 0) @@ -65,8 +70,8 @@ def _get_model(llama_attention_class, with_mask: bool): # optimize_transformers.basic_optimize(exported.model) return exported.model -class TestOptimizeTransformers(unittest.TestCase): +class TestOptimizeTransformers(unittest.TestCase): def test_attention(self): model = _get_model(modeling_llama.LlamaAttention, with_mask=False) optimize_transformers.optimize(model) @@ -77,19 +82,20 @@ def test_masked_attention(self): model = _get_model(modeling_llama.LlamaAttention, with_mask=True) optimize_transformers.optimize(model) op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) + self.assertIn("MultiHeadAttention", op_types) def test_sdpa_attention(self): model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=False) optimize_transformers.optimize(model) op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) + self.assertIn("MultiHeadAttention", op_types) def test_masked_sdpa_attention(self): model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=True) optimize_transformers.optimize(model) op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) + self.assertIn("MultiHeadAttention", op_types) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/attention.py b/onnxscript/rewriter/onnxruntime/xformers/attention.py index 310d6d7b6..f88698a82 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/attention.py @@ -3,9 +3,9 @@ from __future__ import annotations import math -import numpy as np -import onnxscript.ir as ir -from onnxscript.rewriter import pattern, _ir_utils + +from onnxscript.rewriter import _ir_utils, pattern + def sdpa_pattern(op, query, key_transposed, value, query_scale, key_scale, mask): scaled_query = op.Mul(query, query_scale) @@ -16,6 +16,7 @@ def sdpa_pattern(op, query, key_transposed, value, query_scale, key_scale, mask) attn_output = op.MatMul(attn_weight, value) return attn_output + def sdpa(op, query, key_transposed, value, query_scale, key_scale, mask): # Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(dimsize)) query_scale_value = _ir_utils.get_singleton_value(query_scale) @@ -40,6 +41,7 @@ def sdpa_pattern2(op, query, key_transposed, value, scale): attn_output = op.MatMul(attn_weight, value) return attn_output + def valid_post_scale(scale, query) -> bool: # Checks if scale == (sqrt(dimsize)) scale_value = _ir_utils.get_singleton_value(scale) @@ -52,13 +54,15 @@ def valid_post_scale(scale, query) -> bool: dimsize = query.shape[-1] if not isinstance(dimsize, int) or not math.isclose(scaling_factor, dimsize, abs_tol=1e-3): return False - return True - + return True + + def sdpa2(op, query, key_transposed, value, scale): if not valid_post_scale(scale, query): return None return op.SDPA(query, key_transposed, value, scale, _domain="local") + def sdpa_pattern3(op, query, key_transposed, value, scale, mask): attn_score = op.MatMul(query, key_transposed) scaled_score = op.Div(attn_score, scale) @@ -67,13 +71,15 @@ def sdpa_pattern3(op, query, key_transposed, value, scale, mask): attn_output = op.MatMul(attn_weight, value) return attn_output + def sdpa3(op, query, key_transposed, value, scale, mask): if not valid_post_scale(scale, query): return None return op.SDPA(query, key_transposed, value, scale, mask, _domain="local") + rule = pattern.RewriteRule(sdpa_pattern, sdpa) rule2 = pattern.RewriteRule(sdpa_pattern2, sdpa2) rule3 = pattern.RewriteRule(sdpa_pattern3, sdpa3) -rules = pattern.RewriteRuleSet([rule, rule2, rule3]) \ No newline at end of file +rules = pattern.RewriteRuleSet([rule, rule2, rule3]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 08d1c7c6c..20a8d6295 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -2,8 +2,6 @@ # Licensed under the MIT License. from __future__ import annotations -import numpy as np -import onnxscript.ir as ir from onnxscript.rewriter import pattern """ @@ -30,54 +28,73 @@ """ + def project_transpose_head(op, input, weight): - projected = op.MatMul(input, weight) - # Reshape from (B, S, D) to (B, S, H, D/H) - reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) - # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) - return transposed - -def multi_head_attention_pattern (op, input, query_weight, key_weight, value_weight, cos, sin): - query = project_transpose_head(op, input, query_weight) - query_rope = op.Embed(query, cos, sin, _domain="local") - key = project_transpose_head(op, input, key_weight) - key_rope = op.Embed(key, cos, sin, _domain="local") - # Transpose last two axes of key_rope to compute dot-product via matmul. - key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) - key_reshaped_transposed = op.Transpose(key_reshaped) - key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) - value = project_transpose_head(op, input, value_weight) - attention = op.SDPA(query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local") - # Transpose back to (B, S, H, D/H) - attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) - # Reshape back to (B, S, D) - attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) - return attention_reshaped, key_rope, value - -def multi_head_attention_pattern2 (op, input, query_weight, key_weight, value_weight, cos, sin): - """Variation of first pattern with Reshape omitted.""" - query = project_transpose_head(op, input, query_weight) - query_rope = op.Embed(query, cos, sin, _domain="local") - key = project_transpose_head(op, input, key_weight) - key_rope = op.Embed(key, cos, sin, _domain="local") - # Transpose last two axes of key_rope to compute dot-product via matmul. - # Reshape omitted here. - key_transposed = op.Transpose(key_rope) - # Reshape omitted here - value = project_transpose_head(op, input, value_weight) - attention = op.SDPA(query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local") - # Transpose back to (B, S, H, D/H) - attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) - # Reshape back to (B, S, D) - attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) - return attention_reshaped, key_rope, value - -def multi_head_attention(op, input, query_weight, key_weight, value_weight, cos, sin,): + projected = op.MatMul(input, weight) + # Reshape from (B, S, D) to (B, S, H, D/H) + reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) + return transposed + + +def multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): + query = project_transpose_head(op, input, query_weight) + query_rope = op.Embed(query, cos, sin, _domain="local") + key = project_transpose_head(op, input, key_weight) + key_rope = op.Embed(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) + key_reshaped_transposed = op.Transpose(key_reshaped) + key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) + value = project_transpose_head(op, input, value_weight) + attention = op.SDPA( + query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" + ) + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, key_rope, value + + +def multi_head_attention_pattern2(op, input, query_weight, key_weight, value_weight, cos, sin): + """Variation of first pattern with Reshape omitted.""" + query = project_transpose_head(op, input, query_weight) + query_rope = op.Embed(query, cos, sin, _domain="local") + key = project_transpose_head(op, input, key_weight) + key_rope = op.Embed(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + # Reshape omitted here. + key_transposed = op.Transpose(key_rope) + # Reshape omitted here + value = project_transpose_head(op, input, value_weight) + attention = op.SDPA( + query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" + ) + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, key_rope, value + + +def multi_head_attention( + op, + input, + query_weight, + key_weight, + value_weight, + cos, + sin, +): # TODO: other checks and concatenation of weights - return op.MultiHeadAttention(input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3) + return op.MultiHeadAttention( + input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3 + ) + rule = pattern.RewriteRule(multi_head_attention_pattern, multi_head_attention) rule2 = pattern.RewriteRule(multi_head_attention_pattern2, multi_head_attention) -rules = pattern.RewriteRuleSet([rule, rule2]) \ No newline at end of file +rules = pattern.RewriteRuleSet([rule, rule2]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index 322297378..f5517ec39 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -2,12 +2,9 @@ # Licensed under the MIT License. from __future__ import annotations -import numpy as np -import onnxscript.ir as ir from onnxscript.rewriter import _ir_utils, pattern - # Pattern to match against def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): x_cast = op.Cast(x, to=compute_dtype) @@ -20,6 +17,7 @@ def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): normalized_cast = op.Cast(normalized, to=target_dtype) return op.Mul(scale, normalized_cast) + # Replacement def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): epsilon_value = _ir_utils.get_singleton_value(epsilon) @@ -28,13 +26,14 @@ def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): source_dtype = x.dtype if source_dtype is None or source_dtype != target_dtype.value: return None - return op.SimplifiedLayerNormalization ( + return op.SimplifiedLayerNormalization( x, scale, - axis=-1, + axis=-1, epsilon=epsilon_value, - stash_type=compute_dtype.value, - _domain="com.microsoft") + stash_type=compute_dtype.value, + _domain="com.microsoft", + ) rule = pattern.RewriteRule(rms_norm_pattern, simplified_layer_norm) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 3732a653f..598f90cba 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -2,9 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -import numpy as np -import onnxscript.ir as ir -from onnxscript.rewriter import pattern, _ir_utils +from onnxscript.rewriter import _ir_utils, pattern + def rotate_half_pattern(op, x, start1, end1, start2, end2): # Slice(input, starts, ends, axes, steps) @@ -14,6 +13,7 @@ def rotate_half_pattern(op, x, start1, end1, start2, end2): rotated_x = op.Concat(minus_x2, x1, axis=-1) return rotated_x + def rotate_half(op, x, start1, end1, start2, end2): # Check that x is being split into two equal halves: start1_val = _ir_utils.get_singleton_value(start1) @@ -25,16 +25,24 @@ def rotate_half(op, x, start1, end1, start2, end2): return None dim_size = x.shape[3] half_dim_size = dim_size // 2 - if start1_val == 0 and end1_val == half_dim_size and start2_val == half_dim_size and end2_val >= dim_size: + if ( + start1_val == 0 + and end1_val == half_dim_size + and start2_val == half_dim_size + and end2_val >= dim_size + ): return op.RotateHalf(x, _domain="local") return None + def embed_pattern(op, x, cos, sin): return x * cos + op.RotateHalf(x, _domain="local") * sin + def embed(op, x, cos, sin, **_): return op.Embed(x, cos, sin, _domain="local") + rule = pattern.RewriteRule(rotate_half_pattern, rotate_half) -embed_rule = pattern.RewriteRule(embed_pattern, embed) \ No newline at end of file +embed_rule = pattern.RewriteRule(embed_pattern, embed) diff --git a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py index eecb9a60b..e927e3726 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py @@ -2,10 +2,9 @@ # Licensed under the MIT License. from __future__ import annotations -import numpy as np -import onnxscript.ir as ir from onnxscript.rewriter import pattern + def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): skip_sum = op.Add(input, skip) normalized = op.SimplifiedLayerNormalization( @@ -14,9 +13,11 @@ def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): axis=-1, epsilon=epsilon, stash_type=stash_type, - _domain="com.microsoft") + _domain="com.microsoft", + ) return normalized, skip_sum + def skip_normalization(op, input, skip, gamma, epsilon, stash_type): normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( input, @@ -25,8 +26,11 @@ def skip_normalization(op, input, skip, gamma, epsilon, stash_type): epsilon=epsilon, stash_type=stash_type, _domain="com.microsoft", - _outputs=4 - ) + _outputs=4, + ) return normalized, skip_sum -rule = pattern.RewriteRule(skip_norm_pattern, skip_normalization, matcher=pattern.SimplePatternMatcher) \ No newline at end of file + +rule = pattern.RewriteRule( + skip_norm_pattern, skip_normalization, matcher=pattern.SimplePatternMatcher +) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 408cf276d..239c5c719 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -250,7 +250,13 @@ def __call__( inputs = [_to_value_pattern(x) for x in args] attributes = {name: _to_attr_pattern(value) for (name, value) in kwargs.items()} node_pattern = NodePattern( - opset_pattern, self.op_name, inputs, attributes, _outputs, _allow_other_attributes, _allow_other_inputs + opset_pattern, + self.op_name, + inputs, + attributes, + _outputs, + _allow_other_attributes, + _allow_other_inputs, ) self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs @@ -563,7 +569,13 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat inputs = [inputs[1], inputs[0]] outputs = [value.name for value in self.outputs] copied = NodePattern( - self.domain, self.op, inputs, self.attributes, outputs, self.allow_other_attributes, self.allow_other_inputs + self.domain, + self.op, + inputs, + self.attributes, + outputs, + self.allow_other_attributes, + self.allow_other_inputs, ) node_map[self] = copied return copied From 0491366e2428b0ef170d801c2f9f9dd4a2a26671 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 09:07:51 -0800 Subject: [PATCH 12/28] Undo dropout rewrite rule change --- onnxscript/rewriter/no_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py index 7c2d91635..6d25b0ed3 100644 --- a/onnxscript/rewriter/no_op.py +++ b/onnxscript/rewriter/no_op.py @@ -24,7 +24,7 @@ def div_by_1(op, x): def dropout_zero(op, x): - return op.Dropout(x, 0.0) + return op.Dropout(x, ratio=0.0) def dropout_inference(op, x): From 3fb7cd1c2ba4b9d59c56dc1035c18efee7c1bfe7 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 09:11:43 -0800 Subject: [PATCH 13/28] Add concat test --- onnxscript/optimizer/_constant_folding_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 58d37b952..9f3ef20ea 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -417,6 +417,22 @@ def test_dropout_identity(self, dropout_node: str): self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].op_type, "Identity") + def test_concat_identity(self): + if not self.using_ir: + return + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + z = Concat (x) + } + """ + ) + optimized = self._fold(model) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].op_type, "Identity") + if __name__ == "__main__": unittest.main() From 73723f07851be650406f81612f9f21ea8af61cf9 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 13:46:38 -0800 Subject: [PATCH 14/28] Add expand identity optimization --- onnxscript/optimizer/_constant_folding.py | 19 +++++++++++++++ .../optimizer/_constant_folding_test.py | 15 ++++++++++++ .../xformers/_optimize_transformers.py | 21 ---------------- onnxscript/rewriter/pattern.py | 24 ------------------- 4 files changed, 34 insertions(+), 45 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index e9276cb32..fca172afe 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -455,6 +455,25 @@ def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return op.Identity(inputs[0]) return None +@register("Expand") +def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace an Expand node by Identity when applicable.""" + if len(node.inputs) != 2: + return None + if (input := node.inputs[0]) is None: + return None + if (input_shape := input.shape) is None: + # Input shape is not known. + return None + if (expanded_shape := _get_numpy_value(node.inputs[1])) is None: + # Target shape is not known. + return None + if expanded_shape.ndim != 1: + # Target shape must be a 1D tensor. Erroneous model. + return None + if input_shape.dims == tuple(expanded_shape.tolist()): + return op.Identity(input) + return None @register("ConcatFromSequence") def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 52e06bd56..ec5171523 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -433,6 +433,21 @@ def test_concat_identity(self): self.assertEqual(len(optimized.graph.node), 1) self.assertEqual(optimized.graph.node[0].op_type, "Identity") + def test_expand_identity(self): + if not self.using_ir: + self.skipTest("New optimizations not supported for legacy optimizer") + model = onnx.parser.parse_model( + """ + + agraph (float[128, 256] x) => (float[128, 256] z) + { + shape = Constant () + z = Expand (x, shape) + } + """ + ) + optimized = self._fold(model) + self.assertEqual(optimized.graph.node[-1].op_type, "Identity") if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index d637d75fa..5ecfe1772 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -15,24 +15,6 @@ skip_normalization, ) -expand_rule = pattern.make_rewrite_rule_from_class(ExpandIdentity) -transpose_rule = pattern.make_rewrite_rule_from_class(TransposeIdentity) - - -def basic_optimize(irmodel: ir.Model) -> None: - def apply(rulename: str, rule): - count = rule.apply_to_model(irmodel) - print(f"{rulename} count: {count}") - - _constant_folding.fold_constants( - irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4 - ) - - apply("Dropout", no_op.dropout_zero_rule) - apply("Expand", expand_rule) - apply("Transpose", transpose_rule) - remove_unused_nodes(irmodel) - def optimize(irmodel: ir.Model) -> None: def apply(rulename: str, rule): @@ -42,9 +24,6 @@ def apply(rulename: str, rule): _constant_folding.fold_constants( irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4 ) - - apply("Dropout", no_op.dropout_zero_rule) - apply("Expand", expand_rule) remove_unused_nodes(irmodel) apply("RMS Normalization", rms_normalization.rule) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index ab4145609..66d9b3196 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -255,13 +255,8 @@ def __call__( inputs, attributes, _outputs, -<<<<<<< HEAD - _allow_other_attributes, - _allow_other_inputs, -======= allow_other_attributes=_allow_other_attributes, allow_other_inputs=_allow_other_inputs, ->>>>>>> main ) self.pattern_builder.add_node(node_pattern) output_values = node_pattern.outputs @@ -485,34 +480,20 @@ def __init__( outputs: Sequence[str | None], *, allow_other_attributes: bool | None, -<<<<<<< HEAD - _allow_other_inputs: bool | None, -======= allow_other_inputs: bool | None, ->>>>>>> main ): if allow_other_attributes is None: # Default behavior: allow other unmatched attributes in the node. allow_other_attributes = True -<<<<<<< HEAD - if _allow_other_inputs is None: - # TODO(rama): Should we default to True? For now, we preserve the current behavior. - _allow_other_inputs = False -======= if allow_other_inputs is None: # TODO(rama): Should we default to True? For now, we preserve the current behavior. allow_other_inputs = False ->>>>>>> main self.domain = domain self.op = StringConstantPattern(op) if isinstance(op, str) else op self.inputs = [_to_value_pattern(x) for x in inputs] self.attributes = attributes self.allow_other_attributes = allow_other_attributes -<<<<<<< HEAD - self.allow_other_inputs = _allow_other_inputs -======= self.allow_other_inputs = allow_other_inputs ->>>>>>> main # In the common case, domain and op are constants, which can be used to optimize matching. if isinstance(op, str) and isinstance(domain, StringConstantPattern): # TODO(rama): support overloaded operators. @@ -594,13 +575,8 @@ def clone(self, node_map: dict[NodePattern, NodePattern], swap: bool) -> NodePat inputs, self.attributes, outputs, -<<<<<<< HEAD - self.allow_other_attributes, - self.allow_other_inputs, -======= allow_other_attributes=self.allow_other_attributes, allow_other_inputs=self.allow_other_inputs, ->>>>>>> main ) node_map[self] = copied return copied From bb977ec3dc959ca1486f2d863aad155f57fb9a82 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 15:34:24 -0800 Subject: [PATCH 15/28] Some cleanup --- onnxscript/optimizer/_constant_folding.py | 2 ++ .../optimizer/_constant_folding_test.py | 1 + .../rewriter/onnxruntime/xformers/__init__.py | 10 +++++++ .../xformers/_optimize_transformers.py | 28 ++++++++----------- .../xformers/_optimize_transformers_test.py | 4 +-- .../xformers/multi_head_attention.py | 2 +- .../onnxruntime/xformers/rms_normalization.py | 2 +- .../onnxruntime/xformers/rotary_embedding.py | 20 +++++-------- .../xformers/{attention.py => sdpa.py} | 2 +- .../xformers/skip_normalization.py | 2 +- 10 files changed, 38 insertions(+), 35 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/__init__.py rename onnxscript/rewriter/onnxruntime/xformers/{attention.py => sdpa.py} (98%) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index fca172afe..c49b18de7 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -455,6 +455,7 @@ def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return op.Identity(inputs[0]) return None + @register("Expand") def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace an Expand node by Identity when applicable.""" @@ -475,6 +476,7 @@ def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return op.Identity(input) return None + @register("ConcatFromSequence") def concat_from_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index ec5171523..9276a5901 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -449,5 +449,6 @@ def test_expand_identity(self): optimized = self._fold(model) self.assertEqual(optimized.graph.node[-1].op_type, "Identity") + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py new file mode 100644 index 000000000..c9f16a55c --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import rotary_embedding_rules +from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import rms_normalization_rules +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import skip_normalization_rules +from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import mha_rules + diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index 5ecfe1772..c65bff8b7 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -3,22 +3,19 @@ from __future__ import annotations import onnxscript.ir as ir -import onnxscript.rewriter.pattern as pattern from onnxscript.optimizer import _constant_folding, remove_unused_nodes -from onnxscript.rewriter import no_op -from onnxscript.rewriter.llama_rule_sets import ExpandIdentity, TransposeIdentity from onnxscript.rewriter.onnxruntime.xformers import ( - attention, - multi_head_attention, - rms_normalization, - rotary_embedding, - skip_normalization, + mha_rules, + rms_normalization_rules, + rotary_embedding_rules, + sdpa_rules, + skip_normalization_rules, ) -def optimize(irmodel: ir.Model) -> None: +def optimize(irmodel: ir.Model, verbose: int = 0) -> None: def apply(rulename: str, rule): - count = rule.apply_to_model(irmodel) + count = rule.apply_to_model(irmodel, verbose=verbose) print(f"{rulename} count: {count}") _constant_folding.fold_constants( @@ -26,13 +23,12 @@ def apply(rulename: str, rule): ) remove_unused_nodes(irmodel) - apply("RMS Normalization", rms_normalization.rule) - apply("Skip Normalization", skip_normalization.rule) + apply("RMS Normalization", rms_normalization_rules) + apply("Skip Normalization", skip_normalization_rules) _constant_folding.fold_constants(irmodel) remove_unused_nodes(irmodel) - apply("Attention", attention.rules) - apply("Rotate", rotary_embedding.rule) - apply("Embed", rotary_embedding.embed_rule) - apply("Multi-Head-Attention", multi_head_attention.rules) + apply("SDPA-Attention", sdpa_rules) + apply("RotaryEmbedding", rotary_embedding_rules) + apply("Multi-Head-Attention", mha_rules) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 04ff18fa5..7bc23fc31 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -86,13 +86,13 @@ def test_masked_attention(self): def test_sdpa_attention(self): model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=False) - optimize_transformers.optimize(model) + optimize_transformers.optimize(model, verbose=10) op_types = [n.op_type for n in model.graph] self.assertIn("MultiHeadAttention", op_types) def test_masked_sdpa_attention(self): model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=True) - optimize_transformers.optimize(model) + optimize_transformers.optimize(model, verbose=10) op_types = [n.op_type for n in model.graph] self.assertIn("MultiHeadAttention", op_types) diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 20a8d6295..5b41a22d2 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -97,4 +97,4 @@ def multi_head_attention( rule = pattern.RewriteRule(multi_head_attention_pattern, multi_head_attention) rule2 = pattern.RewriteRule(multi_head_attention_pattern2, multi_head_attention) -rules = pattern.RewriteRuleSet([rule, rule2]) +mha_rules = pattern.RewriteRuleSet([rule, rule2]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index f5517ec39..d4a7c359f 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -36,4 +36,4 @@ def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): ) -rule = pattern.RewriteRule(rms_norm_pattern, simplified_layer_norm) +rms_normalization_rules = pattern.RewriteRule(rms_norm_pattern, simplified_layer_norm) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 598f90cba..764482abd 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -14,7 +14,11 @@ def rotate_half_pattern(op, x, start1, end1, start2, end2): return rotated_x -def rotate_half(op, x, start1, end1, start2, end2): +def embed_pattern(op, x, cos, sin, start1, end1, start2, end2): + return x * cos + rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + + +def embed(op, x, cos, sin, start1, end1, start2, end2): # Check that x is being split into two equal halves: start1_val = _ir_utils.get_singleton_value(start1) end1_val = _ir_utils.get_singleton_value(end1) @@ -31,18 +35,8 @@ def rotate_half(op, x, start1, end1, start2, end2): and start2_val == half_dim_size and end2_val >= dim_size ): - return op.RotateHalf(x, _domain="local") + return op.Embed(x, cos, sin, _domain="local") return None -def embed_pattern(op, x, cos, sin): - return x * cos + op.RotateHalf(x, _domain="local") * sin - - -def embed(op, x, cos, sin, **_): - return op.Embed(x, cos, sin, _domain="local") - - -rule = pattern.RewriteRule(rotate_half_pattern, rotate_half) - -embed_rule = pattern.RewriteRule(embed_pattern, embed) +rotary_embedding_rules = pattern.RewriteRule(embed_pattern, embed) diff --git a/onnxscript/rewriter/onnxruntime/xformers/attention.py b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py similarity index 98% rename from onnxscript/rewriter/onnxruntime/xformers/attention.py rename to onnxscript/rewriter/onnxruntime/xformers/sdpa.py index f88698a82..93d093695 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py @@ -82,4 +82,4 @@ def sdpa3(op, query, key_transposed, value, scale, mask): rule2 = pattern.RewriteRule(sdpa_pattern2, sdpa2) rule3 = pattern.RewriteRule(sdpa_pattern3, sdpa3) -rules = pattern.RewriteRuleSet([rule, rule2, rule3]) +sdpa_rules = pattern.RewriteRuleSet([rule, rule2, rule3]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py index e927e3726..5c0f99efb 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py @@ -31,6 +31,6 @@ def skip_normalization(op, input, skip, gamma, epsilon, stash_type): return normalized, skip_sum -rule = pattern.RewriteRule( +skip_normalization_rules = pattern.RewriteRule( skip_norm_pattern, skip_normalization, matcher=pattern.SimplePatternMatcher ) From 7f1606f9a3620576c947204c7e2273c2fe21ba28 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 16:23:30 -0800 Subject: [PATCH 16/28] Fix dropout optimization --- onnxscript/optimizer/_constant_folding.py | 23 +++++++++++++------ .../rewriter/onnxruntime/xformers/__init__.py | 19 ++++++++++----- .../xformers/_optimize_transformers_test.py | 4 ++-- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index c49b18de7..507a83175 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -13,6 +13,7 @@ import numpy as np import onnx +import onnx.helper import onnx.reference.ops import onnxscript.ir as ir @@ -434,17 +435,25 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue: @register("Dropout", version=(12, None)) def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: """Replace a Dropout by Identity when applicable.""" - if len(node.outputs) != 1: - # If output mask is requested, optimization is more complex. - # TODO: handle this case. But unlikely to be needed in practice. - return None + + def optimized_dropout(): + input = node.inputs[0] + output = op.Identity(input) + if len(node.outputs) == 1: + return output + else: + true_tensor = onnx.helper.make_tensor("true", onnx.TensorProto.BOOL, [1], [True]) + input_shape = op.Shape(input) + mask = op.ConstantOfShape(input_shape, value=true_tensor) + return output, mask + inputs = node.inputs if (len(inputs) <= 2) or inputs[2] is None: # No training_mode specified: - return op.Identity(inputs[0]) + return optimized_dropout() if _get_bool_value(inputs[2]) is False: # training_mode is False: dropout is not applied. - return op.Identity(inputs[0]) + return optimized_dropout() ratio = _get_numpy_value(inputs[1]) if ratio is None: return None @@ -452,7 +461,7 @@ def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None if ratio.item() == 0: # dropout ratio is 0: dropout is not applied. - return op.Identity(inputs[0]) + return optimized_dropout() return None diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py index c9f16a55c..dfd0df8da 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -2,9 +2,16 @@ # Licensed under the MIT License. from __future__ import annotations -from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import rotary_embedding_rules -from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules -from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import rms_normalization_rules -from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import skip_normalization_rules -from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import mha_rules - +from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import ( + mha_rules as mha_rules, +) +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import ( + rms_normalization_rules as rms_normalization_rules, +) +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import ( + rotary_embedding_rules as rotary_embedding_rules, +) +from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import ( + skip_normalization_rules as skip_normalization_rules, +) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 7bc23fc31..04ff18fa5 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -86,13 +86,13 @@ def test_masked_attention(self): def test_sdpa_attention(self): model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=False) - optimize_transformers.optimize(model, verbose=10) + optimize_transformers.optimize(model) op_types = [n.op_type for n in model.graph] self.assertIn("MultiHeadAttention", op_types) def test_masked_sdpa_attention(self): model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=True) - optimize_transformers.optimize(model, verbose=10) + optimize_transformers.optimize(model) op_types = [n.op_type for n in model.graph] self.assertIn("MultiHeadAttention", op_types) From a3e0d1d3bb26e2632e24ead91a85147ba36b6ec3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 16:43:14 -0800 Subject: [PATCH 17/28] Some more cleanup --- onnxscript/optimizer/__init__.py | 14 ++++++++++++-- onnxscript/optimizer/_inliner.py | 7 ++++--- .../onnxruntime/xformers/_optimize_transformers.py | 8 +++----- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index f30976c24..ad5053e10 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -4,13 +4,15 @@ import onnx +import onnxscript.optimizer._constant_folding as constant_folding import onnxscript.optimizer._legacy._optimizer as legacy_optimizer +import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding from onnxscript import ir -from onnxscript.optimizer._constant_folding import basic_constant_propagation -from onnxscript.optimizer._legacy.constant_folding import fold_constants from onnxscript.optimizer._optimizer import optimize_ir from onnxscript.optimizer._remove_unused import remove_unused_nodes +basic_constant_propagation = constant_folding.basic_constant_propagation +fold_constants_ir = constant_folding.fold_constants def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): if isinstance(model, ir.Model): @@ -18,9 +20,17 @@ def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): else: return legacy_optimizer.optimize(model, *args, **kwargs) +def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs): + if isinstance(model, ir.Model): + return constant_folding.fold_constants(model, *args, **kwargs) + else: + return legacy_constant_folding.fold_constants(model, *args, **kwargs) + + __all__ = [ "fold_constants", + "fold_constants_ir", "remove_unused_nodes", "optimize", "optimize_ir", diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/optimizer/_inliner.py index 590937397..ebae203d7 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/optimizer/_inliner.py @@ -305,6 +305,7 @@ def inline_calls_in(self, graph: ir.Graph) -> None: def inline(model: ir.Model) -> None: """Inline all function calls (recursively) in the model.""" - inliner = _Inliner(model) - inliner.inline_calls_in(model.graph) - model.functions.clear() + if model.functions: + inliner = _Inliner(model) + inliner.inline_calls_in(model.graph) + model.functions.clear() diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index c65bff8b7..2d2c475f0 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -3,7 +3,7 @@ from __future__ import annotations import onnxscript.ir as ir -from onnxscript.optimizer import _constant_folding, remove_unused_nodes +from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes from onnxscript.rewriter.onnxruntime.xformers import ( mha_rules, rms_normalization_rules, @@ -18,15 +18,13 @@ def apply(rulename: str, rule): count = rule.apply_to_model(irmodel, verbose=verbose) print(f"{rulename} count: {count}") - _constant_folding.fold_constants( - irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4 - ) + fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4) remove_unused_nodes(irmodel) apply("RMS Normalization", rms_normalization_rules) apply("Skip Normalization", skip_normalization_rules) - _constant_folding.fold_constants(irmodel) + fold_constants_ir(irmodel) remove_unused_nodes(irmodel) apply("SDPA-Attention", sdpa_rules) From 087993423c1bb0b87232570d7a269df76213c5c3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 8 Nov 2024 17:14:47 -0800 Subject: [PATCH 18/28] Cleanup --- onnxscript/optimizer/__init__.py | 3 +- .../xformers/multi_head_attention.py | 61 ++++++++++--------- .../onnxruntime/xformers/rms_normalization.py | 7 ++- .../onnxruntime/xformers/rotary_embedding.py | 10 +-- .../xformers/skip_normalization.py | 10 +-- 5 files changed, 50 insertions(+), 41 deletions(-) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index ad5053e10..8ba6229c1 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -14,12 +14,14 @@ basic_constant_propagation = constant_folding.basic_constant_propagation fold_constants_ir = constant_folding.fold_constants + def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs): if isinstance(model, ir.Model): return optimize_ir(model, *args, **kwargs) else: return legacy_optimizer.optimize(model, *args, **kwargs) + def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs): if isinstance(model, ir.Model): return constant_folding.fold_constants(model, *args, **kwargs) @@ -27,7 +29,6 @@ def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs): return legacy_constant_folding.fold_constants(model, *args, **kwargs) - __all__ = [ "fold_constants", "fold_constants_ir", diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 5b41a22d2..37862d324 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -5,31 +5,32 @@ from onnxscript.rewriter import pattern """ -MultiHeadAttention: +The MultiHeadAttention pattern: +B: Batch size +S: Sequence length D: input embedding dimension H: number of heads -d_h: head size -usually, D = H * d_h +d_h: head size (usually, D = H * d_h) thus, weights are usually of shape (D, D) and (D, D) and (D, D) -for Q, K, V: - MatMul (Input, W for Q, K, V) => B, S, D - Reshape to B, S, 32, 64 (that is, B, S, H, d_h) - Transpose to B, 32, S, 64 (that is, B, H, S, d_h) +for each of Q, K, and V, we have the following pattern: + MatMul (Input, W), producing output of shape (B, S, D) + Reshape to produce a matrix of shape (B, S, H, d_h) + Transpose middle two axes to produce a matrix of shape (B, H, S, d_h) -Here, 32 is the number of heads and 64 is the head size +This is followed by a RotaryEmbedding pattern for Q and K -Embed Q and K - -One of the embeddings (namely ) is also output of layer -and last two axes are transposed for SDPA +The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) +The dot-product attention is then computed using SDPA + +Finally, the output is transposed and reshaped back to (B, S, D) shape """ -def project_transpose_head(op, input, weight): +def _project_transpose_head(op, input, weight): projected = op.MatMul(input, weight) # Reshape from (B, S, D) to (B, S, H, D/H) reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) @@ -38,16 +39,16 @@ def project_transpose_head(op, input, weight): return transposed -def multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): - query = project_transpose_head(op, input, query_weight) - query_rope = op.Embed(query, cos, sin, _domain="local") - key = project_transpose_head(op, input, key_weight) - key_rope = op.Embed(key, cos, sin, _domain="local") +def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): + query = _project_transpose_head(op, input, query_weight) + query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") + key = _project_transpose_head(op, input, key_weight) + key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") # Transpose last two axes of key_rope to compute dot-product via matmul. key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) key_reshaped_transposed = op.Transpose(key_reshaped) key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) - value = project_transpose_head(op, input, value_weight) + value = _project_transpose_head(op, input, value_weight) attention = op.SDPA( query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" ) @@ -58,17 +59,19 @@ def multi_head_attention_pattern(op, input, query_weight, key_weight, value_weig return attention_reshaped, key_rope, value -def multi_head_attention_pattern2(op, input, query_weight, key_weight, value_weight, cos, sin): +def _multi_head_attention_pattern2( + op, input, query_weight, key_weight, value_weight, cos, sin +): """Variation of first pattern with Reshape omitted.""" - query = project_transpose_head(op, input, query_weight) - query_rope = op.Embed(query, cos, sin, _domain="local") - key = project_transpose_head(op, input, key_weight) - key_rope = op.Embed(key, cos, sin, _domain="local") + query = _project_transpose_head(op, input, query_weight) + query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") + key = _project_transpose_head(op, input, key_weight) + key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") # Transpose last two axes of key_rope to compute dot-product via matmul. # Reshape omitted here. key_transposed = op.Transpose(key_rope) # Reshape omitted here - value = project_transpose_head(op, input, value_weight) + value = _project_transpose_head(op, input, value_weight) attention = op.SDPA( query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" ) @@ -79,7 +82,7 @@ def multi_head_attention_pattern2(op, input, query_weight, key_weight, value_wei return attention_reshaped, key_rope, value -def multi_head_attention( +def _multi_head_attention( op, input, query_weight, @@ -94,7 +97,7 @@ def multi_head_attention( ) -rule = pattern.RewriteRule(multi_head_attention_pattern, multi_head_attention) -rule2 = pattern.RewriteRule(multi_head_attention_pattern2, multi_head_attention) +_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention) +_rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) -mha_rules = pattern.RewriteRuleSet([rule, rule2]) +mha_rules = pattern.RewriteRuleSet([_rule1, _rule2]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index d4a7c359f..b0527111b 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -6,7 +6,7 @@ # Pattern to match against -def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): +def _rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): x_cast = op.Cast(x, to=compute_dtype) x_square = op.Pow(x_cast, 2.0) mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) @@ -19,7 +19,7 @@ def rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): # Replacement -def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): +def _simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): return None @@ -36,4 +36,5 @@ def simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): ) -rms_normalization_rules = pattern.RewriteRule(rms_norm_pattern, simplified_layer_norm) +_rule = pattern.RewriteRule(_rms_norm_pattern, _simplified_layer_norm) +rms_normalization_rules = pattern.RewriteRuleSet([_rule]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 764482abd..3312e3062 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -14,11 +14,11 @@ def rotate_half_pattern(op, x, start1, end1, start2, end2): return rotated_x -def embed_pattern(op, x, cos, sin, start1, end1, start2, end2): +def _rotary_embedding_pattern(op, x, cos, sin, start1, end1, start2, end2): return x * cos + rotate_half_pattern(op, x, start1, end1, start2, end2) * sin -def embed(op, x, cos, sin, start1, end1, start2, end2): +def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): # Check that x is being split into two equal halves: start1_val = _ir_utils.get_singleton_value(start1) end1_val = _ir_utils.get_singleton_value(end1) @@ -35,8 +35,10 @@ def embed(op, x, cos, sin, start1, end1, start2, end2): and start2_val == half_dim_size and end2_val >= dim_size ): - return op.Embed(x, cos, sin, _domain="local") + return op.RotaryEmbedding(x, cos, sin, _domain="local") return None -rotary_embedding_rules = pattern.RewriteRule(embed_pattern, embed) +_rule = pattern.RewriteRule(_rotary_embedding_pattern, _rotary_embedding) + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py index 5c0f99efb..38f4281d5 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/skip_normalization.py @@ -5,7 +5,7 @@ from onnxscript.rewriter import pattern -def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): +def _skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): skip_sum = op.Add(input, skip) normalized = op.SimplifiedLayerNormalization( skip_sum, @@ -18,7 +18,7 @@ def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type): return normalized, skip_sum -def skip_normalization(op, input, skip, gamma, epsilon, stash_type): +def _skip_normalization(op, input, skip, gamma, epsilon, stash_type): normalized, mean, inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( input, skip, @@ -31,6 +31,8 @@ def skip_normalization(op, input, skip, gamma, epsilon, stash_type): return normalized, skip_sum -skip_normalization_rules = pattern.RewriteRule( - skip_norm_pattern, skip_normalization, matcher=pattern.SimplePatternMatcher +_rule = pattern.RewriteRule( + _skip_norm_pattern, _skip_normalization, matcher=pattern.SimplePatternMatcher ) + +skip_normalization_rules = pattern.RewriteRuleSet([_rule]) From a8ac3ee2c93a9df92841b12108aad23f6c77d435 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 12 Nov 2024 17:46:44 -0800 Subject: [PATCH 19/28] Minor fixes --- .../xformers/_optimize_transformers.py | 2 + .../xformers/_optimize_transformers_test.py | 38 +++++++------------ 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index 2d2c475f0..1eb179086 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -30,3 +30,5 @@ def apply(rulename: str, rule): apply("SDPA-Attention", sdpa_rules) apply("RotaryEmbedding", rotary_embedding_rules) apply("Multi-Head-Attention", mha_rules) + + remove_unused_nodes(irmodel) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 04ff18fa5..43a1bfbd3 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations +from parameterized import parameterized import unittest import torch @@ -41,10 +42,6 @@ vocab_size=49152, ) -# Create a LlamaAttention object with the desired parameters -# model = modeling_llama.LlamaAttention(_config, 0) -model = modeling_llama.LlamaSdpaAttention(_config, 0) - # Dimensions for inputs: _batch_size = 1 _seq_len = 10 @@ -66,33 +63,24 @@ def _get_model(llama_attention_class, with_mask: bool): else: inputs = (_hidden_states, None, _position_ids) exported = torch.onnx.export(model, inputs, dynamo=True) + # ORT Transformer optimizations are applied after basic optimization. onnxscript.optimizer.optimize(exported.model) - # optimize_transformers.basic_optimize(exported.model) return exported.model class TestOptimizeTransformers(unittest.TestCase): - def test_attention(self): - model = _get_model(modeling_llama.LlamaAttention, with_mask=False) - optimize_transformers.optimize(model) - op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) - - def test_masked_attention(self): - model = _get_model(modeling_llama.LlamaAttention, with_mask=True) - optimize_transformers.optimize(model) - op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) - - def test_sdpa_attention(self): - model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=False) - optimize_transformers.optimize(model) - op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) - - def test_masked_sdpa_attention(self): - model = _get_model(modeling_llama.LlamaSdpaAttention, with_mask=True) + @parameterized.expand([ + ("attention", modeling_llama.LlamaAttention, False), + ("masked_attention", modeling_llama.LlamaAttention, True), + ("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), + ("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), + ]) + def test_attention_optimization(self, name, attention_class, with_mask): + model = _get_model(attention_class, with_mask) + model.display() + print("======>") optimize_transformers.optimize(model) + model.display() op_types = [n.op_type for n in model.graph] self.assertIn("MultiHeadAttention", op_types) From 044a6382490036903b569fe3314dd30ad6e9f904 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 13 Nov 2024 01:54:09 -0800 Subject: [PATCH 20/28] Add ort check to test --- .../xformers/_optimize_transformers.py | 4 + .../xformers/_optimize_transformers_test.py | 115 ++++++++++++++---- .../xformers/multi_head_attention.py | 1 + 3 files changed, 94 insertions(+), 26 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index 1eb179086..7c5852c72 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -13,6 +13,10 @@ ) +def fuse_rotary_embedding(irmodel: ir.Model) -> None: + count = rotary_embedding_rules.apply_to_model(irmodel) + print(f"RotaryEmbedding count: {count}") + def optimize(irmodel: ir.Model, verbose: int = 0) -> None: def apply(rulename: str, rule): count = rule.apply_to_model(irmodel, verbose=verbose) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 43a1bfbd3..73df98914 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -2,13 +2,18 @@ # Licensed under the MIT License. from __future__ import annotations -from parameterized import parameterized +import os +import tempfile import unittest +import numpy as np +import onnxruntime import torch import transformers.models.llama.modeling_llama as modeling_llama +from parameterized import parameterized from transformers import LlamaConfig +import onnxscript.ir._io as io import onnxscript.optimizer from onnxscript.rewriter.onnxruntime.xformers import ( _optimize_transformers as optimize_transformers, @@ -51,38 +56,96 @@ # Generate inputs: _hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32) -_attention_mask = torch.rand(_batch_size, 1, _seq_len, _seq_len, dtype=torch.float32) +_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32)) +_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len) _position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) - # Get model in ONNX format -def _get_model(llama_attention_class, with_mask: bool): - model = llama_attention_class(_config, 0) - if with_mask: - inputs = (_hidden_states, _attention_mask, _position_ids) - else: - inputs = (_hidden_states, None, _position_ids) - exported = torch.onnx.export(model, inputs, dynamo=True) - # ORT Transformer optimizations are applied after basic optimization. - onnxscript.optimizer.optimize(exported.model) - return exported.model +# def _get_model(llama_attention_class, with_mask: bool): +# model = llama_attention_class(_config, 0) +# if with_mask: +# inputs = (_hidden_states, _attention_mask, _position_ids) +# else: +# inputs = (_hidden_states, None, _position_ids) +# exported = torch.onnx.export(model, inputs, dynamo=True) +# # ORT Transformer optimizations are applied after basic optimization. +# onnxscript.optimizer.optimize(exported.model) +# return exported.model + +class _TestData: + def __init__(self, name: str, attention_class, with_mask: bool): + self.name = name + self.attention_class = attention_class + self.with_mask = with_mask + + def get_torch_model(self): + return self.attention_class(_config, 0) + + def get_onnx_model(self): + model = self.get_torch_model() + inputs = self.get_inputs() + input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] + exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True) + # ORT Transformer optimizations are applied after basic optimization. + onnxscript.optimizer.optimize(exported.model) + return exported.model + + def get_inputs(self): + if self.with_mask: + return (_hidden_states, _attention_mask, _position_ids) + else: + return (_hidden_states, None, _position_ids) + + def get_torch_outputs(self): + return self.get_torch_model()(*self.get_inputs()) + + def get_ort_inputs(self): + inputs = self.get_inputs() + return {f"input{i}": input for i, input in enumerate(inputs) if input is not None} + +_test_cases = [ + _TestData("attention", modeling_llama.LlamaAttention, False), + _TestData("masked_attention", modeling_llama.LlamaAttention, True), + _TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), + _TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), +] + +_test_case_tuples = [ (t,) for t in _test_cases] + +def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): + providers = ["CPUExecutionProvider"] + with tempfile.TemporaryDirectory() as temp_dir: + model_path = os.path.join(temp_dir, f"{model_name}.onnx") + io.save(model, model_path) + # Run optimized model + session = onnxruntime.InferenceSession(model_path, providers=providers) + ort_outputs = session.run(None, inputs) + for i, (baseline_output, optimized_output) in enumerate( + zip(expected_outputs, ort_outputs) + ): + try: + np.testing.assert_equal(baseline_output.shape, optimized_output.shape) + np.testing.assert_allclose( + baseline_output, optimized_output, rtol=rtol, atol=atol + ) + except AssertionError as e: + print( + f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" + ) + raise class TestOptimizeTransformers(unittest.TestCase): - @parameterized.expand([ - ("attention", modeling_llama.LlamaAttention, False), - ("masked_attention", modeling_llama.LlamaAttention, True), - ("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), - ("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), - ]) - def test_attention_optimization(self, name, attention_class, with_mask): - model = _get_model(attention_class, with_mask) - model.display() - print("======>") - optimize_transformers.optimize(model) - model.display() + @parameterized.expand(_test_case_tuples) + def test_attention_optimization(self, test_data: _TestData): + model = test_data.get_onnx_model() + # model.display() + # print("======>") + optimize_transformers.fuse_rotary_embedding(model) + # model.display() op_types = [n.op_type for n in model.graph] - self.assertIn("MultiHeadAttention", op_types) + self.assertIn("RotaryEmbedding", op_types) + # _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs()) if __name__ == "__main__": diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 37862d324..301606a53 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -31,6 +31,7 @@ def _project_transpose_head(op, input, weight): + """Applied to each of Q, K, and V.""" projected = op.MatMul(input, weight) # Reshape from (B, S, D) to (B, S, H, D/H) reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) From b6f007187f255540a80dd31f52147532f5f24e0a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 13 Nov 2024 11:10:05 -0800 Subject: [PATCH 21/28] Testing changes --- .../onnxruntime/xformers/_optimize_transformers_test.py | 1 + onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py index 73df98914..fef86f1fe 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py @@ -139,6 +139,7 @@ class TestOptimizeTransformers(unittest.TestCase): @parameterized.expand(_test_case_tuples) def test_attention_optimization(self, test_data: _TestData): model = test_data.get_onnx_model() + # io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx")) # model.display() # print("======>") optimize_transformers.fuse_rotary_embedding(model) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 3312e3062..0eadaf280 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -35,7 +35,7 @@ def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): and start2_val == half_dim_size and end2_val >= dim_size ): - return op.RotaryEmbedding(x, cos, sin, _domain="local") + return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="com.microsoft") return None From 91bf47a7b95e4d3d3700a10dda3b6f6f02979ecf Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 15 Nov 2024 13:54:39 -0800 Subject: [PATCH 22/28] Check for dynamic dim --- onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 0eadaf280..43360c4c2 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -28,6 +28,8 @@ def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): if x is None or x.shape is None or len(x.shape) != 4: return None dim_size = x.shape[3] + if not isinstance(dim_size, int): + return None half_dim_size = dim_size // 2 if ( start1_val == 0 From 8b496afbd4ec1dd255196069b847fe35c96652af Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 22 Nov 2024 06:56:31 -0800 Subject: [PATCH 23/28] Debugging --- onnxscript/rewriter/_ir_utils.py | 23 +++++++++++++++++++ .../onnxruntime/xformers/rotary_embedding.py | 6 +++++ onnxscript/rewriter/pattern.py | 2 ++ 3 files changed, 31 insertions(+) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index eadb67f0a..9f47c6680 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -8,6 +8,29 @@ from onnxscript.optimizer import basic_constant_propagation +def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5): + """Display the subgraph computing a given value or node upto a certain depth.""" + slice = [] + def visit(node: ir.Node, depth): + if node in slice: + return + slice.append(node) + if (depth < depth_limit): + if backward: + for inp in node.inputs: + if inp is not None and inp.producer() is not None: + visit(inp.producer(), depth + 1) + else: + for out in node.outputs: + for consumer, _ in out.uses(): + visit(consumer, depth + 1) + if isinstance(x, ir.Node): + visit(x, 0) + elif isinstance(x, ir.Value) and x.producer() is not None: + visit(x.producer(), 0) + for node in reversed(slice): + node.display() + def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: node = value.producer() if node is not None: diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 43360c4c2..1ecafb79b 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -29,6 +29,8 @@ def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): return None dim_size = x.shape[3] if not isinstance(dim_size, int): + import onnxscript.rewriter._ir_utils as ir_utils + ir_utils.display_slice(x) return None half_dim_size = dim_size // 2 if ( @@ -37,6 +39,10 @@ def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): and start2_val == half_dim_size and end2_val >= dim_size ): + import onnxscript.rewriter._ir_utils as ir_utils + ir_utils.display_slice(cos) + ir_utils.display_slice(cos, backward=False) + ir_utils.display_slice(sin) return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="com.microsoft") return None diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 66d9b3196..9a4a6a86a 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1319,6 +1319,8 @@ def try_rewrite( verbose = verbose if verbose is not None else self._verbose match = self._matcher.match(model, graph_or_function, node, verbose=verbose) if match: + for n in reversed(match.nodes): + n.display() context = None # TODO(rama) if not self._condition_function(context, **match.bindings): return None From 4f0cca79fd6b38362a1167bd1e8284926e13d7bd Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 3 Dec 2024 13:45:56 -0800 Subject: [PATCH 24/28] Remove unused import --- onnxscript/optimizer/_constant_folding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index fde8ec418..4053bb2a1 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -13,7 +13,6 @@ import numpy as np import onnx -import onnx.helper import onnx.reference.ops import onnxscript.ir as ir From 585d8bc7c7363dc420993ecb766b3a3aaeeb645a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 5 Dec 2024 15:03:50 -0800 Subject: [PATCH 25/28] Various fixes --- onnxscript/rewriter/_ir_utils.py | 16 ++++++-- .../xformers/_optimize_transformers.py | 14 +++++++ .../xformers/multi_head_attention.py | 7 +++- .../onnxruntime/xformers/rms_normalization.py | 39 +++++++++++++++++-- .../onnxruntime/xformers/rotary_embedding.py | 14 +++---- onnxscript/rewriter/pattern.py | 4 +- 6 files changed, 76 insertions(+), 18 deletions(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 9f47c6680..ec157f419 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -26,10 +26,18 @@ def visit(node: ir.Node, depth): visit(consumer, depth + 1) if isinstance(x, ir.Node): visit(x, 0) - elif isinstance(x, ir.Value) and x.producer() is not None: - visit(x.producer(), 0) - for node in reversed(slice): - node.display() + elif isinstance(x, ir.Value): + if backward and x.producer() is not None: + visit(x.producer(), 0) + elif not backward: + for consumer, _ in x.uses(): + visit(consumer, 0) + if slice: + graph = slice[0].graph + if graph: + for n in graph: + if n in slice: + n.display() def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: node = value.producer() diff --git a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py index 7c5852c72..be2e85b6a 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py @@ -17,6 +17,20 @@ def fuse_rotary_embedding(irmodel: ir.Model) -> None: count = rotary_embedding_rules.apply_to_model(irmodel) print(f"RotaryEmbedding count: {count}") +def fuse_rms_normalization(irmodel: ir.Model) -> None: + count = rms_normalization_rules.apply_to_model(irmodel) + print(f"RMS Normalization count: {count}") + count = skip_normalization_rules.apply_to_model(irmodel) + print(f"Skip Normalization count: {count}") + +def fuse_attention(irmodel: ir.Model) -> None: + count = sdpa_rules.apply_to_model(irmodel) + print(f"SDPA-Attention count: {count}") + +def fuse_mha(irmodel: ir.Model) -> None: + count = mha_rules.apply_to_model(irmodel) + print(f"Multi-Head-Attention count: {count}") + def optimize(irmodel: ir.Model, verbose: int = 0) -> None: def apply(rulename: str, rule): count = rule.apply_to_model(irmodel, verbose=verbose) diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 301606a53..68372f90a 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -32,9 +32,12 @@ def _project_transpose_head(op, input, weight): """Applied to each of Q, K, and V.""" - projected = op.MatMul(input, weight) + input_2d = op.Reshape(input, _allow_other_inputs=True, _allow_other_attributes=True) + projected = op.MatMul(input_2d, weight) + # Reshape into 3D tensor (B, S, D) + reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) # Reshape from (B, S, D) to (B, S, H, D/H) - reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + reshaped = op.Reshape(reshaped_3d, _allow_other_inputs=True, _allow_other_attributes=True) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) return transposed diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index b0527111b..fd01c8ebf 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -17,7 +17,6 @@ def _rms_norm_pattern(op, x, scale, epsilon, compute_dtype, target_dtype): normalized_cast = op.Cast(normalized, to=target_dtype) return op.Mul(scale, normalized_cast) - # Replacement def _simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): epsilon_value = _ir_utils.get_singleton_value(epsilon) @@ -35,6 +34,40 @@ def _simplified_layer_norm(op, x, scale, epsilon, compute_dtype, target_dtype): _domain="com.microsoft", ) - _rule = pattern.RewriteRule(_rms_norm_pattern, _simplified_layer_norm) -rms_normalization_rules = pattern.RewriteRuleSet([_rule]) + + +# Pattern to match against +def _rms_norm_pattern_no_cast(op, x, scale, epsilon): + # x_cast = op.Cast(x, to=compute_dtype) + x_cast = x + x_square = op.Pow(x_cast, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x_cast, reciprocal_rms) + # normalized_cast = op.Cast(normalized, to=target_dtype) + normalized_cast = normalized + return op.Mul(scale, normalized_cast) + +# Replacement +def _simplified_layer_norm_no_cast(op, x, scale, epsilon): + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): + return None + source_dtype = x.dtype + if source_dtype is None: + return None + return op.SimplifiedLayerNormalization( + x, + scale, + axis=-1, + epsilon=epsilon_value, + stash_type=source_dtype.value, + _domain="com.microsoft", + ) + +_rule_no_cast = pattern.RewriteRule(_rms_norm_pattern_no_cast, _simplified_layer_norm_no_cast) + +rms_normalization_rules = pattern.RewriteRuleSet([_rule, _rule_no_cast]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 1ecafb79b..41e3340e3 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -29,8 +29,8 @@ def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): return None dim_size = x.shape[3] if not isinstance(dim_size, int): - import onnxscript.rewriter._ir_utils as ir_utils - ir_utils.display_slice(x) + # import onnxscript.rewriter._ir_utils as ir_utils + # ir_utils.display_slice(x) return None half_dim_size = dim_size // 2 if ( @@ -39,11 +39,11 @@ def _rotary_embedding(op, x, cos, sin, start1, end1, start2, end2): and start2_val == half_dim_size and end2_val >= dim_size ): - import onnxscript.rewriter._ir_utils as ir_utils - ir_utils.display_slice(cos) - ir_utils.display_slice(cos, backward=False) - ir_utils.display_slice(sin) - return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="com.microsoft") + # import onnxscript.rewriter._ir_utils as ir_utils + # ir_utils.display_slice(cos) + # ir_utils.display_slice(cos, backward=False) + # ir_utils.display_slice(sin) + return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="local") return None diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 9a4a6a86a..a7218defe 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1319,8 +1319,8 @@ def try_rewrite( verbose = verbose if verbose is not None else self._verbose match = self._matcher.match(model, graph_or_function, node, verbose=verbose) if match: - for n in reversed(match.nodes): - n.display() + # for n in reversed(match.nodes): + # n.display() context = None # TODO(rama) if not self._condition_function(context, **match.bindings): return None From 555f56f0730936af3eaaf92d69162acb9268ec58 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 5 Dec 2024 15:32:15 -0800 Subject: [PATCH 26/28] Undo MHA change --- .../rewriter/onnxruntime/xformers/multi_head_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 68372f90a..4bb952c87 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -32,12 +32,12 @@ def _project_transpose_head(op, input, weight): """Applied to each of Q, K, and V.""" - input_2d = op.Reshape(input, _allow_other_inputs=True, _allow_other_attributes=True) - projected = op.MatMul(input_2d, weight) + # input_2d = op.Reshape(input, _allow_other_inputs=True, _allow_other_attributes=True) + projected = op.MatMul(input, weight) # Reshape into 3D tensor (B, S, D) - reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + # reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) # Reshape from (B, S, D) to (B, S, H, D/H) - reshaped = op.Reshape(reshaped_3d, _allow_other_inputs=True, _allow_other_attributes=True) + reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) return transposed From 3fc12b8067db362bfab81dfc4c073ccf2b601c27 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 5 Dec 2024 16:06:44 -0800 Subject: [PATCH 27/28] MHA validation part 1 --- .../xformers/multi_head_attention.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 4bb952c87..187620ed5 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations +import onnxscript.ir as ir from onnxscript.rewriter import pattern """ @@ -30,38 +31,49 @@ """ -def _project_transpose_head(op, input, weight): +def _project_transpose_head(op, input, weight, reshape_var: str): """Applied to each of Q, K, and V.""" # input_2d = op.Reshape(input, _allow_other_inputs=True, _allow_other_attributes=True) projected = op.MatMul(input, weight) # Reshape into 3D tensor (B, S, D) # reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) # Reshape from (B, S, D) to (B, S, H, D/H) - reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True, _outputs=[reshape_var]) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) return transposed def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): - query = _project_transpose_head(op, input, query_weight) + query = _project_transpose_head(op, input, query_weight, "query_mm_reshaped") query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") - key = _project_transpose_head(op, input, key_weight) + key = _project_transpose_head(op, input, key_weight, "key_mm_reshaped") key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") # Transpose last two axes of key_rope to compute dot-product via matmul. - key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) + key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]) key_reshaped_transposed = op.Transpose(key_reshaped) - key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) - value = _project_transpose_head(op, input, value_weight) + key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"]) + value = _project_transpose_head(op, input, value_weight, "value_mm_reshaped") attention = op.SDPA( query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" ) # Transpose back to (B, S, H, D/H) attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) - attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]) return attention_reshaped, key_rope, value +def _check_shape(reshaped_value: ir.Value): + print(reshaped_value.shape) + +def _mha_validation(op, query_mm_reshaped, key_mm_reshaped, value_mm_reshaped, key_reshaped, key_transposed, attention_reshaped, **_): + _check_shape(query_mm_reshaped) + _check_shape(key_mm_reshaped) + _check_shape(value_mm_reshaped) + _check_shape(key_reshaped) + _check_shape(key_transposed) + _check_shape(attention_reshaped) + return True def _multi_head_attention_pattern2( op, input, query_weight, key_weight, value_weight, cos, sin @@ -94,6 +106,7 @@ def _multi_head_attention( value_weight, cos, sin, + **_ ): # TODO: other checks and concatenation of weights return op.MultiHeadAttention( @@ -101,7 +114,9 @@ def _multi_head_attention( ) -_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention) -_rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) +_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention, _mha_validation) + +# TODO: _rule2 validation conditions +# _rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) -mha_rules = pattern.RewriteRuleSet([_rule1, _rule2]) +mha_rules = pattern.RewriteRuleSet([_rule1]) From f09567e594e9eeb53803b27f94c18f9dffa8999f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 6 Dec 2024 10:56:58 -0800 Subject: [PATCH 28/28] MHA validation --- .../xformers/multi_head_attention.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py index 187620ed5..8870913f2 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py +++ b/onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations +from typing import Iterable import onnxscript.ir as ir from onnxscript.rewriter import pattern @@ -63,16 +64,34 @@ def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_wei attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]) return attention_reshaped, key_rope, value -def _check_shape(reshaped_value: ir.Value): - print(reshaped_value.shape) +def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool: + if val.shape is None: + return False + if val.shape.rank() != len(shape): + return False + for actual, expected in zip(val.shape, shape): + if expected not in bindings: + bindings[expected] = actual + elif actual != bindings[expected]: + return False + return True def _mha_validation(op, query_mm_reshaped, key_mm_reshaped, value_mm_reshaped, key_reshaped, key_transposed, attention_reshaped, **_): - _check_shape(query_mm_reshaped) - _check_shape(key_mm_reshaped) - _check_shape(value_mm_reshaped) - _check_shape(key_reshaped) - _check_shape(key_transposed) - _check_shape(attention_reshaped) + bindings : dict[str, int] = {} + check = ( + _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) and + _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) and + _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) and + _check_shape(bindings, key_reshaped, ["B*H", "S", "d_h"]) and + _check_shape(bindings, key_transposed, ["B", "H", "d_h", "S"]) and + _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) + ) + if not check: + return False + if bindings["B"] * bindings["H"] != bindings["B*H"]: + return False + if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: + return False return True def _multi_head_attention_pattern2(