Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d751d37
Fusions
gramalingam Oct 23, 2024
8c4dff5
MultiHeadAttention fusion
gramalingam Oct 24, 2024
5d3c9af
Merge branch 'main' into rama/fusions
gramalingam Nov 5, 2024
4d3ff90
Move transformers optimization into onnxruntime folder
gramalingam Nov 6, 2024
4a667f9
Support some SDPA variations
gramalingam Nov 7, 2024
33c3753
Add variations of rules for SDPA
gramalingam Nov 8, 2024
404e5c3
Add attention scale validation
gramalingam Nov 8, 2024
e98682f
Add validation conditions for rotary embedding
gramalingam Nov 8, 2024
001bb59
Add tests
gramalingam Nov 8, 2024
40b9052
Move into new xformers folder
gramalingam Nov 8, 2024
94ce2f3
Add dropout to optimizer
gramalingam Nov 8, 2024
bf3b64a
Run lint
gramalingam Nov 8, 2024
0491366
Undo dropout rewrite rule change
gramalingam Nov 8, 2024
3fb7cd1
Add concat test
gramalingam Nov 8, 2024
f25b669
Merge with main
gramalingam Nov 8, 2024
73723f0
Add expand identity optimization
gramalingam Nov 8, 2024
bb977ec
Some cleanup
gramalingam Nov 8, 2024
7f1606f
Fix dropout optimization
gramalingam Nov 9, 2024
a3e0d1d
Some more cleanup
gramalingam Nov 9, 2024
0879934
Cleanup
gramalingam Nov 9, 2024
a8ac3ee
Minor fixes
gramalingam Nov 13, 2024
044a638
Add ort check to test
gramalingam Nov 13, 2024
b6f0071
Testing changes
gramalingam Nov 13, 2024
b985bb1
Merge branch 'main' into rama/fusions
gramalingam Nov 15, 2024
0f35c45
Merge with main
gramalingam Nov 15, 2024
91bf47a
Check for dynamic dim
gramalingam Nov 15, 2024
fa7ba33
Merge branch 'main' into rama/fusions
gramalingam Nov 16, 2024
8b496af
Debugging
gramalingam Nov 22, 2024
19de794
Merge branch 'main' into rama/fusions
gramalingam Dec 3, 2024
4f0cca7
Remove unused import
gramalingam Dec 3, 2024
585d8bc
Various fixes
gramalingam Dec 5, 2024
555f56f
Undo MHA change
gramalingam Dec 5, 2024
3fc12b8
MHA validation part 1
gramalingam Dec 6, 2024
f09567e
MHA validation
gramalingam Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,66 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation


def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5):
"""Display the subgraph computing a given value or node upto a certain depth."""
slice = []
def visit(node: ir.Node, depth):

Check warning on line 14 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L13-L14

Added lines #L13 - L14 were not covered by tests
if node in slice:
return
slice.append(node)

Check warning on line 17 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L16-L17

Added lines #L16 - L17 were not covered by tests
if (depth < depth_limit):
if backward:
for inp in node.inputs:
if inp is not None and inp.producer() is not None:
visit(inp.producer(), depth + 1)

Check warning on line 22 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L22

Added line #L22 was not covered by tests

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
else:
for out in node.outputs:
for consumer, _ in out.uses():
visit(consumer, depth + 1)

Check warning on line 26 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L26

Added line #L26 was not covered by tests
if isinstance(x, ir.Node):
visit(x, 0)

Check warning on line 28 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L28

Added line #L28 was not covered by tests
elif isinstance(x, ir.Value):
if backward and x.producer() is not None:
visit(x.producer(), 0)

Check warning on line 31 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L31

Added line #L31 was not covered by tests

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
elif not backward:
for consumer, _ in x.uses():
visit(consumer, 0)

Check warning on line 34 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L34

Added line #L34 was not covered by tests
if slice:
graph = slice[0].graph

Check warning on line 36 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L36

Added line #L36 was not covered by tests
if graph:
for n in graph:
if n in slice:
n.display()

Check warning on line 40 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L40

Added line #L40 was not covered by tests

def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
node = value.producer()
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
if val is None:
return None

Check warning on line 51 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L51

Added line #L51 was not covered by tests
const_value = val.const_value
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:

Check warning on line 56 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L56

Added line #L56 was not covered by tests
# External data is not available.
return None
return None

Check warning on line 59 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L58-L59

Added lines #L58 - L59 were not covered by tests


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
return np_val.item()
return None

Check warning on line 67 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L67

