Skip to content

Commit

Permalink
llama 70b model fusion and shardding (#18175)
Browse files Browse the repository at this point in the history
### Description
Support llama-70b model fusion and shardding



### Motivation and Context
This change enables shard and export llama-70b model into Onnx as this
model is too large for single GPU.
This change also fuses llama-70b model with repeat_kv pattern different
with llama-7b and llama-13b.
  • Loading branch information
frank-dong-ms authored Nov 2, 2023
1 parent 178f7ca commit dabd395
Show file tree
Hide file tree
Showing 17 changed files with 1,259 additions and 252 deletions.
1 change: 1 addition & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"GatherElements": self._infer_GatherElements,
"GatherND": self._infer_GatherND,
"Identity": self._pass_on_shape_and_type,
"AllReduce": self._pass_on_shape_and_type,
"If": self._infer_If,
"Loop": self._infer_Loop,
"MatMul": self._infer_MatMul,
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove


def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0):
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1):
past_seq_len = past_seq_len_input
if past_seq_len not in model.get_graphs_input_names():
# Add model input for past sequence length
Expand All @@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
# Replace MultiHeadAttention with GroupQueryAttention
for node in model.model.graph.node:
if node.op_type == "MultiHeadAttention":
num_heads_mha = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
Expand All @@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=node.attribute[0].i,
kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads,
num_heads=num_heads_mha // world_size,
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
is_past_bsnh=0,
)
model.model.graph.node.remove(node)
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,8 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]):
for node in nodes:
if node not in self.nodes_to_remove:
self.nodes_to_remove.append(node)

def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
for node in nodes:
if node not in self.nodes_to_remove and node not in nodes_to_keep:
self.nodes_to_remove.append(node)
Loading

0 comments on commit dabd395

Please sign in to comment.