Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] First version of fusion optimizations for transformers #1938

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d751d37
Fusions
gramalingam Oct 23, 2024
8c4dff5
MultiHeadAttention fusion
gramalingam Oct 24, 2024
5d3c9af
Merge branch 'main' into rama/fusions
gramalingam Nov 5, 2024
4d3ff90
Move transformers optimization into onnxruntime folder
gramalingam Nov 6, 2024
4a667f9
Support some SDPA variations
gramalingam Nov 7, 2024
33c3753
Add variations of rules for SDPA
gramalingam Nov 8, 2024
404e5c3
Add attention scale validation
gramalingam Nov 8, 2024
e98682f
Add validation conditions for rotary embedding
gramalingam Nov 8, 2024
001bb59
Add tests
gramalingam Nov 8, 2024
40b9052
Move into new xformers folder
gramalingam Nov 8, 2024
94ce2f3
Add dropout to optimizer
gramalingam Nov 8, 2024
bf3b64a
Run lint
gramalingam Nov 8, 2024
0491366
Undo dropout rewrite rule change
gramalingam Nov 8, 2024
3fb7cd1
Add concat test
gramalingam Nov 8, 2024
f25b669
Merge with main
gramalingam Nov 8, 2024
73723f0
Add expand identity optimization
gramalingam Nov 8, 2024
bb977ec
Some cleanup
gramalingam Nov 8, 2024
7f1606f
Fix dropout optimization
gramalingam Nov 9, 2024
a3e0d1d
Some more cleanup
gramalingam Nov 9, 2024
0879934
Cleanup
gramalingam Nov 9, 2024
a8ac3ee
Minor fixes
gramalingam Nov 13, 2024
044a638
Add ort check to test
gramalingam Nov 13, 2024
b6f0071
Testing changes
gramalingam Nov 13, 2024
b985bb1
Merge branch 'main' into rama/fusions
gramalingam Nov 15, 2024
0f35c45
Merge with main
gramalingam Nov 15, 2024
91bf47a
Check for dynamic dim
gramalingam Nov 15, 2024
fa7ba33
Merge branch 'main' into rama/fusions
gramalingam Nov 16, 2024
8b496af
Debugging
gramalingam Nov 22, 2024
19de794
Merge branch 'main' into rama/fusions
gramalingam Dec 3, 2024
4f0cca7
Remove unused import
gramalingam Dec 3, 2024
585d8bc
Various fixes
gramalingam Dec 5, 2024
555f56f
Undo MHA change
gramalingam Dec 5, 2024
3fc12b8
MHA validation part 1
gramalingam Dec 6, 2024
f09567e
MHA validation
gramalingam Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import onnx

import onnxscript.optimizer._constant_folding as constant_folding
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import basic_constant_propagation
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import optimize_ir
from onnxscript.optimizer._remove_unused import remove_unused_nodes

basic_constant_propagation = constant_folding.basic_constant_propagation
fold_constants_ir = constant_folding.fold_constants


def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
Expand All @@ -19,8 +22,16 @@ def optimize(model: ir.Model | onnx.ModelProto, *args, **kwargs):
return legacy_optimizer.optimize(model, *args, **kwargs)


def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs):
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
return legacy_constant_folding.fold_constants(model, *args, **kwargs)


