diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9b68aef57656e..81ae3eb9e7e5d 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -147,6 +147,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatherElements": self._infer_GatherElements, "GatherND": self._infer_GatherND, "Identity": self._pass_on_shape_and_type, + "AllReduce": self._pass_on_shape_and_type, "If": self._infer_If, "Loop": self._infer_Loop, "MatMul": self._infer_MatMul, diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index b32ae64c5b0c0..7aca5e8526a23 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): +def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1): past_seq_len = past_seq_len_input if past_seq_len not in model.get_graphs_input_names(): # Add model input for past sequence length @@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads # Replace MultiHeadAttention with GroupQueryAttention for node in model.model.graph.node: if node.op_type == "MultiHeadAttention": + num_heads_mha = 0 + for att in node.attribute: + if att.name == "num_heads": + num_heads_mha = att.i gqa_node = onnx.helper.make_node( "GroupQueryAttention", inputs=[ @@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", - num_heads=node.attribute[0].i, - kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + num_heads=num_heads_mha // world_size, + kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, is_past_bsnh=0, ) model.model.graph.node.remove(node) diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index c5d7bc16d64f7..67f4f0b55cff8 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -130,3 +130,8 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]): for node in nodes: if node not in self.nodes_to_remove: self.nodes_to_remove.append(node) + + def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]): + for node in nodes: + if node not in self.nodes_to_remove and node not in nodes_to_keep: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py index 44d15b619ec7a..ceee836e33f77 100644 --- a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -323,6 +323,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # qkv_nodes_1 is for LLaMA-2 Microsoft # qkv_nodes_2 is for LLaMA-2 Hugging Face + # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model qkv_nodes = None qkv_nodes_1 = self.model.match_parent_path( normalize_node, @@ -334,18 +335,27 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["MatMul", "Reshape", "Transpose", "MatMul"], [1, 0, 0, 0], ) + qkv_nodes_3 = self.model.match_parent_path( + normalize_node, + ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0, 0], + ) if qkv_nodes_1 is not None: _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 qkv_nodes = qkv_nodes_1 elif qkv_nodes_2 is not None: _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 qkv_nodes = qkv_nodes_2 + elif qkv_nodes_3 is not None: + _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3 + qkv_nodes = qkv_nodes_3 else: logger.debug("fuse_rotary_attention: failed to match qkv nodes") return # v_nodes_1 is for LLaMA-2 Microsoft # v_nodes_3 is for LLaMA-2 Hugging Face + # v_nodes_4 is for LLaMA-2 70B model past_v, present_v, past_seq_len = "", "", "" v_nodes = None v_nodes_1 = self.model.match_parent_path( @@ -363,6 +373,118 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "MatMul"], [1, 0, 0], ) + _, v_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qkv, + [ + ( + ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + [ + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 1, 1, 0, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 2, 0, 0, 0, 1, 0, 0], + ), + ( + ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 3, 0, 0, 0, 1, 0, 0], + ), + ], + output_name_to_node=None, + ) if v_nodes_1 is not None: reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 v_nodes = v_nodes_1 @@ -388,6 +510,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): transpose_v, reshape_v, matmul_v = v_nodes_3 v_nodes = v_nodes_3 present_v = transpose_v.output[0] + elif v_nodes_4 is not None and len(v_nodes_4) == 9: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:] + v_nodes = v_nodes_4 + past_v = concat_v.input[0] + present_v = concat_v.output[0] else: logger.debug("fuse_rotary_attention: failed to match v path") return @@ -461,6 +588,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # k_nodes_1 is for LLaMA-2 Microsoft # k_nodes_2 is for LLaMA-2 Hugging Face + # k_nodes_4 is for LLaMA-2 70B Hugging Face past_k, present_k = "", "" k_nodes = None k_nodes_1 = self.model.match_parent_path( @@ -478,6 +606,174 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0, 0], ) + _, k_nodes_4, _ = self.model.match_parent_paths_all( + matmul_qk, + [ + ( + [ + "Transpose", + "Reshape", + "Expand", + "Unsqueeze", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Equal", + "Mul", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "ConstantOfShape", + "Shape", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Expand", + "Where", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Mul", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0], + ), + ( + [ + "Transpose", + "Reshape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + "Concat", + "RotaryEmbedding", + "Transpose", + "Reshape", + "MatMul", + ], + [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0], + ), + ], + output_name_to_node=None, + ) if k_nodes_1 is not None: reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 k_nodes = k_nodes_1 @@ -505,6 +801,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): k_nodes = k_nodes_3 past_k = concat_k.input[0] present_k = concat_k.output[0] + elif k_nodes_4 is not None and len(k_nodes_4) == 9: + reshape_k, matmul_k = k_nodes_4[0][-2:] + concat_k, rotary_k = k_nodes_4[0][-5:-3] + k_nodes = k_nodes_4 + past_k = concat_k.input[0] + present_k = concat_k.output[0] else: logger.debug("fuse_rotary_attention: failed to match k nodes") return @@ -552,7 +854,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return root_output = reshape_qkv_2.output[0] - elif qkv_nodes == qkv_nodes_2: + elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3): if not self.check_runtime_shape_paths_for_nodes( reshape_qkv, reshape_q, @@ -573,6 +875,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) rotary_k.output[0] = rotary_k.name + "_output_0" + if qkv_nodes == qkv_nodes_3: + qkv_nodes = qkv_nodes[1:] + new_node = self.create_mha_node( matmul_q.input[0], root_output, @@ -594,7 +899,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[new_node.name] = self.this_graph_name self.nodes_to_remove.extend(qkv_nodes[1:]) - self.nodes_to_remove.extend(v_nodes[:-1]) + + if v_nodes != v_nodes_4: + self.nodes_to_remove.extend(v_nodes[:-1]) + else: + nodes_to_keep = [v_nodes[0][-1]] + for temp_path in v_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) + self.nodes_to_remove.extend(qk_nodes) if k_nodes == k_nodes_1: @@ -608,6 +920,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.append(k_nodes[1]) self.nodes_to_remove.append(k_nodes[3]) self.nodes_to_remove.append(k_nodes[4]) + elif k_nodes == k_nodes_4: + nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]] + for temp_path in k_nodes: + self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep) if q_nodes == q_nodes_1: self.nodes_to_remove.extend(q_nodes[:-2]) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 9619e6cb52a91..1bb6940d1cd74 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -10,6 +10,8 @@ Please note the package versions needed for using LLaMA-2 in the `requirements.t - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. - `requirements-quant.txt` - For running the SmoothQuant algorithm using [Intel's Neural Compressor](https://github.com/intel/neural-compressor) +- `requirements-70b-model.txt` + - For running the LLaMA-2 70B model on multiple GPUs - `requirements.txt` - Package versions needed in each of the above files @@ -79,6 +81,15 @@ model.save_pretrained(name.split("/")[-1] + "-onnx") Here are some additional examples for exporting LLaMA-2. +Export Model with Different GPU Device Ids +``` +# From source using first GPU: +$ CUDA_VISIBLE_DEVICES=0 python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b + +# From wheel using second GPU: +$ CUDA_VISIBLE_DEVICES=1 python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b +``` + Export Saved Model on Disk ``` # From source: @@ -153,6 +164,19 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu ``` +Export LLaMA-2 70B sharded model into 4 partitions +``` +# From source: +# 1. Install necessary packages from requirements-70b-model.txt + +# 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: +$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ + +# 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: +$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda + +``` + ## Benchmark LLaMA-2 Here are some examples of how you can benchmark LLaMA-2. @@ -220,11 +244,11 @@ python3 -m models.llama.benchmark \ --device cuda ``` -6. ONNX Runtime, FP32, convert_to_onnx +6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -232,11 +256,11 @@ python3 -m models.llama.benchmark \ --device cpu ``` -7. ONNX Runtime, FP16, convert_to_onnx +7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU ``` -python3 -m models.llama.benchmark \ +CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ - --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --ort-model-path ./llama2-7b/rank_0_Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index 245ff3dfe7f9d..be678931de5d1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,6 +11,7 @@ import onnx import psutil import torch +from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings, get_merged_sample_with_past_kv_inputs, @@ -133,6 +134,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): use_fp16=args.use_fp16, engine="ort", return_dict=True, + world_size=args.world_size, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, @@ -144,6 +146,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): use_fp16=args.use_fp16, engine="ort", return_dict=True, + world_size=args.world_size, ) elif args.benchmark_type == "ort-msft": @@ -244,10 +247,10 @@ def get_model(args: argparse.Namespace): if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx - logger.info(f"Loading model from {args.ort_model_path}") + logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}") start_time = time.time() model = ort.InferenceSession( - args.ort_model_path, + args.ort_model_path.format(args.rank), sess_options, providers=[args.execution_provider], ) @@ -315,10 +318,11 @@ def time_fn(args, fn, inputs): latency = total_time / args.num_runs throughput = args.batch_size / latency - logger.info(f"Batch Size: {args.batch_size}") - logger.info(f"Sequence Length: {args.sequence_length}") - logger.info(f"Latency: {latency} s") - logger.info(f"Throughput: {throughput} tps") + if args.rank == 0: + logger.info(f"Batch Size: {args.batch_size}") + logger.info(f"Sequence Length: {args.sequence_length}") + logger.info(f"Latency: {latency} s") + logger.info(f"Throughput: {throughput} tps") return @@ -358,7 +362,8 @@ def measure_fn(args, fn, inputs): process.cpu_percent(interval=0.1) fn(inputs) - logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%") + if args.rank == 0: + logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%") # Measure memory usage gc.collect() @@ -451,7 +456,7 @@ def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings( - model, inputs, args.device, int(args.device_id), kv_cache_ortvalues + model, inputs, args.device, int(args.rank), kv_cache_ortvalues ) setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding, kv_cache_ortvalues @@ -511,7 +516,7 @@ def run_inference(args, init_inputs, iter_inputs, model): raise Exception(f"Cannot recognize {args.benchmark_type}") -def get_args(): +def get_args(rank=0): parser = argparse.ArgumentParser() parser.add_argument( "-bt", @@ -569,7 +574,7 @@ def get_args(): parser.add_argument( "-s", "--sequence-lengths", - default="8 16 32 64 128 256 512", + default="32 64 128 256 512", ) parser.add_argument( "-d", @@ -606,9 +611,9 @@ def get_args(): if "ort" in args.benchmark_type: setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 if args.execution_provider == "CUDAExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) elif args.execution_provider == "ROCMExecutionProvider": - args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) + args.execution_provider = (args.execution_provider, {"device_id": rank}) args.device = "cuda" # Check that paths have been specified for any benchmarking with ORT @@ -635,14 +640,19 @@ def get_args(): def main(): - args = get_args() + rank = get_rank() + world_size = get_size() + + args = get_args(rank) setup_logger(args.verbose) logger.info(args.__dict__) torch.backends.cudnn.benchmark = True + args.rank = rank + args.world_size = world_size tokenizer = LlamaTokenizer.from_pretrained(args.model_name) config = LlamaConfig.from_pretrained(args.model_name) - target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device + target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device use_fp16 = args.precision == "fp16" setattr(args, "tokenizer", tokenizer) # noqa: B010 @@ -656,7 +666,7 @@ def main(): # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: - onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False) + onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False) gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" @@ -666,7 +676,8 @@ def main(): # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): - logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") + if args.rank == 0: + logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh new file mode 100644 index 0000000000000..38f1916456658 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python benchmark.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 951b2549368f7..b35a5e27f9ea3 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -247,6 +247,7 @@ def main(): torch.backends.cudnn.benchmark = True all_results = [] + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # Benchmark PyTorch without torch.compile if args.hf_pt_eager: @@ -266,8 +267,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -298,8 +297,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -332,8 +329,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -366,8 +361,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", @@ -399,8 +392,6 @@ def main(): args.sequence_lengths, "--device", args.device, - "--device-id", - str(args.device_id), "--warmup-runs", str(args.warmup_runs), "--num-runs", diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh new file mode 100644 index 0000000000000..637d15c10e0c7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/convert_70b_model.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +NUM_GPUS=${1:-1} + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode $NUM_GPUS --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python convert_to_onnx.py ${@:2}" + +$CMD \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 3f05be53c6729..b0e0b41e75d3d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -1,16 +1,16 @@ import argparse import logging import os -import tempfile +import shutil from itertools import chain from typing import List import onnx import torch -from benchmark_helper import Precision, prepare_environment, setup_logger -from convert_generation import replace_mha_with_gqa +from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check +from llama_torch import setup_torch_model from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -18,8 +18,11 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger +from onnxruntime.transformers.convert_generation import replace_mha_with_gqa logger = logging.getLogger("") +init_dist() def get_model_dynamic_axes(input_names: List[str], output_names: List[str]): @@ -129,7 +132,9 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st # del onnx_model # temp_dir.cleanup() # -def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_dynamo_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): from torch._dynamo import config config.capture_scalar_outputs = True @@ -150,9 +155,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -160,7 +165,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll # Export decoder_with_past_model.onnx input_ids, attn_mask, pos_ids, past_kv = get_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length + l_config, device, batch_size, sequence_length, world_size=world_size ) temp_dir = args.output # tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir, "temp.onnx") # os.path.join(temp_dir.name, "temp.onnx") @@ -172,9 +177,9 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) - save_onnx_model(onnx_model, output_path, f"{args.model_name}_decoder_with_past_model_fp32.onnx.data") + save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data") del onnx_model os.system( f"rm {os.path.join(temp_dir, 'model.*')} && rm {os.path.join(temp_dir, '*.weight')} && rm {temp_path}" @@ -183,10 +188,21 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") -def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def _prepare_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + +def run_torchscript_separate_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") # Export decoder_model.onnx decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length) @@ -199,8 +215,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_inputs, @@ -218,18 +238,25 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) # Export decoder_with_past_model.onnx - decoder_with_past_inputs = get_sample_with_past_kv_inputs(l_config, device, batch_size, sequence_length) + decoder_with_past_inputs = get_sample_with_past_kv_inputs( + l_config, + device, + batch_size, + sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, + ) input_names = [ "input_ids", "attention_mask", @@ -247,8 +274,12 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon ), ] dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_past_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_with_past_inputs, @@ -266,27 +297,45 @@ def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaCon onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_with_past_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info( + f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!" + ) -def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_merged_export( + args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 +): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 - device = torch.device("cpu") + + # set device used to export model + # for llama-2-70b we will use current gpus to speed up export process + # for other models, we will use CPU to make sure we have enough memory to do export + device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") + + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 # Export decoder_merged_model.onnx decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( - l_config, device, batch_size, sequence_length, past_sequence_length + l_config, + device, + batch_size, + sequence_length, + past_sequence_length, + max_seq_len=max_sequence_length, + use_fp16=args.precision == Precision.FLOAT16, + world_size=world_size, ) input_names = [ "input_ids", @@ -305,8 +354,12 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi ), ] dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp.onnx") + + # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large. + # Use temp folder per rank to avoid race condition here. + temp_dir = f"./temp_{rank}" + _prepare_dir(temp_dir) + temp_path = os.path.join(temp_dir, "temp.onnx") torch.onnx.export( llama, args=decoder_merged_inputs, @@ -324,17 +377,17 @@ def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfi onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) - output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx") onnx_model = onnx.load_model(temp_path, load_external_data=True) save_onnx_model( onnx_model, output_path, - f"{args.model_name}_decoder_merged_model_fp32.onnx.data", + f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data", ) del onnx_model - temp_dir.cleanup() + shutil.rmtree(temp_dir) - logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!") # Optimize the model as FP32 @@ -357,12 +410,16 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str): remove_existing_model(input_path) -def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]): - decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") +def convert_to_float16( + args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 +): + decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx" + ) + decoder_merged_model_fp16_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx" ) - decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx") new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] logger.info("Converting to float16...") @@ -370,7 +427,7 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: if os.path.exists(fp32_path): model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False) - model = use_group_query_attention(config, model) + model = use_group_query_attention(config, model, world_size) model.save_model_to_file(fp16_path, use_external_data_format=True) del model logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") @@ -380,9 +437,11 @@ def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel): +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1): # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes - fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads) + fp16_model_opt = replace_mha_with_gqa( + fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size + ) fp16_model_opt.prune_graph() fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt @@ -406,7 +465,7 @@ def smooth_quant( calibration_sampling_size=[args.calibration_sampling_size], recipes={ "optypes_to_exclude_output_quant": ["MatMul"], - "smooth_quant": args.smooth_quant, + "smooth_quant": True, "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, }, op_type_dict={ @@ -526,15 +585,6 @@ def get_args(): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-r", "--reexport", @@ -655,6 +705,14 @@ def get_args(): ) parser.set_defaults(use_dynamo_export=False) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() return args @@ -673,144 +731,182 @@ def main(): remove_existing_files(args.output) logger.info(f"Arguments: {args}") + world_size = get_size() + rank = get_rank() + # Load model and config use_auth_token = args.input == os.path.join(".") setattr(args, "use_auth_token", use_auth_token) # noqa: B010 - location = args.model_name if use_auth_token else args.input - l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True) original_model_name = args.model_name setattr(args, "original_model_name", original_model_name) # noqa: B010 args.model_name = args.model_name.split("/")[-1] - # Set model paths for FP32 model - decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") - decoder_with_past_model_fp32_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx" - ) - decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - - missing_separate_exports = ( - args.no_merged - and not os.path.exists(decoder_model_fp32_path) - and not os.path.exists(decoder_with_past_model_fp32_path) - ) - missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) - - # Export to ONNX - if missing_separate_exports or missing_merged_export: - if args.use_dynamo_export and missing_separate_exports: - logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") - logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") - logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") - logger.warning( - "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" - ) - logger.warning( - "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." - ) - run_dynamo_export(args, l_config, llama) - elif args.no_merged: - run_torchscript_separate_export(args, l_config, llama) - else: - run_torchscript_merged_export(args, l_config, llama) - del llama # Delete LLaMA model from memory since it will be loaded again during parity check + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 - # Set model paths to store FP32 optimized model - decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") - decoder_with_past_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx" - ) - decoder_merged_model_fp32_opt_path = os.path.join( - args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx" - ) - new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + location = args.original_model_name if use_auth_token else args.input - # Run the optimizer script - logger.info("Optimizing models...") - for orig_path, opt_path in zip(old_paths, new_paths): - if os.path.exists(orig_path): - optimize_export(l_config, input_path=orig_path, output_path=opt_path) + # use cuda for Llama-2-70b to speedup export, other models use CPU by default + l_config, llama = setup_torch_model( + args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None + ) - # Re-assign default FP32 model paths as their optimized versions - decoder_model_fp32_path = decoder_model_fp32_opt_path - decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path - decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path - old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0 - logger.info( - f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" - ) - - # Change precision of exported models from FP32 - if args.precision == Precision.FLOAT16: - new_paths = convert_to_float16(args, l_config, old_paths) - - elif args.precision == Precision.INT8: - decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx") - decoder_with_past_model_int8_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx" - ) - decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx") - new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] - - if args.quantization_method == "smooth_quant": - if not args.no_merged: - logger.error("SmoothQuant must be used on separately exported models") - else: - logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") - smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) - - elif args.quantization_method == "quantize_dynamic": - logger.warning( - "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + barrier() + for i in range(world_size): + if i == rank: + # Set model paths for FP32 model + decoder_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx" + ) + decoder_with_past_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx" + ) + decoder_merged_model_fp32_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx" ) + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] - logger.info("Quantizing to int8...") - for fp32_path, int8_path in zip(old_paths, new_paths): - if os.path.exists(fp32_path): - ort_quantization.quantize_dynamic( - fp32_path, - int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." ) - logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") - remove_existing_model(decoder_model_fp32_path) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama, rank, world_size) + else: + run_torchscript_merged_export(args, l_config, llama, rank, world_size) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx" + ) + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [ + decoder_model_fp32_opt_path, + decoder_with_past_model_fp32_opt_path, + decoder_merged_model_fp32_opt_path, + ] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) - logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + # Change precision of exported models from FP32 + if args.precision == Precision.FLOAT16: + new_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + elif args.precision == Precision.INT8: + decoder_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx" + ) + decoder_with_past_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx" + ) + decoder_merged_model_int8_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx" + ) + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] + + if args.quantization_method == "smooth_quant": + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info( + f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8" + ) + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) + + elif args.quantization_method == "quantize_dynamic": + logger.warning( + "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." + ) - else: - raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - - elif args.precision == Precision.INT4: - if args.execution_provider != "cpu": - old_paths = convert_to_float16(args, l_config, old_paths) - - decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx") - decoder_with_past_model_int4_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx" - ) - decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx") - new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] - - for fp_path, int4_path in zip(old_paths, new_paths): - if os.path.exists(fp_path): - model = onnx.load_model(fp_path, load_external_data=True) - quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) - quant.process() - quant.model.save_model_to_file(int4_path, use_external_data_format=True) - del model - del quant - logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") - remove_existing_model(fp_path) + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info( + f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + else: + raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") + + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths, rank, world_size) + + decoder_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx" + ) + decoder_with_past_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join( + args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx" + ) + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + barrier() logger.info("Verifying parity on all ONNX models created") @@ -824,7 +920,12 @@ def main(): # Verify parity on all saved ONNX models for filename in os.listdir(args.output): - if ".data" in filename or ".onnx" not in filename: + if ( + ".data" in filename + or ".onnx" not in filename + or args.precision not in filename + or f"rank_{rank}" not in filename + ): continue parity_cmd = [ @@ -834,10 +935,10 @@ def main(): os.path.join(args.output, filename), "-ep", args.execution_provider, - "-id", - args.device_id, "-fp", args.precision, + "--cache_dir", + args.cache_dir, ] if "with_past" in filename: parity_cmd.append("--use_past_kv") @@ -845,6 +946,7 @@ def main(): parity_cmd.append("--merged") try: + logger.debug(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py new file mode 100644 index 0000000000000..50b0669d6d83a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -0,0 +1,45 @@ +import os + +import torch.distributed as dist + +comm = None + + +def init_dist(): + if "LOCAL_RANK" in os.environ: + int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) + elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: + from mpi4py import MPI + + comm = MPI.COMM_WORLD # noqa: F841 + + int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) + + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank) + else: + # don't need to do init for single process + pass + + +def get_rank(): + return comm.Get_rank() if comm is not None else 0 + + +def get_size(): + return comm.Get_size() if comm is not None else 1 + + +def barrier(): + if comm is not None: + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index f7a1b05249abf..6530eead55f03 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -66,12 +66,13 @@ def get_sample_with_past_kv_inputs( use_fp16: bool = False, engine: str = "pt", return_dict: bool = False, + world_size: int = 1, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) @@ -123,12 +124,13 @@ def get_merged_sample_with_past_kv_inputs( use_fp16: bool = False, engine: str = "pt", return_dict: bool = False, + world_size: int = 1, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size) # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) @@ -220,8 +222,8 @@ def get_msft_sample_inputs( # Create past_key_values # Each is of shape (batch_size, num_heads, past_sequence_length, head_size) -def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool): - num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): + num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index c1c5d3c412f2a..42581caf3bb9e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -6,7 +6,7 @@ import numpy as np import torch -from benchmark_helper import setup_logger +from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings, convert_inputs_for_ort, @@ -14,9 +14,11 @@ get_sample_inputs, get_sample_with_past_kv_inputs, ) +from llama_torch import setup_torch_model from transformers import LlamaConfig, LlamaForCausalLM import onnxruntime as ort +from onnxruntime.transformers.benchmark_helper import setup_logger logger = logging.getLogger("") @@ -30,6 +32,7 @@ def get_sequence_lengths(args: argparse.Namespace): def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity + world_size = get_size() batch_size = 2 past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) @@ -43,10 +46,17 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, + world_size=world_size, ) elif args.use_past_kv: inputs = get_sample_with_past_kv_inputs( - config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True + config, + args.device, + batch_size, + sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + world_size=world_size, ) else: inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) @@ -66,6 +76,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") + del pt_model # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) @@ -76,12 +87,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, device=args.execution_provider, - device_id=int(args.device_id), + device_id=int(args.rank), ) ep = f"{args.execution_provider.upper()}ExecutionProvider" if ep == "CUDAExecutionProvider": - ep = (ep, {"device_id": args.device_id}) + ep = (ep, {"device_id": args.rank}) ort_model = ort.InferenceSession( args.onnx_model_path, sess_options=ort.SessionOptions(), @@ -91,7 +102,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings( - ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues + ort_model, inputs, args.execution_provider, int(args.rank), kv_cache_ortvalues ) io_binding.synchronize_inputs() @@ -101,6 +112,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama end_time = time.time() ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + del ort_model else: start_time = time.time() @@ -155,15 +167,6 @@ def get_args(argv: List[str]): help="Execution provider to verify parity with", ) - parser.add_argument( - "-id", - "--device-id", - required=False, - type=str, - default="0", - help="Device ID for GPUs", - ) - parser.add_argument( "-v", "--verbose", @@ -195,6 +198,14 @@ def get_args(argv: List[str]): help="Precision of model", ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./model_cache", + help="model cache dir to override default HF cache dir to avoid overflood the /home dir", + ) + args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -210,21 +221,23 @@ def main(argv: List[str] = []): # noqa: B006 args = get_args(argv) setup_logger(args.verbose) logger.info(f"Arguments: {args}") + rank = get_rank() # Load model and config setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 - setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010 + args.rank = rank + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010 setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) - llama = LlamaForCausalLM.from_pretrained( + config, llama = setup_torch_model( + args, location, + use_auth_token, torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), - use_auth_token=use_auth_token, - use_cache=True, - ).to(args.device) + device=args.device, + ) kv_cache_ortvalues = {} if not args.merged: diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py new file mode 100644 index 0000000000000..cf6406dde5be0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -0,0 +1,38 @@ +import logging +import os + +import torch +from dist_settings import barrier, get_rank, get_size +from transformers import LlamaConfig, LlamaForCausalLM + +logger = logging.getLogger("") + + +def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, device=None): + world_size = get_size() + logger.info(f"world_size: {world_size}") + rank = get_rank() + barrier() + + if not os.path.exists(args.cache_dir): + os.makedirs(args.cache_dir, exist_ok=True) + + for i in range(world_size): + if i == rank % (world_size): + l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) + l_config.use_cache = True + llama = LlamaForCausalLM.from_pretrained( + location, + use_auth_token=use_auth_token, + config=l_config, + torch_dtype=torch_dtype, + cache_dir=args.cache_dir, + ) + if world_size > 1: + llama.parallel_model() + if device: + llama.to(device) + llama.eval() + llama.requires_grad_(False) + barrier() + return l_config, llama diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt new file mode 100644 index 0000000000000..572cfdb71be4a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-70b-model.txt @@ -0,0 +1,4 @@ +-r requirements.txt +git+https://github.com/frankdongms/transformers.git@frdong/shard_llama +mpi4py +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 07870373e90b0..66ec0de88b44c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -337,6 +337,18 @@ def match_parent_paths(self, node, paths, output_name_to_node): return i, matched, return_indice return -1, None, None + def match_parent_paths_all(self, node, paths, output_name_to_node): + match_i, matches, return_indices = [], [], [] + for i, path in enumerate(paths): + assert isinstance(path, (List, Tuple)) + return_indice = [] + matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice) + if matched: + match_i.append(i) + matches.append(matched) + return_indices.append(return_indice) + return match_i, matches, return_indices + def match_parent_path( self, node, diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py index fedba2a25dfc2..373ad86ced1a7 100644 --- a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -96,7 +96,7 @@ def create_inputs_and_outputs(self, model_type: str): helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), ] - if model_type in {"past", "merged", "llama2_msft"}: + if model_type in {"past", "merged", "llama2_msft", "70b_distributed_merged"}: inputs.extend( [ helper.make_tensor_value_info( @@ -164,14 +164,14 @@ def get_first_rope_input(node_type: str): if is_fused or model_type == "llama2_msft": # q_out/k_out return f"{node_type}_out" - if model_type in {"no_past", "past", "merged"}: + if model_type in {"no_past", "past", "merged", "70b_distributed_merged"}: if node_type == "k": return "k_before_rope" return "q_before_rope" return "" def get_first_rope_output(node_type: str): - if is_fused or model_type in {"llama2_msft", "past", "merged"}: + if is_fused or model_type in {"llama2_msft", "past", "merged", "70b_distributed_merged"}: if node_type == "q": return "q_rope" return "k_rope" @@ -295,23 +295,225 @@ def create_k_path_hf(self, model_type: str): ) k_nodes = [reshape_k_node, transpose_k_1_node] - if model_type in {"past", "merged"}: + if model_type == "70b_distributed_merged": concat_k_node = helper.make_node( "Concat", inputs=["past_key", "k_rope"], outputs=["present_key"], axis=2, ) - k_nodes.append(concat_k_node) + shape_k1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k1_out"], name="Shape_k1") + shape_k2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k2_out"], name="Shape_k2") + shape_k3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k3_out"], name="Shape_k3") + shape_k4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_k4_out"], name="Shape_k4") + + gather_k_1 = helper.make_node( + "Gather", + inputs=["shape_k1_out", "one"], + outputs=["gather_k1_out"], + name="Gather_k_1", + axis=0, + ) + gather_k_2 = helper.make_node( + "Gather", + inputs=["shape_k2_out", "one"], + outputs=["gather_k2_out"], + name="Gather_k_2", + axis=0, + ) + gather_k_3 = helper.make_node( + "Gather", + inputs=["shape_k3_out", "one"], + outputs=["gather_k3_out"], + name="Gather_k_3", + axis=0, + ) + gather_k_4 = helper.make_node( + "Gather", + inputs=["shape_k4_out", "one"], + outputs=["gather_k4_out"], + name="Gather_k_4", + axis=0, + ) - transpose_k_2_node = helper.make_node( - "Transpose", - inputs=["present_key"], - outputs=["k"], - name="Transpose_k_2", - perm=[0, 1, 3, 2], - ) - return k_nodes + [transpose_k_2_node] # noqa: RUF005 + unsqueeze_k_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_k1_out"], + name="Unsqueeze_k1", + ) + unsqueeze_k_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k2_out"], + name="Unsqueeze_k2", + ) + unsqueeze_k_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_k2_out", "zero"], + outputs=["unsqueeze_k3_out"], + name="Unsqueeze_k3", + ) + unsqueeze_k_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k4_out"], + name="Unsqueeze_k4", + ) + unsqueeze_k_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k5_out"], + name="Unsqueeze_k5", + ) + + concat_k_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_k2_out", "unsqueeze_k3_out", "One", "unsqueeze_k4_out", "unsqueeze_k5_out"], + outputs=["concat_k2_ouot"], + name="Concat_k2", + axis=0, + ) + reshape_k_2 = helper.make_node( + "Reshape", + inputs=["concat_k2_ouot", "One"], + outputs=["reshape_k2_out"], + name="Reshape_k_2", + ) + shape_k5 = helper.make_node("Shape", inputs=["reshape_k2_out"], outputs=["shape_k5_out"], name="Shape_k5") + constant_of_shape_k_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_k5_out"], + outputs=["constant_of_shape_k1_out"], + name="ConstantOfShape_k1", + ) + mul_k_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_k1_out", "One"], + outputs=["mul_k1_out"], + name="mul_k1", + ) + equal_k_1 = helper.make_node( + "Equal", + inputs=["reshape_k2_out", "mul_k1_out"], + outputs=["equal_k_1_out"], + name="equal_k1", + ) + where_k_1 = helper.make_node( + "Where", + inputs=["equal_k_1_out", "constant_of_shape_k1_out", "reshape_k2_out"], + outputs=["where_k_1_out"], + name="where_k1", + ) + unsqueeze_k_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_k1_out", "zero"], + outputs=["unsqueeze_k6_out"], + name="Unsqueeze_k6", + ) + mul_k_2 = helper.make_node( + "Mul", + inputs=["gather_k2_out", "One"], + outputs=["mul_k2_out"], + name="mul_k2", + ) + unsqueeze_k_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_k2_out", "zero"], + outputs=["unsqueeze_k7_out"], + name="Unsqueeze_k7", + ) + unsqueeze_k_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_k3_out", "zero"], + outputs=["unsqueeze_k8_out"], + name="Unsqueeze_k8", + ) + unsqueeze_k_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_k4_out", "zero"], + outputs=["unsqueeze_k9_out"], + name="Unsqueeze_k9", + ) + concat_k_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_k6_out", "unsqueeze_k7_out", "unsqueeze_k8_out", "unsqueeze_k9_out"], + outputs=["concat_k3_out"], + name="Concat_k3", + axis=0, + ) + expand_k_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_k1_out", "where_k_1_out"], + outputs=["expand_k1_out"], + name="expand_k1", + ) + reshape_k_3 = helper.make_node( + "Reshape", + inputs=["expand_k1_out", "concat_k3_out"], + outputs=["reshape_k3_out"], + name="Reshape_k_3", + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["reshape_k3_out"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + + k_nodes_for_70b_model = [ + concat_k_node, + shape_k1, + shape_k2, + shape_k3, + shape_k4, + gather_k_1, + gather_k_2, + gather_k_3, + gather_k_4, + unsqueeze_k_1, + unsqueeze_k_2, + unsqueeze_k_3, + unsqueeze_k_4, + unsqueeze_k_5, + concat_k_2, + reshape_k_2, + shape_k5, + constant_of_shape_k_1, + mul_k_1, + equal_k_1, + where_k_1, + unsqueeze_k_6, + mul_k_2, + unsqueeze_k_7, + unsqueeze_k_8, + unsqueeze_k_9, + concat_k_3, + expand_k_1, + reshape_k_3, + transpose_k_2_node, + ] + k_nodes.extend(k_nodes_for_70b_model) + return k_nodes + else: + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 def create_k_path(self, model_type: str): if model_type == "llama2_msft": @@ -505,7 +707,7 @@ def create_v_path(self, model_type: str): if model_type == "no_past": return v_nodes - if model_type in {"past", "merged"}: + if model_type in {"past", "merged", "70b_distributed_merged"}: concat_v_node = helper.make_node( "Concat", inputs=["past_value", "transpose_v_1_out"], @@ -513,7 +715,194 @@ def create_v_path(self, model_type: str): name="Concat_v", axis=2, ) - return v_nodes + [concat_v_node] # noqa: RUF005 + + if model_type != "70b_distributed_merged": + return v_nodes + [concat_v_node] # noqa: RUF005 + + shape_v1 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_1_out"], name="Shape_v1") + shape_v2 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_2_out"], name="Shape_v2") + shape_v3 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_3_out"], name="Shape_v3") + shape_v4 = helper.make_node("Shape", inputs=["present_value"], outputs=["shape_4_out"], name="Shape_v4") + gather_v_1 = helper.make_node( + "Gather", + inputs=["shape_1_out", "one"], + outputs=["gather_1_out"], + name="Gather_v1", + axis=0, + ) + gather_v_2 = helper.make_node( + "Gather", + inputs=["shape_2_out", "one"], + outputs=["gather_2_out"], + name="Gather_v2", + axis=0, + ) + gather_v_3 = helper.make_node( + "Gather", + inputs=["shape_3_out", "one"], + outputs=["gather_3_out"], + name="Gather_v3", + axis=0, + ) + gather_v_4 = helper.make_node( + "Gather", + inputs=["shape_4_out", "one"], + outputs=["gather_4_out"], + name="Gather_v4", + axis=0, + ) + unsqueeze_v_1 = helper.make_node( + "Unsqueeze", + inputs=["present_value", "zero"], + outputs=["unsqueeze_v1_out"], + name="Unsqueeze_v1", + ) + unsqueeze_v_2 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v2_out"], + name="Unsqueeze_v2", + ) + unsqueeze_v_3 = helper.make_node( + "Unsqueeze", + inputs=["gather_2_out", "zero"], + outputs=["unsqueeze_v3_out"], + name="Unsqueeze_v3", + ) + unsqueeze_v_4 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v4_out"], + name="Unsqueeze_v4", + ) + unsqueeze_v_5 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v5_out"], + name="Unsqueeze_v5", + ) + concat_v_2 = helper.make_node( + "Concat", + inputs=["unsqueeze_v2_out", "unsqueeze_v3_out", "One", "unsqueeze_v4_out", "unsqueeze_v5_out"], + outputs=["concat_v2_ouot"], + name="Concat_v2", + axis=0, + ) + reshape_v_2 = helper.make_node( + "Reshape", + inputs=["concat_v2_ouot", "One"], + outputs=["reshape_v2_out"], + name="Reshape_v2", + ) + shape_v5 = helper.make_node("Shape", inputs=["reshape_v2_out"], outputs=["shape_5_out"], name="Shape_v5") + constant_of_shape_v_1 = helper.make_node( + "ConstantOfShape", + inputs=["shape_5_out"], + outputs=["constant_of_shape_v1_out"], + name="ConstantOfShape_v1", + ) + mul_v_1 = helper.make_node( + "Mul", + inputs=["constant_of_shape_v1_out", "One"], + outputs=["mul_v1_out"], + name="mul_v1", + ) + equal_v_1 = helper.make_node( + "Equal", + inputs=["reshape_v2_out", "mul_v1_out"], + outputs=["equal_v_1_out"], + name="equal_v1", + ) + where_v_1 = helper.make_node( + "Where", + inputs=["equal_v_1_out", "constant_of_shape_v1_out", "reshape_v2_out"], + outputs=["where_v_1_out"], + name="where_v1", + ) + unsqueeze_v_6 = helper.make_node( + "Unsqueeze", + inputs=["gather_1_out", "zero"], + outputs=["unsqueeze_v6_out"], + name="Unsqueeze_v6", + ) + mul_v_2 = helper.make_node( + "Mul", + inputs=["gather_2_out", "One"], + outputs=["mul_v2_out"], + name="mul_v2", + ) + unsqueeze_v_7 = helper.make_node( + "Unsqueeze", + inputs=["mul_v2_out", "zero"], + outputs=["unsqueeze_v7_out"], + name="Unsqueeze_v7", + ) + unsqueeze_v_8 = helper.make_node( + "Unsqueeze", + inputs=["gather_3_out", "zero"], + outputs=["unsqueeze_v8_out"], + name="Unsqueeze_v8", + ) + unsqueeze_v_9 = helper.make_node( + "Unsqueeze", + inputs=["gather_4_out", "zero"], + outputs=["unsqueeze_v9_out"], + name="Unsqueeze_v9", + ) + concat_v_3 = helper.make_node( + "Concat", + inputs=["unsqueeze_v6_out", "unsqueeze_v7_out", "unsqueeze_v8_out", "unsqueeze_v9_out"], + outputs=["concat_v3_out"], + name="Concat_v3", + axis=0, + ) + expand_v_1 = helper.make_node( + "Expand", + inputs=["unsqueeze_v1_out", "where_v_1_out"], + outputs=["expand_v1_out"], + name="expand_v1", + ) + reshape_v_3 = helper.make_node( + "Reshape", + inputs=["expand_v1_out", "concat_v3_out"], + outputs=["reshape_v3_out"], + name="Reshape_v3", + ) + + v_nodes_for_70b_model = [ + concat_v_node, + shape_v1, + shape_v2, + shape_v3, + shape_v4, + gather_v_1, + gather_v_2, + gather_v_3, + gather_v_4, + unsqueeze_v_1, + unsqueeze_v_2, + unsqueeze_v_3, + unsqueeze_v_4, + unsqueeze_v_5, + concat_v_2, + reshape_v_2, + shape_v5, + constant_of_shape_v_1, + mul_v_1, + equal_v_1, + where_v_1, + unsqueeze_v_6, + mul_v_2, + unsqueeze_v_7, + unsqueeze_v_8, + unsqueeze_v_9, + concat_v_3, + expand_v_1, + reshape_v_3, + ] + v_nodes.extend(v_nodes_for_70b_model) + + return v_nodes # Create extra nodes for `position_ids` unsqueeze_v_node = helper.make_node( @@ -672,7 +1061,28 @@ def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[Nod return extra_nodes - def create_end_nodes(self): + def create_end_nodes(self, model_type): + if model_type == "70b_distributed_merged": + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + all_reduce = helper.make_node( + "AllReduce", + inputs=["output_proj"], + outputs=["allreduce_proj"], + name="allreduce_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "allreduce_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, all_reduce, end_node] + matmul_o_node = helper.make_node( "MatMul", inputs=["attn_output", "o_weight"], @@ -711,7 +1121,7 @@ def create_fused_model(self, model_type: str, interleaved: bool, initializers: L num_heads=self.num_heads, ) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) graph = helper.make_graph( nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, @@ -740,7 +1150,7 @@ def create_test_model(self, model_type: str, interleaved: bool, initializers: Li reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) - end_nodes = self.create_end_nodes() + end_nodes = self.create_end_nodes(model_type) first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes @@ -790,6 +1200,11 @@ def test_hf_decoder_merged_model(self): interleaved = False self.check_models(model_type, interleaved) + def test_hf_70b_distributed_decoder_merged_model(self): + model_type = "70b_distributed_merged" + interleaved = False + self.check_models(model_type, interleaved) + if __name__ == "__main__": unittest.main()