-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] First version of fusion optimizations for transformers #1938
Changes from 28 commits
d751d37
8c4dff5
5d3c9af
4d3ff90
4a667f9
33c3753
404e5c3
e98682f
001bb59
40b9052
94ce2f3
bf3b64a
0491366
3fb7cd1
f25b669
73723f0
bb977ec
7f1606f
a3e0d1d
0879934
a8ac3ee
044a638
b6f0071
b985bb1
0f35c45
91bf47a
fa7ba33
8b496af
19de794
4f0cca7
585d8bc
555f56f
3fc12b8
f09567e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import ( | ||
mha_rules as mha_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import ( | ||
rms_normalization_rules as rms_normalization_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import ( | ||
rotary_embedding_rules as rotary_embedding_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import ( | ||
skip_normalization_rules as skip_normalization_rules, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warning Code scanning / lintrunner RUFF-FORMAT/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes | ||
from onnxscript.rewriter.onnxruntime.xformers import ( | ||
mha_rules, | ||
rms_normalization_rules, | ||
rotary_embedding_rules, | ||
sdpa_rules, | ||
skip_normalization_rules, | ||
) | ||
|
||
|
||
def fuse_rotary_embedding(irmodel: ir.Model) -> None: | ||
count = rotary_embedding_rules.apply_to_model(irmodel) | ||
print(f"RotaryEmbedding count: {count}") | ||
|
||
def optimize(irmodel: ir.Model, verbose: int = 0) -> None: | ||
def apply(rulename: str, rule): | ||
count = rule.apply_to_model(irmodel, verbose=verbose) | ||
print(f"{rulename} count: {count}") | ||
|
||
fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4) | ||
remove_unused_nodes(irmodel) | ||
|
||
apply("RMS Normalization", rms_normalization_rules) | ||
apply("Skip Normalization", skip_normalization_rules) | ||
|
||
fold_constants_ir(irmodel) | ||
remove_unused_nodes(irmodel) | ||
|
||
apply("SDPA-Attention", sdpa_rules) | ||
apply("RotaryEmbedding", rotary_embedding_rules) | ||
apply("Multi-Head-Attention", mha_rules) | ||
|
||
remove_unused_nodes(irmodel) | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
|
||
import os | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
import onnxruntime | ||
import torch | ||
import transformers.models.llama.modeling_llama as modeling_llama | ||
from parameterized import parameterized | ||
from transformers import LlamaConfig | ||
|
||
import onnxscript.ir._io as io | ||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers import ( | ||
_optimize_transformers as optimize_transformers, | ||
) | ||
|
||
# Create a LlamaConfig object with the desired parameters | ||
_config = LlamaConfig( | ||
_name_or_path="HuggingFaceTB/SmolLM-1.7B", | ||
architectures=["LlamaForCausalLM"], | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
bos_token_id=0, | ||
eos_token_id=0, | ||
hidden_act="silu", | ||
hidden_size=2048, | ||
initializer_range=0.02, | ||
intermediate_size=8192, | ||
max_position_embeddings=2048, | ||
model_type="llama", | ||
num_attention_heads=32, | ||
num_hidden_layers=24, | ||
num_key_value_heads=32, | ||
pretraining_tp=1, | ||
rms_norm_eps=1e-05, | ||
rope_scaling=None, | ||
rope_theta=10000.0, | ||
tie_word_embeddings=True, | ||
torch_dtype="float32", | ||
transformers_version="4.37.2", | ||
use_cache=True, | ||
vocab_size=49152, | ||
) | ||
|
||
# Dimensions for inputs: | ||
_batch_size = 1 | ||
_seq_len = 10 | ||
_hidden_size = _config.hidden_size | ||
_num_attention_heads = _config.num_attention_heads | ||
dim = _hidden_size // _num_attention_heads | ||
|
||
# Generate inputs: | ||
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32) | ||
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32)) | ||
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len) | ||
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) | ||
|
||
# Get model in ONNX format | ||
# def _get_model(llama_attention_class, with_mask: bool): | ||
# model = llama_attention_class(_config, 0) | ||
# if with_mask: | ||
# inputs = (_hidden_states, _attention_mask, _position_ids) | ||
# else: | ||
# inputs = (_hidden_states, None, _position_ids) | ||
# exported = torch.onnx.export(model, inputs, dynamo=True) | ||
# # ORT Transformer optimizations are applied after basic optimization. | ||
# onnxscript.optimizer.optimize(exported.model) | ||
# return exported.model | ||
|
||
class _TestData: | ||
def __init__(self, name: str, attention_class, with_mask: bool): | ||
self.name = name | ||
self.attention_class = attention_class | ||
self.with_mask = with_mask | ||
|
||
def get_torch_model(self): | ||
return self.attention_class(_config, 0) | ||
|
||
def get_onnx_model(self): | ||
model = self.get_torch_model() | ||
inputs = self.get_inputs() | ||
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] | ||
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True) | ||
# ORT Transformer optimizations are applied after basic optimization. | ||
onnxscript.optimizer.optimize(exported.model) | ||
return exported.model | ||
|
||
def get_inputs(self): | ||
if self.with_mask: | ||
return (_hidden_states, _attention_mask, _position_ids) | ||
else: | ||
return (_hidden_states, None, _position_ids) | ||
|
||
def get_torch_outputs(self): | ||
return self.get_torch_model()(*self.get_inputs()) | ||
|
||
def get_ort_inputs(self): | ||
inputs = self.get_inputs() | ||
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None} | ||
|
||
_test_cases = [ | ||
_TestData("attention", modeling_llama.LlamaAttention, False), | ||
_TestData("masked_attention", modeling_llama.LlamaAttention, True), | ||
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), | ||
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), | ||
] | ||
|
||
_test_case_tuples = [ (t,) for t in _test_cases] | ||
|
||
def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): | ||
providers = ["CPUExecutionProvider"] | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
model_path = os.path.join(temp_dir, f"{model_name}.onnx") | ||
io.save(model, model_path) | ||
# Run optimized model | ||
session = onnxruntime.InferenceSession(model_path, providers=providers) | ||
ort_outputs = session.run(None, inputs) | ||
|
||
for i, (baseline_output, optimized_output) in enumerate( | ||
zip(expected_outputs, ort_outputs) | ||
): | ||
try: | ||
np.testing.assert_equal(baseline_output.shape, optimized_output.shape) | ||
np.testing.assert_allclose( | ||
baseline_output, optimized_output, rtol=rtol, atol=atol | ||
) | ||
except AssertionError as e: | ||
print( | ||
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" | ||
) | ||
raise | ||
|
||
class TestOptimizeTransformers(unittest.TestCase): | ||
@parameterized.expand(_test_case_tuples) | ||
def test_attention_optimization(self, test_data: _TestData): | ||
model = test_data.get_onnx_model() | ||
# io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx")) | ||
# model.display() | ||
# print("======>") | ||
optimize_transformers.fuse_rotary_embedding(model) | ||
# model.display() | ||
op_types = [n.op_type for n in model.graph] | ||
self.assertIn("RotaryEmbedding", op_types) | ||
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
Check warning Code scanning / lintrunner RUFF/I001 Warning
Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports |
||
|
||
from onnxscript.rewriter import pattern | ||
|
||
""" | ||
The MultiHeadAttention pattern: | ||
|
||
B: Batch size | ||
S: Sequence length | ||
D: input embedding dimension | ||
H: number of heads | ||
d_h: head size (usually, D = H * d_h) | ||
|
||
thus, weights are usually of shape (D, D) and (D, D) and (D, D) | ||
|
||
for each of Q, K, and V, we have the following pattern: | ||
MatMul (Input, W), producing output of shape (B, S, D) | ||
Reshape to produce a matrix of shape (B, S, H, d_h) | ||
Transpose middle two axes to produce a matrix of shape (B, H, S, d_h) | ||
|
||
This is followed by a RotaryEmbedding pattern for Q and K | ||
|
||
The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) | ||
|
||
The dot-product attention is then computed using SDPA | ||
|
||
Check warning Code scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warning Code scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||
Finally, the output is transposed and reshaped back to (B, S, D) shape | ||
""" | ||
|
||
|
||
def _project_transpose_head(op, input, weight): | ||
"""Applied to each of Q, K, and V.""" | ||
projected = op.MatMul(input, weight) | ||
# Reshape from (B, S, D) to (B, S, H, D/H) | ||
reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) | ||
# Transpose from (B, S, H, D/H) to (B, H, S, D/H) | ||
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) | ||
return transposed | ||
|
||
|
||
def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): | ||
query = _project_transpose_head(op, input, query_weight) | ||
query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") | ||
key = _project_transpose_head(op, input, key_weight) | ||
key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") | ||
# Transpose last two axes of key_rope to compute dot-product via matmul. | ||
key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True) | ||
key_reshaped_transposed = op.Transpose(key_reshaped) | ||
key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True) | ||
value = _project_transpose_head(op, input, value_weight) | ||
attention = op.SDPA( | ||
query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" | ||
) | ||
# Transpose back to (B, S, H, D/H) | ||
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) | ||
# Reshape back to (B, S, D) | ||
attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) | ||
return attention_reshaped, key_rope, value | ||
|
||
|
||
def _multi_head_attention_pattern2( | ||
op, input, query_weight, key_weight, value_weight, cos, sin | ||
): | ||
"""Variation of first pattern with Reshape omitted.""" | ||
query = _project_transpose_head(op, input, query_weight) | ||
Check failure Code scanning / lintrunner MYPY/call-arg Error
Missing positional argument "reshape_var" in call to "_project_transpose_head"
To disable, use # type: ignore[call-arg]
|
||
query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") | ||
key = _project_transpose_head(op, input, key_weight) | ||
Check failure Code scanning / lintrunner MYPY/call-arg Error
Missing positional argument "reshape_var" in call to "_project_transpose_head"
To disable, use # type: ignore[call-arg]
|
||
key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") | ||
# Transpose last two axes of key_rope to compute dot-product via matmul. | ||
# Reshape omitted here. | ||
key_transposed = op.Transpose(key_rope) | ||
# Reshape omitted here | ||
value = _project_transpose_head(op, input, value_weight) | ||
Check failure Code scanning / lintrunner MYPY/call-arg Error
Missing positional argument "reshape_var" in call to "_project_transpose_head"
To disable, use # type: ignore[call-arg]
|
||
attention = op.SDPA( | ||
query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" | ||
) | ||
# Transpose back to (B, S, H, D/H) | ||
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) | ||
# Reshape back to (B, S, D) | ||
attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) | ||
return attention_reshaped, key_rope, value | ||
|
||
|
||
def _multi_head_attention( | ||
op, | ||
input, | ||
query_weight, | ||
key_weight, | ||
value_weight, | ||
cos, | ||
sin, | ||
): | ||
# TODO: other checks and concatenation of weights | ||
return op.MultiHeadAttention( | ||
input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3 | ||
) | ||
|
||
|
||
_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention) | ||
_rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) | ||
|
||
mha_rules = pattern.RewriteRuleSet([_rule1, _rule2]) |
Check failure
Code scanning / lintrunner
MYPY/arg-type Error