Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d751d37
Fusions
gramalingam Oct 23, 2024
8c4dff5
MultiHeadAttention fusion
gramalingam Oct 24, 2024
5d3c9af
Merge branch 'main' into rama/fusions
gramalingam Nov 5, 2024
4d3ff90
Move transformers optimization into onnxruntime folder
gramalingam Nov 6, 2024
4a667f9
Support some SDPA variations
gramalingam Nov 7, 2024
33c3753
Add variations of rules for SDPA
gramalingam Nov 8, 2024
404e5c3
Add attention scale validation
gramalingam Nov 8, 2024
e98682f
Add validation conditions for rotary embedding
gramalingam Nov 8, 2024
001bb59
Add tests
gramalingam Nov 8, 2024
40b9052
Move into new xformers folder
gramalingam Nov 8, 2024
94ce2f3
Add dropout to optimizer
gramalingam Nov 8, 2024
bf3b64a
Run lint
gramalingam Nov 8, 2024
0491366
Undo dropout rewrite rule change
gramalingam Nov 8, 2024
3fb7cd1
Add concat test
gramalingam Nov 8, 2024
f25b669
Merge with main
gramalingam Nov 8, 2024
73723f0
Add expand identity optimization
gramalingam Nov 8, 2024
bb977ec
Some cleanup
gramalingam Nov 8, 2024
7f1606f
Fix dropout optimization
gramalingam Nov 9, 2024
a3e0d1d
Some more cleanup
gramalingam Nov 9, 2024
0879934
Cleanup
gramalingam Nov 9, 2024
a8ac3ee
Minor fixes
gramalingam Nov 13, 2024
044a638
Add ort check to test
gramalingam Nov 13, 2024
b6f0071
Testing changes
gramalingam Nov 13, 2024
b985bb1
Merge branch 'main' into rama/fusions
gramalingam Nov 15, 2024
0f35c45
Merge with main
gramalingam Nov 15, 2024
91bf47a
Check for dynamic dim
gramalingam Nov 15, 2024
fa7ba33
Merge branch 'main' into rama/fusions
gramalingam Nov 16, 2024
8b496af
Debugging
gramalingam Nov 22, 2024
19de794
Merge branch 'main' into rama/fusions
gramalingam Dec 3, 2024
4f0cca7
Remove unused import
gramalingam Dec 3, 2024
585d8bc
Various fixes
gramalingam Dec 5, 2024
555f56f
Undo MHA change
gramalingam Dec 5, 2024
3fc12b8
MHA validation part 1
gramalingam Dec 6, 2024
f09567e
MHA validation
gramalingam Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import onnx
import onnx.helper
import onnx.reference.ops

import onnxscript.ir as ir
Expand Down
46 changes: 46 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,58 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation


def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5):
"""Display the subgraph computing a given value or node upto a certain depth."""
slice = []
def visit(node: ir.Node, depth):

Check warning on line 14 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L13-L14

Added lines #L13 - L14 were not covered by tests
if node in slice:
return
slice.append(node)

Check warning on line 17 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L16-L17

Added lines #L16 - L17 were not covered by tests
if (depth < depth_limit):
if backward:
for inp in node.inputs:
if inp is not None and inp.producer() is not None:
visit(inp.producer(), depth + 1)

Check warning on line 22 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L22

Added line #L22 was not covered by tests

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
else:
for out in node.outputs:
for consumer, _ in out.uses():
visit(consumer, depth + 1)

Check warning on line 26 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L26

Added line #L26 was not covered by tests
if isinstance(x, ir.Node):
visit(x, 0)

Check warning on line 28 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L28

Added line #L28 was not covered by tests
elif isinstance(x, ir.Value) and x.producer() is not None:
visit(x.producer(), 0)

Check warning on line 30 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L30

Added line #L30 was not covered by tests
Fixed Show fixed Hide fixed
for node in reversed(slice):
node.display()

Check warning on line 32 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L32

Added line #L32 was not covered by tests

def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
node = value.producer()
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
if val is None:
return None
const_value = val.const_value

Check warning on line 44 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L43-L44

Added lines #L43 - L44 were not covered by tests
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:

Check warning on line 48 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L46-L48

Added lines #L46 - L48 were not covered by tests
# External data is not available.
return None
return None

Check warning on line 51 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L50-L51

Added lines #L50 - L51 were not covered by tests


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)

Check warning on line 56 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L56

Added line #L56 was not covered by tests
if np_val is not None and np_val.size == 1:
return np_val.item()
return None

Check warning on line 59 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L58-L59

