Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d751d37
Fusions
gramalingam Oct 23, 2024
8c4dff5
MultiHeadAttention fusion
gramalingam Oct 24, 2024
5d3c9af
Merge branch 'main' into rama/fusions
gramalingam Nov 5, 2024
4d3ff90
Move transformers optimization into onnxruntime folder
gramalingam Nov 6, 2024
4a667f9
Support some SDPA variations
gramalingam Nov 7, 2024
33c3753
Add variations of rules for SDPA
gramalingam Nov 8, 2024
404e5c3
Add attention scale validation
gramalingam Nov 8, 2024
e98682f
Add validation conditions for rotary embedding
gramalingam Nov 8, 2024
001bb59
Add tests
gramalingam Nov 8, 2024
40b9052
Move into new xformers folder
gramalingam Nov 8, 2024
94ce2f3
Add dropout to optimizer
gramalingam Nov 8, 2024
bf3b64a
Run lint
gramalingam Nov 8, 2024
0491366
Undo dropout rewrite rule change
gramalingam Nov 8, 2024
3fb7cd1
Add concat test
gramalingam Nov 8, 2024
f25b669
Merge with main
gramalingam Nov 8, 2024
73723f0
Add expand identity optimization
gramalingam Nov 8, 2024
bb977ec
Some cleanup
gramalingam Nov 8, 2024
7f1606f
Fix dropout optimization
gramalingam Nov 9, 2024
a3e0d1d
Some more cleanup
gramalingam Nov 9, 2024
0879934
Cleanup
gramalingam Nov 9, 2024
a8ac3ee
Minor fixes
gramalingam Nov 13, 2024
044a638
Add ort check to test
gramalingam Nov 13, 2024
b6f0071
Testing changes
gramalingam Nov 13, 2024
b985bb1
Merge branch 'main' into rama/fusions
gramalingam Nov 15, 2024
0f35c45
Merge with main
gramalingam Nov 15, 2024
91bf47a
Check for dynamic dim
gramalingam Nov 15, 2024
fa7ba33
Merge branch 'main' into rama/fusions
gramalingam Nov 16, 2024
8b496af
Debugging
gramalingam Nov 22, 2024
19de794
Merge branch 'main' into rama/fusions
gramalingam Dec 3, 2024
4f0cca7
Remove unused import
gramalingam Dec 3, 2024
585d8bc
Various fixes
gramalingam Dec 5, 2024
555f56f
Undo MHA change
gramalingam Dec 5, 2024
3fc12b8
MHA validation part 1
gramalingam Dec 6, 2024
f09567e
MHA validation
gramalingam Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,66 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation


def display_slice(x: ir.Value | ir.Node, backward: bool = True, depth_limit: int = 5):
"""Display the subgraph computing a given value or node upto a certain depth."""
slice = []
def visit(node: ir.Node, depth):
if node in slice:
return
slice.append(node)
if (depth < depth_limit):
if backward:
for inp in node.inputs:
if inp is not None and inp.producer() is not None:
visit(inp.producer(), depth + 1)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
else:
for out in node.outputs:
for consumer, _ in out.uses():
visit(consumer, depth + 1)
if isinstance(x, ir.Node):
visit(x, 0)
elif isinstance(x, ir.Value):
if backward and x.producer() is not None:
visit(x.producer(), 0)

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 1 to "visit" has incompatible type "Node | None"; expected "Node" To disable, use # type: ignore[arg-type]
elif not backward:
for consumer, _ in x.uses():
visit(consumer, 0)
if slice:
graph = slice[0].graph
if graph:
for n in graph:
if n in slice:
n.display()