__all__ = [
"fold_constants",
"fold_constants_ir",
"remove_unused_nodes",
"optimize",
"optimize_ir",
Expand Down
44 changes: 37 additions & 7 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import onnx
import onnx.helper
import onnx.reference.ops

import onnxscript.ir as ir
Expand Down Expand Up @@ -434,25 +435,54 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
@register("Dropout", version=(12, None))
def dropout(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Dropout by Identity when applicable."""
if len(node.outputs) != 1:
# If output mask is requested, optimization is more complex.
# TODO: handle this case. But unlikely to be needed in practice.
return None

def optimized_dropout():
input = node.inputs[0]
output = op.Identity(input)
if len(node.outputs) == 1:
return output
else:
true_tensor = onnx.helper.make_tensor("true", onnx.TensorProto.BOOL, [1], [True])
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
input_shape = op.Shape(input)
mask = op.ConstantOfShape(input_shape, value=true_tensor)
return output, mask

inputs = node.inputs
if (len(inputs) <= 2) or inputs[2] is None:
# No training_mode specified:
return op.Identity(inputs[0])
return optimized_dropout()
if _get_bool_value(inputs[2]) is False:
# training_mode is False: dropout is not applied.
return op.Identity(inputs[0])
return optimized_dropout()
ratio = _get_numpy_value(inputs[1])
if ratio is None:
return None
if ratio.size != 1: # Only scalar dropout ratio is supported.
return None
if ratio.item() == 0:
# dropout ratio is 0: dropout is not applied.
return op.Identity(inputs[0])
return optimized_dropout()
return None


@register("Expand")
def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Expand node by Identity when applicable."""
if len(node.inputs) != 2:
return None
if (input := node.inputs[0]) is None:
return None
if (input_shape := input.shape) is None:
# Input shape is not known.
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
return None
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
if input_shape.dims == tuple(expanded_shape.tolist()):
return op.Identity(input)
return None


Expand Down
16 changes: 16 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,22 @@ def test_concat_identity(self):
self.assertEqual(len(optimized.graph.node), 1)
self.assertEqual(optimized.graph.node[0].op_type, "Identity")

def test_expand_identity(self):
if not self.using_ir:
self.skipTest("New optimizations not supported for legacy optimizer")
model = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[128, 256] x) => (float[128, 256] z)
{
shape = Constant <value_ints=[128, 256]> ()
z = Expand (x, shape)
}
"""
)
optimized = self._fold(model)
self.assertEqual(optimized.graph.node[-1].op_type, "Identity")


if __name__ == "__main__":
unittest.main()
7 changes: 4 additions & 3 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def inline_calls_in(self, graph: ir.Graph) -> None:

def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()
if model.functions:
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()
23 changes: 23 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import basic_constant_propagation

Expand All @@ -11,3 +13,24 @@ def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
if node is not None:
basic_constant_propagation([node])
return value.const_value


def get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
if val is None:
return None
const_value = val.const_value
if const_value is not None:
try:
return const_value.numpy()
except FileNotFoundError:
# External data is not available.
return None
return None


def get_singleton_value(val: ir.Value | None):
"""Returns element of a single element tensor constant value, and None otherwise."""
np_val = get_numpy_value(val)
if np_val is not None and np_val.size == 1:
return np_val.item()
return None
17 changes: 17 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.multi_head_attention import (
mha_rules as mha_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import (
rms_normalization_rules as rms_normalization_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import (
rotary_embedding_rules as rotary_embedding_rules,
)
from onnxscript.rewriter.onnxruntime.xformers.sdpa import sdpa_rules as sdpa_rules
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import (
skip_normalization_rules as skip_normalization_rules,
)
38 changes: 38 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/_optimize_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.optimizer import fold_constants_ir, remove_unused_nodes
from onnxscript.rewriter.onnxruntime.xformers import (
mha_rules,
rms_normalization_rules,
rotary_embedding_rules,
sdpa_rules,
skip_normalization_rules,
)


def fuse_rotary_embedding(irmodel: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(irmodel)
print(f"RotaryEmbedding count: {count}")

def optimize(irmodel: ir.Model, verbose: int = 0) -> None:
def apply(rulename: str, rule):
count = rule.apply_to_model(irmodel, verbose=verbose)
print(f"{rulename} count: {count}")

fold_constants_ir(irmodel, input_size_limit=5120000 * 4, output_size_limit=5120000 * 4)
remove_unused_nodes(irmodel)

apply("RMS Normalization", rms_normalization_rules)
apply("Skip Normalization", skip_normalization_rules)

fold_constants_ir(irmodel)
remove_unused_nodes(irmodel)

apply("SDPA-Attention", sdpa_rules)
apply("RotaryEmbedding", rotary_embedding_rules)
apply("Multi-Head-Attention", mha_rules)

remove_unused_nodes(irmodel)
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
from __future__ import annotations
Fixed Show fixed Hide fixed

import os
import tempfile
import unittest

import numpy as np
import onnxruntime
import torch
import transformers.models.llama.modeling_llama as modeling_llama
from parameterized import parameterized
from transformers import LlamaConfig

import onnxscript.ir._io as io
import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import (
_optimize_transformers as optimize_transformers,
)

# Create a LlamaConfig object with the desired parameters
_config = LlamaConfig(
_name_or_path="HuggingFaceTB/SmolLM-1.7B",
architectures=["LlamaForCausalLM"],
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=0,
hidden_act="silu",
hidden_size=2048,
initializer_range=0.02,
intermediate_size=8192,
max_position_embeddings=2048,
model_type="llama",
num_attention_heads=32,
num_hidden_layers=24,
num_key_value_heads=32,
pretraining_tp=1,
rms_norm_eps=1e-05,
rope_scaling=None,
rope_theta=10000.0,
tie_word_embeddings=True,
torch_dtype="float32",
transformers_version="4.37.2",
use_cache=True,
vocab_size=49152,
)

# Dimensions for inputs:
_batch_size = 1
_seq_len = 10
_hidden_size = _config.hidden_size
_num_attention_heads = _config.num_attention_heads
dim = _hidden_size // _num_attention_heads

# Generate inputs:
_hidden_states = torch.rand(_batch_size, _seq_len, _hidden_size, dtype=torch.float32)
_causal_mask = torch.tril(torch.ones(_seq_len, _seq_len, dtype=torch.float32))
_attention_mask = _causal_mask.unsqueeze(0).unsqueeze(0).expand(_batch_size, 1, _seq_len, _seq_len)
_position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64).reshape(1, 10)

# Get model in ONNX format
# def _get_model(llama_attention_class, with_mask: bool):
# model = llama_attention_class(_config, 0)
# if with_mask:
# inputs = (_hidden_states, _attention_mask, _position_ids)
# else:
# inputs = (_hidden_states, None, _position_ids)
# exported = torch.onnx.export(model, inputs, dynamo=True)
# # ORT Transformer optimizations are applied after basic optimization.
# onnxscript.optimizer.optimize(exported.model)
# return exported.model

class _TestData:
def __init__(self, name: str, attention_class, with_mask: bool):
self.name = name
self.attention_class = attention_class
self.with_mask = with_mask

def get_torch_model(self):
return self.attention_class(_config, 0)

def get_onnx_model(self):
model = self.get_torch_model()
inputs = self.get_inputs()
input_names = ["input" + str(i) for i in range(len(inputs)) if inputs[i] is not None]
exported = torch.onnx.export(model, inputs, input_names=input_names, dynamo=True)
# ORT Transformer optimizations are applied after basic optimization.
onnxscript.optimizer.optimize(exported.model)
return exported.model

def get_inputs(self):
if self.with_mask:
return (_hidden_states, _attention_mask, _position_ids)
else:
return (_hidden_states, None, _position_ids)

def get_torch_outputs(self):
return self.get_torch_model()(*self.get_inputs())

def get_ort_inputs(self):
inputs = self.get_inputs()
return {f"input{i}": input for i, input in enumerate(inputs) if input is not None}

_test_cases = [
_TestData("attention", modeling_llama.LlamaAttention, False),
_TestData("masked_attention", modeling_llama.LlamaAttention, True),
_TestData("sdpa_attention", modeling_llama.LlamaSdpaAttention, False),
_TestData("masked_sdpa_attention", modeling_llama.LlamaSdpaAttention, True),
]

_test_case_tuples = [ (t,) for t in _test_cases]

def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
# Run optimized model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)

for i, (baseline_output, optimized_output) in enumerate(
zip(expected_outputs, ort_outputs)
):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
np.testing.assert_allclose(
baseline_output, optimized_output, rtol=rtol, atol=atol
)
except AssertionError as e:
print(
f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}"
)
raise

class TestOptimizeTransformers(unittest.TestCase):
@parameterized.expand(_test_case_tuples)
def test_attention_optimization(self, test_data: _TestData):
model = test_data.get_onnx_model()
# model.display()
# print("======>")
optimize_transformers.fuse_rotary_embedding(model)
# model.display()
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)
# _ort_check(test_data.name, model, test_data.get_ort_inputs(), test_data.get_torch_outputs())


if __name__ == "__main__":
unittest.main()
Loading
Loading