Added lines #L58 - L59 were not covered by tests
17 changes: 17 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import (
mha_rules as mha_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import (
rms_normalization_rules as rms_normalization_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import (
rotary_embedding_rules as rotary_embedding_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import (
skip_normalization_rules as skip_normalization_rules,
)
38 changes: 38 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes
from onnxscript.rewriter.onnxruntime.xformers import (
mha_rules,
rms_normalization_rules,
rotary_embedding_rules,
sdpa_rules,
skip_normalization_rules,
)


def fuse_rotary_embedding(irmodel: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(irmodel)
print(f"RotaryEmbedding count: {count}")

Check warning on line 18 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L17-L18

Added lines #L17 - L18 were not covered by tests

def optimize(irmodel: ir.Model, verbose: int = 0) -> None:
def apply(rulename: str, rule):
count = rule.apply_to_model(irmodel, verbose=verbose)
print(f"{rulename} count: {count}")

Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L21-L23

Added lines #L21 - L23 were not covered by tests

fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4)
remove_unused_nodes(irmodel)

Check warning on line 26 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L25-L26

Added lines #L25 - L26 were not covered by tests

apply("RMS Normalization", rms_normalization_rules)
apply("Skip Normalization", skip_normalization_rules)

Check warning on line 29 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L28-L29

Added lines #L28 - L29 were not covered by tests

fold_constants_ir(irmodel)
remove_unused_nodes(irmodel)

Check warning on line 32 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L31-L32

Added lines #L31 - L32 were not covered by tests

apply("SDPA-Attention", sdpa_rules)
apply("RotaryEmbedding", rotary_embedding_rules)
apply("Multi-Head-Attention", mha_rules)

Check warning on line 36 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L34-L36

Added lines #L34 - L36 were not covered by tests

remove_unused_nodes(irmodel)

Check warning on line 38 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L38

Added line #L38 was not covered by tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations
Fixed Show fixed Hide fixed

import os
import tempfile
import unittest

import numpy as np
import onnxruntime
import torch
import transformers.models.llama.modeling_llama as modeling_llama
from parameterized import parameterized
from transformers import LlamaConfig

Check warning on line 14 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L13-L14

Added lines #L13 - L14 were not covered by tests

import onnxscript.ir._io as io
import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import (

Check warning on line 18 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L16-L18

Added lines #L16 - L18 were not covered by tests
_optimize_transformers as optimize_transformers,
)

# Create a LlamaConfig object with the desired parameters
_config = LlamaConfig(

Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L23

Added line #L23 was not covered by tests
_name_or_path="HuggingFaceTB/SmolLM-1.7B",
architectures=["LlamaForCausalLM"],
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=0,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=2048,
model_type="llama",
num_attention_heads=32,
num_hidden_layers=24,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
rope_theta=10000.0,
tie_word_embeddings=True,
torch_dtype="float32",
transformers_version="4.37.2",
use_cache=True,
vocab_size=49152,
)

# Dimensions for inputs:
_batch_size = 1
_seq_len = 10
_hidden_size = _config.hidden_size
_num_attention_heads = _config.num_attention_heads
dim = _hidden_size // _num_attention_heads

Check warning on line 55 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L51-L55

Added lines #L51 - L55 were not covered by tests

# Generate inputs:
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32)
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32))
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len)
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10)

Check warning on line 61 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L58-L61

Added lines #L58 - L61 were not covered by tests

# Get model in ONNX format
# def _get_model(llama_attention_class, with_mask: bool):
# model = llama_attention_class(_config, 0)
# if with_mask:
# inputs = (_hidden_states, _attention_mask, _position_ids)
# else:
# inputs = (_hidden_states, None, _position_ids)
# exported = torch.onnx.export(model, inputs, dynamo=True)
# # ORT Transformer optimizations are applied after basic optimization.
# onnxscript.optimizer.optimize(exported.model)
# return exported.model

class _TestData:
def __init__(self, name: str, attention_class, with_mask: bool):
self.name = name
self.attention_class = attention_class
self.with_mask = with_mask

Check warning on line 79 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L75-L79

Added lines #L75 - L79 were not covered by tests

def get_torch_model(self):
return self.attention_class(_config, 0)

Check warning on line 82 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L81-L82

Added lines #L81 - L82 were not covered by tests

def get_onnx_model(self):
model = self.get_torch_model()
inputs = self.get_inputs()
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None]
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True)

Check warning on line 88 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L84-L88

Added lines #L84 - L88 were not covered by tests
# ORT Transformer optimizations are applied after basic optimization.
onnxscript.optimizer.optimize(exported.model)
return exported.model

Check warning on line 91 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L90-L91

Added lines #L90 - L91 were not covered by tests

def get_inputs(self):

Check warning on line 93 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L93

Added line #L93 was not covered by tests
if self.with_mask:
return (_hidden_states, _attention_mask, _position_ids)

Check warning on line 95 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L95

Added line #L95 was not covered by tests
else:
return (_hidden_states, None, _position_ids)

Check warning on line 97 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L97

Added line #L97 was not covered by tests

def get_torch_outputs(self):
return self.get_torch_model()(*self.get_inputs())

Check warning on line 100 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L99-L100

Added lines #L99 - L100 were not covered by tests

def get_ort_inputs(self):
inputs = self.get_inputs()
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None}

Check warning on line 104 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L102-L104

Added lines #L102 - L104 were not covered by tests

