Skip to content

Commit

Permalink
rename input to first_input
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Nov 20, 2024
1 parent 69cbc4a commit 2b61022
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 42 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def create_attention_node(
v_add: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
first_input: str,
output: str,
add_qk_str: str = "",
past_k: str = "",
Expand All @@ -717,7 +717,7 @@ def create_attention_node(
v_add (NodeProto): Add bias node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input (str): input name
first_input (str): first input name
output (str): output name
add_qk_str (str): name of Add node after Q x K'
past_k (str): name of input for past K value
Expand Down Expand Up @@ -863,7 +863,7 @@ def create_attention_node(
)
else:
attention_inputs = [
input,
first_input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias" if has_bias else "",
]
Expand Down Expand Up @@ -1177,7 +1177,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
v_add=add_v,
num_heads=q_num_heads,
hidden_size=q_hidden_size,
input=root_input,
first_input=root_input,
output=attention_last_node.output[0],
add_qk_str=add_qk_str,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
input=root_input,
first_input=root_input,
output=attention_last_node.output[0],
add_qk_str=None,
add_qk_str="",
scale=None,
causal=(add_mask is not None),
)
Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/python/tools/transformers/fusion_bart_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,19 +586,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Temporarily set multihead attention flag to false
use_multi_head_attention_ground_truth = self.use_multi_head_attention
self.use_multi_head_attention = False
add_qk_str = mask_index if decoder_attention and mask_index else ""
new_node = self.create_attention_node(
None,
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
num_heads,
hidden_size,
root_input,
attention_last_node.output[0],
add_qk_str=mask_index if decoder_attention else None,
mask_index=None,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=num_heads,
hidden_size=hidden_size,
first_input=root_input,
output=attention_last_node.output[0],
add_qk_str=add_qk_str,
past_k=past_k,
past_v=past_v,
present_k=present_k,
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/python/tools/transformers/onnx_model_bert_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,17 @@ def fuse_attention(self):
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
logger.debug("Create an Attention node.")
attention_node = self.attention_fusion.create_attention_node(
mask_index,
matmul_q,
matmul_k,
matmul_v,
add_q,
add_k,
add_v,
self.num_heads,
self.hidden_size,
parent.output[0],
reshape_qkv.output[0],
None,
mask_index=mask_index,
q_matmul=matmul_q,
k_matmul=matmul_k,
v_matmul=matmul_v,
q_add=add_q,
k_add=add_k,
v_add=add_v,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
first_input=parent.output[0],
output=reshape_qkv.output[0],
)
if attention_node is None:
continue
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/python/tools/transformers/onnx_model_bert_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,17 @@ def fuse_attention(self):

# For tf models, q and v are flipped.
attention_node = self.attention_fusion.create_attention_node(
mask_index,
matmul_k,
matmul_q,
matmul_v,
add_k,
add_q,
add_v,
self.num_heads,
self.hidden_size,
parent.output[0],
qkv_nodes[2].output[0],
None,
mask_index=mask_index,
q_matmul=matmul_k,
k_matmul=matmul_q,
v_matmul=matmul_v,
q_add=add_k,
k_add=add_q,
v_add=add_v,
num_heads=self.num_heads,
hidden_size=self.hidden_size,
first_input=parent.output[0],
output=qkv_nodes[2].output[0],
)
if attention_node is None:
continue
Expand Down

0 comments on commit 2b61022

Please sign in to comment.