Skip to content

Commit

Permalink
Update attention fusion to support SDPA pattern (#22629)
Browse files Browse the repository at this point in the history
### Description
Match new SDPA pattern for huggingface BERT model that exported from
latest transformers package.

Some changes of transformers tests in CI pipeline:
(1) Enable tests for bert, distilbert and roberta models in CI.
(2) Remove out-of-date tests for huggingface models that were marked as
slow and not enabled in CI pipeline.
(3) Upgrade transformers package version to the latest.

### Motivation and Context

Recent huggingface transformers use torch SDPA in bert modeling. The
graph pattern change causes attention fusion not working anymore. Update
the fusion script to match the new pattern.
  • Loading branch information
tianleiwu authored Nov 21, 2024
1 parent 1e605be commit 55f0559
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 450 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/transformers/bert_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def generate_test_data(
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
dictionary_size: int = 10000,
):
"""Create given number of input data for testing
Expand All @@ -270,7 +271,6 @@ def generate_test_data(
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
with input name as key and a tensor as value
"""
dictionary_size = 10000
all_inputs = fake_test_data(
batch_size,
sequence_length,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/transformers/compare_bert_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def run_test(
segment_ids_name,
input_mask_name,
mask_type,
dictionary_size: int = 1024,
):
# Try deduce input names from optimized model.
input_ids, segment_ids, input_mask = get_bert_inputs(
Expand All @@ -105,6 +106,7 @@ def run_test(
average_sequence_length,
True, # random sequence length
mask_type,
dictionary_size=dictionary_size,
)

baseline_results, baseline_latency, output_names = run_model(
Expand Down
270 changes: 129 additions & 141 deletions onnxruntime/python/tools/transformers/fusion_attention.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
input=root_input,
first_input=root_input,
output=attention_last_node.output[0],
add_qk_str=None,
add_qk_str="",
scale=None,
causal=(add_mask is not None),
)
Expand Down
43 changes: 22 additions & 21 deletions onnxruntime/python/tools/transformers/fusion_bart_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,15 +564,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# value whereas attention supports concatenated past key and past value.
new_node = (
self.create_multihead_attention_node(
matmul_q,
matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
add_q,
add_k if decoder_cross_attention or decoder_attention_with_past else None,
add_v if decoder_cross_attention or decoder_attention_with_past else None,
num_heads,
hidden_size,
attention_last_node.output[0],
q_matmul=matmul_q,
k_matmul=matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
v_matmul=matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
q_add=add_q,
k_add=add_k if decoder_cross_attention or decoder_attention_with_past else None,
v_add=add_v if decoder_cross_attention or decoder_attention_with_past else None,
num_heads=num_heads,
hidden_size=hidden_size,
output=attention_last_node.output[0],
past_k=past_k if decoder_attention_with_past else "",
past_v=past_v if decoder_attention_with_past else "",
present_k=present_k,
Expand All @@ -586,19 +586,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Temporarily set multihead attention flag to false
use_multi_head_attention_ground_truth = self.use_multi_head_attention
self.use_multi_head_attention = False
add_qk_str = mask_index if decoder_attention and mask_index else ""
new_node = self.create_attention_node(
None,
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
root_input,
attention_last_node.output[0],
add_qk_str=mask_index if decoder_attention else None,
mask_index=None,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
first_input=root_input,
output=attention_last_node.output[0],
add_qk_str=add_qk_str,
past_k=past_k,
past_v=past_v,
present_k=present_k,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
return

new_node = self.create_multihead_attention_node(
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
attention_last_node.output[0],
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
output=attention_last_node.output[0],
add_qk=add_qk.input[1],
past_k=past_k,
past_v=past_v,
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/python/tools/transformers/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,13 @@ def validate_and_optimize_onnx(
False,
output_names,
)
if optimize_info == OptimizerInfo.NOOPT:
if optimize_info.name == OptimizerInfo.NOOPT.name:
return onnx_model_path, is_valid_onnx_model, config.vocab_size

if (
optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8
optimize_info.name == OptimizerInfo.BYSCRIPT.name
or precision == Precision.FLOAT16
or precision == Precision.INT8
): # Use script (optimizer.py) to optimize
optimized_model_path = get_onnx_file_path(
onnx_dir,
Expand Down Expand Up @@ -439,7 +441,7 @@ def validate_and_optimize_onnx(
QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format)
logger.info(f"Finished quantizing model: {onnx_model_path}")

if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize
if optimize_info.name == OptimizerInfo.BYORT.name: # Use OnnxRuntime to optimize
if is_valid_onnx_model:
ort_model_path = add_filename_suffix(onnx_model_path, "_ort")
optimize_onnx_model_by_ort(
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/python/tools/transformers/onnx_model_bert_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,17 @@ def fuse_attention(self):
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
logger.debug("Create an Attention node.")
attention_node = self.attention_fusion.create_attention_node(
mask_index,
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
self.num_heads,
self.hidden_size,
parent.output[0],
reshape_qkv.output[0],
None,
mask_index=mask_index,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
first_input=parent.output[0],
output=reshape_qkv.output[0],
)
if attention_node is None:
continue
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/python/tools/transformers/onnx_model_bert_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,17 @@ def fuse_attention(self):

# For tf models, q and v are flipped.
attention_node = self.attention_fusion.create_attention_node(
mask_index,
matmul_k,
matmul_q,
matmul_v,
add_k,
add_q,
add_v,
self.num_heads,
self.hidden_size,
parent.output[0],
qkv_nodes[2].output[0],
None,
mask_index=mask_index,
q_matmul=matmul_k,
k_matmul=matmul_q,
v_matmul=matmul_v,
q_add=add_k,
k_add=add_q,
v_add=add_v,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
first_input=parent.output[0],
output=qkv_nodes[2].output[0],
)
if attention_node is None:
continue
Expand Down
Loading

0 comments on commit 55f0559

Please sign in to comment.