_test_cases = [

Check warning on line 106 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L106

Added line #L106 was not covered by tests
_TestData("attention", modeling_llama.LlamaAttention, False),
_TestData("masked_attention", modeling_llama.LlamaAttention, True),
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False),
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True),
]

_test_case_tuples = [ (t,) for t in _test_cases]

Check warning on line 113 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L113

Added line #L113 was not covered by tests

def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)

Check warning on line 119 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L115-L119

Added lines #L115 - L119 were not covered by tests
# Run optimized model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)

Check warning on line 122 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L121-L122

Added lines #L121 - L122 were not covered by tests

for i, (baseline_output, optimized_output) in enumerate(
zip(expected_outputs, ort_outputs)
):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(

Check warning on line 129 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L127-L129

Added lines #L127 - L129 were not covered by tests
baseline_output, optimized_output, rtol=rtol, atol=atol
)
except AssertionError as e:
print(

Check warning on line 133 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L132-L133

Added lines #L132 - L133 were not covered by tests
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
)
raise

Check warning on line 136 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L136

Added line #L136 was not covered by tests

class TestOptimizeTransformers(unittest.TestCase):
@parameterized.expand(_test_case_tuples)
def test_attention_optimization(self, test_data: _TestData):
model = test_data.get_onnx_model()

Check warning on line 141 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L138-L141

Added lines #L138 - L141 were not covered by tests
# io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx"))
# model.display()
# print("======>")
optimize_transformers.fuse_rotary_embedding(model)

Check warning on line 145 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L145

Added line #L145 was not covered by tests
# model.display()
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)

Check warning on line 148 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L147-L148

Added lines #L147 - L148 were not covered by tests
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs())


if __name__ == "__main__":
unittest.main()

Check warning on line 153 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L153

Added line #L153 was not covered by tests
104 changes: 104 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

Check warning

Code scanning / lintrunner

RUFF/I001 Warning

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports

from onnxscript.rewriter import pattern

"""
The MultiHeadAttention pattern:

B: Batch size
S: Sequence length
D: input embedding dimension
H: number of heads
d_h: head size (usually, D = H * d_h)

thus, weights are usually of shape (D, D) and (D, D) and (D, D)

for each of Q, K, and V, we have the following pattern:
MatMul (Input, W), producing output of shape (B, S, D)
Reshape to produce a matrix of shape (B, S, H, d_h)
Transpose middle two axes to produce a matrix of shape (B, H, S, d_h)

This is followed by a RotaryEmbedding pattern for Q and K

The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence)

The dot-product attention is then computed using SDPA

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Finally, the output is transposed and reshaped back to (B, S, D) shape
"""


def _project_transpose_head(op, input, weight):
"""Applied to each of Q, K, and V."""
projected = op.MatMul(input, weight)
# Reshape from (B, S, D) to (B, S, H, D/H)
reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True)
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3])
return transposed


def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin):
query = _project_transpose_head(op, input, query_weight)
query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local")
key = _project_transpose_head(op, input, key_weight)
key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local")
# Transpose last two axes of key_rope to compute dot-product via matmul.
key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True)
key_reshaped_transposed = op.Transpose(key_reshaped)
key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True)
value = _project_transpose_head(op, input, value_weight)
attention = op.SDPA(
query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local"
)
# Transpose back to (B, S, H, D/H)
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3])
# Reshape back to (B, S, D)
attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True)
return attention_reshaped, key_rope, value


def _multi_head_attention_pattern2(
op, input, query_weight, key_weight, value_weight, cos, sin
):
"""Variation of first pattern with Reshape omitted."""
query = _project_transpose_head(op, input, query_weight)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Missing positional argument "reshape_var" in call to "_project_transpose_head" To disable, use # type: ignore[call-arg]
query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local")
key = _project_transpose_head(op, input, key_weight)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Missing positional argument "reshape_var" in call to "_project_transpose_head" To disable, use # type: ignore[call-arg]
key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local")
# Transpose last two axes of key_rope to compute dot-product via matmul.
# Reshape omitted here.
key_transposed = op.Transpose(key_rope)
# Reshape omitted here
value = _project_transpose_head(op, input, value_weight)

Check failure

Code scanning / lintrunner

MYPY/call-arg Error

Missing positional argument "reshape_var" in call to "_project_transpose_head" To disable, use # type: ignore[call-arg]
attention = op.SDPA(
query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local"
)
# Transpose back to (B, S, H, D/H)
attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3])
# Reshape back to (B, S, D)
attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True)
return attention_reshaped, key_rope, value


def _multi_head_attention(
op,
input,
query_weight,
key_weight,
value_weight,
cos,
sin,
):
# TODO: other checks and concatenation of weights
return op.MultiHeadAttention(

Check warning on line 96 in onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/multi_head_attention.py#L96

Added line #L96 was not covered by tests
input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3
)


_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention)
_rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention)

mha_rules = pattern.RewriteRuleSet([_rule1, _rule2])
Loading
Loading