-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] First version of fusion optimizations for transformers #1938
Closed
Closed
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
d751d37
Fusions
gramalingam 8c4dff5
MultiHeadAttention fusion
gramalingam 5d3c9af
Merge branch 'main' into rama/fusions
gramalingam 4d3ff90
Move transformers optimization into onnxruntime folder
gramalingam 4a667f9
Support some SDPA variations
gramalingam 33c3753
Add variations of rules for SDPA
gramalingam 404e5c3
Add attention scale validation
gramalingam e98682f
Add validation conditions for rotary embedding
gramalingam 001bb59
Add tests
gramalingam 40b9052
Move into new xformers folder
gramalingam 94ce2f3
Add dropout to optimizer
gramalingam bf3b64a
Run lint
gramalingam 0491366
Undo dropout rewrite rule change
gramalingam 3fb7cd1
Add concat test
gramalingam f25b669
Merge with main
gramalingam 73723f0
Add expand identity optimization
gramalingam bb977ec
Some cleanup
gramalingam 7f1606f
Fix dropout optimization
gramalingam a3e0d1d
Some more cleanup
gramalingam 0879934
Cleanup
gramalingam a8ac3ee
Minor fixes
gramalingam 044a638
Add ort check to test
gramalingam b6f0071
Testing changes
gramalingam b985bb1
Merge branch 'main' into rama/fusions
gramalingam 0f35c45
Merge with main
gramalingam 91bf47a
Check for dynamic dim
gramalingam fa7ba33
Merge branch 'main' into rama/fusions
gramalingam 8b496af
Debugging
gramalingam 19de794
Merge branch 'main' into rama/fusions
gramalingam 4f0cca7
Remove unused import
gramalingam 585d8bc
Various fixes
gramalingam 555f56f
Undo MHA change
gramalingam 3fc12b8
MHA validation part 1
gramalingam f09567e
MHA validation
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import ( | ||
mha_rules as mha_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import ( | ||
rms_normalization_rules as rms_normalization_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import ( | ||
rotary_embedding_rules as rotary_embedding_rules, | ||
) | ||
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import ( | ||
skip_normalization_rules as skip_normalization_rules, | ||
) |
52 changes: 52 additions & 0 deletions
52
onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warning Code scanning / lintrunner RUFF-FORMAT/format Warning
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes | ||
from onnxscript.rewriter.onnxruntime.xformers import ( | ||
mha_rules, | ||
rms_normalization_rules, | ||
rotary_embedding_rules, | ||
sdpa_rules, | ||
skip_normalization_rules, | ||
) | ||
|
||
|
||
def fuse_rotary_embedding(irmodel: ir.Model) -> None: | ||
count = rotary_embedding_rules.apply_to_model(irmodel) | ||
print(f"RotaryEmbedding count: {count}") | ||
|
||
def fuse_rms_normalization(irmodel: ir.Model) -> None: | ||
count = rms_normalization_rules.apply_to_model(irmodel) | ||
print(f"RMS Normalization count: {count}") | ||
count = skip_normalization_rules.apply_to_model(irmodel) | ||
print(f"Skip Normalization count: {count}") | ||
|
||
def fuse_attention(irmodel: ir.Model) -> None: | ||
count = sdpa_rules.apply_to_model(irmodel) | ||
print(f"SDPA-Attention count: {count}") | ||
|
||
def fuse_mha(irmodel: ir.Model) -> None: | ||
count = mha_rules.apply_to_model(irmodel) | ||
print(f"Multi-Head-Attention count: {count}") | ||
|
||
def optimize(irmodel: ir.Model, verbose: int = 0) -> None: | ||
def apply(rulename: str, rule): | ||
count = rule.apply_to_model(irmodel, verbose=verbose) | ||
print(f"{rulename} count: {count}") | ||
|
||
fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4) | ||
remove_unused_nodes(irmodel) | ||
|
||
apply("RMS Normalization", rms_normalization_rules) | ||
apply("Skip Normalization", skip_normalization_rules) | ||
|
||
fold_constants_ir(irmodel) | ||
remove_unused_nodes(irmodel) | ||
|
||
apply("SDPA-Attention", sdpa_rules) | ||
apply("RotaryEmbedding", rotary_embedding_rules) | ||
apply("Multi-Head-Attention", mha_rules) | ||
|
||
remove_unused_nodes(irmodel) |
153 changes: 153 additions & 0 deletions
153
onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
|
||
import os | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
import onnxruntime | ||
import torch | ||
import transformers.models.llama.modeling_llama as modeling_llama | ||
from parameterized import parameterized | ||
from transformers import LlamaConfig | ||
|
||
import onnxscript.ir._io as io | ||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers import ( | ||
_optimize_transformers as optimize_transformers, | ||
) | ||
|
||
# Create a LlamaConfig object with the desired parameters | ||
_config = LlamaConfig( | ||
_name_or_path="HuggingFaceTB/SmolLM-1.7B", | ||
architectures=["LlamaForCausalLM"], | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
bos_token_id=0, | ||
eos_token_id=0, | ||
hidden_act="silu", | ||
hidden_size=2048, | ||
initializer_range=0.02, | ||
intermediate_size=8192, | ||
max_position_embeddings=2048, | ||
model_type="llama", | ||
num_attention_heads=32, | ||
num_hidden_layers=24, | ||
num_key_value_heads=32, | ||
pretraining_tp=1, | ||
rms_norm_eps=1e-05, | ||
rope_scaling=None, | ||
rope_theta=10000.0, | ||
tie_word_embeddings=True, | ||
torch_dtype="float32", | ||
transformers_version="4.37.2", | ||
use_cache=True, | ||
vocab_size=49152, | ||
) | ||
|
||
# Dimensions for inputs: | ||
_batch_size = 1 | ||
_seq_len = 10 | ||
_hidden_size = _config.hidden_size | ||
_num_attention_heads = _config.num_attention_heads | ||
dim = _hidden_size // _num_attention_heads | ||
|
||
# Generate inputs: | ||
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32) | ||
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32)) | ||
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len) | ||
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10) | ||
|
||
# Get model in ONNX format | ||
# def _get_model(llama_attention_class, with_mask: bool): | ||
# model = llama_attention_class(_config, 0) | ||
# if with_mask: | ||
# inputs = (_hidden_states, _attention_mask, _position_ids) | ||
# else: | ||
# inputs = (_hidden_states, None, _position_ids) | ||
# exported = torch.onnx.export(model, inputs, dynamo=True) | ||
# # ORT Transformer optimizations are applied after basic optimization. | ||
# onnxscript.optimizer.optimize(exported.model) | ||
# return exported.model | ||
|
||
class _TestData: | ||
def __init__(self, name: str, attention_class, with_mask: bool): | ||
self.name = name | ||
self.attention_class = attention_class | ||
self.with_mask = with_mask | ||
|
||
def get_torch_model(self): | ||
return self.attention_class(_config, 0) | ||
|
||
def get_onnx_model(self): | ||
model = self.get_torch_model() | ||
inputs = self.get_inputs() | ||
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None] | ||
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True) | ||
# ORT Transformer optimizations are applied after basic optimization. | ||
onnxscript.optimizer.optimize(exported.model) | ||
return exported.model | ||
|
||
def get_inputs(self): | ||
if self.with_mask: | ||
return (_hidden_states, _attention_mask, _position_ids) | ||
else: | ||
return (_hidden_states, None, _position_ids) | ||
|
||
def get_torch_outputs(self): | ||
return self.get_torch_model()(*self.get_inputs()) | ||
|
||
def get_ort_inputs(self): | ||
inputs = self.get_inputs() | ||
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None} | ||
|
||
_test_cases = [ | ||
_TestData("attention", modeling_llama.LlamaAttention, False), | ||
_TestData("masked_attention", modeling_llama.LlamaAttention, True), | ||
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False), | ||
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True), | ||
] | ||
|
||
_test_case_tuples = [ (t,) for t in _test_cases] | ||
|
||
def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): | ||
providers = ["CPUExecutionProvider"] | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
model_path = os.path.join(temp_dir, f"{model_name}.onnx") | ||
io.save(model, model_path) | ||
# Run optimized model | ||
session = onnxruntime.InferenceSession(model_path, providers=providers) | ||
ort_outputs = session.run(None, inputs) | ||
|
||
for i, (baseline_output, optimized_output) in enumerate( | ||
zip(expected_outputs, ort_outputs) | ||
): | ||
try: | ||
np.testing.assert_equal(baseline_output.shape, optimized_output.shape) | ||
np.testing.assert_allclose( | ||
baseline_output, optimized_output, rtol=rtol, atol=atol | ||
) | ||
except AssertionError as e: | ||
print( | ||
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" | ||
) | ||
raise | ||
|
||
class TestOptimizeTransformers(unittest.TestCase): | ||
@parameterized.expand(_test_case_tuples) | ||
def test_attention_optimization(self, test_data: _TestData): | ||
model = test_data.get_onnx_model() | ||
# io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx")) | ||
# model.display() | ||
# print("======>") | ||
optimize_transformers.fuse_rotary_embedding(model) | ||
# model.display() | ||
op_types = [n.op_type for n in model.graph] | ||
self.assertIn("RotaryEmbedding", op_types) | ||
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Check failure
Code scanning / lintrunner
MYPY/arg-type Error