def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
node = value.producer()
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
if val is None:
return None
const_value = val.const_value
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:
# External data is not available.
return None
return None


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
return np_val.item()
return None
17 changes: 17 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import (
mha_rules as mha_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import (
rms_normalization_rules as rms_normalization_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import (
rotary_embedding_rules as rotary_embedding_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import (
skip_normalization_rules as skip_normalization_rules,
)
52 changes: 52 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes
from onnxscript.rewriter.onnxruntime.xformers import (
mha_rules,
rms_normalization_rules,
rotary_embedding_rules,
sdpa_rules,
skip_normalization_rules,
)


def fuse_rotary_embedding(irmodel: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(irmodel)
print(f"RotaryEmbedding count: {count}")

def fuse_rms_normalization(irmodel: ir.Model) -> None:
count = rms_normalization_rules.apply_to_model(irmodel)
print(f"RMS Normalization count: {count}")
count = skip_normalization_rules.apply_to_model(irmodel)
print(f"Skip Normalization count: {count}")

def fuse_attention(irmodel: ir.Model) -> None:
count = sdpa_rules.apply_to_model(irmodel)
print(f"SDPA-Attention count: {count}")

def fuse_mha(irmodel: ir.Model) -> None:
count = mha_rules.apply_to_model(irmodel)
print(f"Multi-Head-Attention count: {count}")

def optimize(irmodel: ir.Model, verbose: int = 0) -> None:
def apply(rulename: str, rule):
count = rule.apply_to_model(irmodel, verbose=verbose)
print(f"{rulename} count: {count}")

fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4)
remove_unused_nodes(irmodel)

apply("RMS Normalization", rms_normalization_rules)
apply("Skip Normalization", skip_normalization_rules)

fold_constants_ir(irmodel)
remove_unused_nodes(irmodel)

apply("SDPA-Attention", sdpa_rules)
apply("RotaryEmbedding", rotary_embedding_rules)
apply("Multi-Head-Attention", mha_rules)

remove_unused_nodes(irmodel)
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations
Fixed Show fixed Hide fixed

import os
import tempfile
import unittest

import numpy as np
import onnxruntime
import torch
import transformers.models.llama.modeling_llama as modeling_llama
from parameterized import parameterized
from transformers import LlamaConfig

import onnxscript.ir._io as io
import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import (
_optimize_transformers as optimize_transformers,
)

# Create a LlamaConfig object with the desired parameters
_config = LlamaConfig(
_name_or_path="HuggingFaceTB/SmolLM-1.7B",
architectures=["LlamaForCausalLM"],
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=0,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=2048,
model_type="llama",
num_attention_heads=32,
num_hidden_layers=24,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
rope_theta=10000.0,
tie_word_embeddings=True,
torch_dtype="float32",
transformers_version="4.37.2",
use_cache=True,
vocab_size=49152,
)

# Dimensions for inputs:
_batch_size = 1
_seq_len = 10
_hidden_size = _config.hidden_size
_num_attention_heads = _config.num_attention_heads
dim = _hidden_size // _num_attention_heads

# Generate inputs:
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32)
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32))
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len)
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10)

# Get model in ONNX format
# def _get_model(llama_attention_class, with_mask: bool):
# model = llama_attention_class(_config, 0)
# if with_mask:
# inputs = (_hidden_states, _attention_mask, _position_ids)
# else:
# inputs = (_hidden_states, None, _position_ids)
# exported = torch.onnx.export(model, inputs, dynamo=True)
# # ORT Transformer optimizations are applied after basic optimization.
# onnxscript.optimizer.optimize(exported.model)
# return exported.model

class _TestData:
def __init__(self, name: str, attention_class, with_mask: bool):
self.name = name
self.attention_class = attention_class
self.with_mask = with_mask

def get_torch_model(self):
return self.attention_class(_config, 0)

def get_onnx_model(self):
model = self.get_torch_model()
inputs = self.get_inputs()
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None]
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True)
# ORT Transformer optimizations are applied after basic optimization.
onnxscript.optimizer.optimize(exported.model)
return exported.model

def get_inputs(self):
if self.with_mask:
return (_hidden_states, _attention_mask, _position_ids)
else:
return (_hidden_states, None, _position_ids)

def get_torch_outputs(self):
return self.get_torch_model()(*self.get_inputs())

def get_ort_inputs(self):
inputs = self.get_inputs()
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None}

_test_cases = [
_TestData("attention", modeling_llama.LlamaAttention, False),
_TestData("masked_attention", modeling_llama.LlamaAttention, True),
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False),
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True),
]

_test_case_tuples = [ (t,) for t in _test_cases]

def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
# Run optimized model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)

for i, (baseline_output, optimized_output) in enumerate(
zip(expected_outputs, ort_outputs)
):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(
baseline_output, optimized_output, rtol=rtol, atol=atol
)
except AssertionError as e:
print(
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
)
raise

class TestOptimizeTransformers(unittest.TestCase):
@parameterized.expand(_test_case_tuples)
def test_attention_optimization(self, test_data: _TestData):
model = test_data.get_onnx_model()
# io.save(model, os.path.join(r"C:\repos\onnxscript\smy\Models", f"{test_data.name}.onnx"))
# model.display()
# print("======>")
optimize_transformers.fuse_rotary_embedding(model)
# model.display()
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs())


if __name__ == "__main__":
unittest.main()
Loading
Loading