Added line #L67 was not covered by tests
17 changes: 17 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import (
mha_rules as mha_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import (
rms_normalization_rules as rms_normalization_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import (
rotary_embedding_rules as rotary_embedding_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import (
skip_normalization_rules as skip_normalization_rules,
)
52 changes: 52 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes
from onnxscript.rewriter.onnxruntime.xformers import (
mha_rules,
rms_normalization_rules,
rotary_embedding_rules,
sdpa_rules,
skip_normalization_rules,
)


def fuse_rotary_embedding(irmodel: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(irmodel)
print(f"RotaryEmbedding count: {count}")

def fuse_rms_normalization(irmodel: ir.Model) -> None:
count = rms_normalization_rules.apply_to_model(irmodel)
print(f"RMS Normalization count: {count}")
count = skip_normalization_rules.apply_to_model(irmodel)
print(f"Skip Normalization count: {count}")

Check warning on line 24 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L21-L24

Added lines #L21 - L24 were not covered by tests

def fuse_attention(irmodel: ir.Model) -> None:
count = sdpa_rules.apply_to_model(irmodel)
print(f"SDPA-Attention count: {count}")

Check warning on line 28 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L27-L28

Added lines #L27 - L28 were not covered by tests

def fuse_mha(irmodel: ir.Model) -> None:
count = mha_rules.apply_to_model(irmodel)
print(f"Multi-Head-Attention count: {count}")

Check warning on line 32 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L31-L32

Added lines #L31 - L32 were not covered by tests

def optimize(irmodel: ir.Model, verbose: int = 0) -> None:
def apply(rulename: str, rule):
count = rule.apply_to_model(irmodel, verbose=verbose)
print(f"{rulename} count: {count}")

Check warning on line 37 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L35-L37

Added lines #L35 - L37 were not covered by tests

fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4)
remove_unused_nodes(irmodel)

Check warning on line 40 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L39-L40

Added lines #L39 - L40 were not covered by tests

apply("RMS Normalization", rms_normalization_rules)
apply("Skip Normalization", skip_normalization_rules)

Check warning on line 43 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L42-L43

Added lines #L42 - L43 were not covered by tests

fold_constants_ir(irmodel)
remove_unused_nodes(irmodel)

Check warning on line 46 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L45-L46

Added lines #L45 - L46 were not covered by tests

apply("SDPA-Attention", sdpa_rules)
apply("RotaryEmbedding", rotary_embedding_rules)
apply("Multi-Head-Attention", mha_rules)

Check warning on line 50 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L48-L50

Added lines #L48 - L50 were not covered by tests

remove_unused_nodes(irmodel)

Check warning on line 52 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py#L52

Added line #L52 was not covered by tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations
Fixed Show fixed Hide fixed

import os
import tempfile
import unittest

import numpy as np
import onnxruntime
import torch
import transformers.models.llama.modeling_llama as modeling_llama
from parameterized import parameterized
from transformers import LlamaConfig

import onnxscript.ir._io as io
import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import (
_optimize_transformers as optimize_transformers,
)

# Create a LlamaConfig object with the desired parameters
_config = LlamaConfig(
_name_or_path="HuggingFaceTB/SmolLM-1.7B",
architectures=["LlamaForCausalLM"],
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=0,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=2048,
model_type="llama",
num_attention_heads=32,
num_hidden_layers=24,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
rope_theta=10000.0,
tie_word_embeddings=True,
torch_dtype="float32",
transformers_version="4.37.2",
use_cache=True,
vocab_size=49152,
)

# Dimensions for inputs:
_batch_size = 1
_seq_len = 10
_hidden_size = _config.hidden_size
_num_attention_heads = _config.num_attention_heads
dim = _hidden_size // _num_attention_heads

# Generate inputs:
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32)
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32))
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len)
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10)

# Get model in ONNX format
# def _get_model(llama_attention_class, with_mask: bool):
# model = llama_attention_class(_config, 0)
# if with_mask:
# inputs = (_hidden_states, _attention_mask, _position_ids)
# else:
# inputs = (_hidden_states, None, _position_ids)
# exported = torch.onnx.export(model, inputs, dynamo=True)
# # ORT Transformer optimizations are applied after basic optimization.
# onnxscript.optimizer.optimize(exported.model)
# return exported.model

class _TestData:
def __init__(self, name: str, attention_class, with_mask: bool):
self.name = name
self.attention_class = attention_class
self.with_mask = with_mask

def get_torch_model(self):
return self.attention_class(_config, 0)

def get_onnx_model(self):
model = self.get_torch_model()
inputs = self.get_inputs()
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None]
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True)
# ORT Transformer optimizations are applied after basic optimization.
onnxscript.optimizer.optimize(exported.model)
return exported.model

def get_inputs(self):
if self.with_mask:
return (_hidden_states, _attention_mask, _position_ids)
else:
return (_hidden_states, None, _position_ids)

def get_torch_outputs(self):
return self.get_torch_model()(*self.get_inputs())

Check warning on line 100 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L100

Added line #L100 was not covered by tests

def get_ort_inputs(self):
inputs = self.get_inputs()
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None}

Check warning on line 104 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L103-L104

Added lines #L103 - L104 were not covered by tests

_test_cases = [
_TestData("attention", modeling_llama.LlamaAttention, False),
_TestData("masked_attention", modeling_llama.LlamaAttention, True),
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False),
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True),
]

_test_case_tuples = [ (t,) for t in _test_cases]

def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)

Check warning on line 119 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L116-L119

Added lines #L116 - L119 were not covered by tests
# Run optimized model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)

Check warning on line 122 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L121-L122

Added lines #L121 - L122 were not covered by tests

for i, (baseline_output, optimized_output) in enumerate(
zip(expected_outputs, ort_outputs)
):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(

Check warning on line 129 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L127-L129

Added lines #L127 - L129 were not covered by tests
baseline_output, optimized_output, rtol=rtol, atol=atol
)
except AssertionError as e:
print(

Check warning on line 133 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L132-L133

Added lines #L132 - L133 were not covered by tests
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
)
raise

Check warning on line 136 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L136

Added line #L136 was not covered by tests

class TestOptimizeTransformers(unittest.TestCase):
@parameterized.expand(_test_case_tuples)
def test_attention_optimization(self, test_data: _TestData):
model = test_data.get_onnx_model()
# io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx"))
# model.display()
# print("======>")
optimize_transformers.fuse_rotary_embedding(model)
# model.display()
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs())


if __name__ == "__main__":
unittest.main()

Check warning on line 153 in onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers_test.py#L153

Added line #L153 was not covered by tests
Loading
Loading