From ddfb4625ba7cdd5da2969bd4aa20d502d8a52da7 Mon Sep 17 00:00:00 2001 From: Huiqiang Jiang Date: Thu, 18 Jul 2024 13:46:57 +0800 Subject: [PATCH] Feature(MInference): add e2e benchmark using vllm (#49) * Feature(MInference): add e2e benchmark using vllm * Feature(MInference): change the guideline Co-authored-by: Yucheng Li Co-authored-by: Chengruidong Zhang --- README.md | 6 +- experiments/README.md | 30 +++++++- experiments/benchmarks/benchmark_e2e.py | 4 +- experiments/benchmarks/benchmark_e2e_vllm.py | 73 ++++++++++++++++++++ experiments/benchmarks/run_e2e_vllm.sh | 9 +++ 5 files changed, 117 insertions(+), 5 deletions(-) create mode 100644 experiments/benchmarks/benchmark_e2e_vllm.py create mode 100644 experiments/benchmarks/run_e2e_vllm.sh diff --git a/README.md b/README.md index 8eb3b77..2f5722b 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,8 @@ from transformers import pipeline pipe = pipeline("text-generation", model=model_name, torch_dtype="auto", device_map="auto") -# Patch MInference Module +# Patch MInference Module, +# If you use the local path, please use the model_name from HF when initializing MInference. +minference_patch = MInference("minference", model_name) +pipe.model = minference_patch(pipe.model) @@ -91,7 +92,8 @@ from vllm import LLM, SamplingParams llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000) -# Patch MInference Module +# Patch MInference Module, +# If you use the local path, please use the model_name from HF when initializing MInference. +minference_patch = MInference("vllm", model_name) +llm = minference_patch(llm) diff --git a/experiments/README.md b/experiments/README.md index 715f2b3..415c1a1 100644 --- a/experiments/README.md +++ b/experiments/README.md @@ -51,7 +51,7 @@ wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/ 2. Run a single context window size test using one method: ```bash -python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 1000000 +python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 1_000_000 ``` 3. Run all latency experiments using different methods: @@ -77,6 +77,34 @@ python experiments/benchmarks/benchmark_e2e.py --run_benchmark > [!TIP] > Based on our tests, **a single A100 can support up to 1.8M** context prompts during the pre-filling stage using LLaMA-3-8B-4M with **bf16**. +#### End-to-End Benchmark using vLLM + +And we also built the End-to-End benchmark using vLLM. You can run it using the following scripts: + +1. Download the prompt: + +```bash +wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt +``` + +2. Run a single context window size test using one method: + +```bash +python experiments/benchmarks/benchmark_e2e_vllm.py --attn_type minference --context_window 100_000 +``` + +Here are some vLLM latency data, for reference, on an A100: +```json + FlashAttention-2 MInference +1K 0.08062 3.01744 +10K 0.83215 2.76216 +50K 7.71675 7.53989 +100K 21.73080 14.08111 +128K 32.86252 18.82662 +``` + +Please note that the current vLLM version **only supports** MInference and FlashAttention modes. Due to vLLM's PageAttention management of the KV cache, a single A100 can handle a maximum context window size of **130k**. + ### Micro-Benchmark diff --git a/experiments/benchmarks/benchmark_e2e.py b/experiments/benchmarks/benchmark_e2e.py index dcfd7e5..b67ed5c 100644 --- a/experiments/benchmarks/benchmark_e2e.py +++ b/experiments/benchmarks/benchmark_e2e.py @@ -11,7 +11,7 @@ from minference import MInference -def run_target_length(m: int, model, attn_type): +def run_target_length(m: int, model, attn_type: str): # wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt prompt_complex = open("./prompt_hardest.txt").read() input_ids = tokenizer(prompt_complex)["input_ids"] @@ -37,7 +37,7 @@ def run_target_length(m: int, model, attn_type): ) torch.cuda.synchronize() s += time.time() - start - print(m, s / T) + print(attn_type, m, s / T) return s / T diff --git a/experiments/benchmarks/benchmark_e2e_vllm.py b/experiments/benchmarks/benchmark_e2e_vllm.py new file mode 100644 index 0000000..d0c62b5 --- /dev/null +++ b/experiments/benchmarks/benchmark_e2e_vllm.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +import argparse +import time + +import torch +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +from minference import MInference + + +def run_target_length(m: int, model, sampling_params, attn_type: str): + # wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt + prompt_complex = open("./prompt_hardest.txt").read() + input_ids = tokenizer(prompt_complex)["input_ids"] + n = len(input_ids) + b = m // n + 1 + + new_input_ids = (input_ids * b)[:m] + prompt = tokenizer.decode(new_input_ids) + + s = 0 + T = 10 + for _ in range(T): + torch.cuda.synchronize() + start = time.time() + with torch.no_grad(): + outputs = llm.generate([prompt], sampling_params) + torch.cuda.synchronize() + s += time.time() - start + print(attn_type, m, s / T) + return s / T + + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument( + "--model_name", + type=str, + default="gradientai/Llama-3-8B-Instruct-Gradient-1048k", + ) + args.add_argument( + "--attn_type", + type=str, + choices=["flash_attn", "minference"], + ) + args.add_argument("--context_window", type=int, default=100_000) + args = args.parse_args() + + model_name = args.model_name + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1, + ) + + llm = LLM( + model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=129000, + ) + + # Patch MInference Module + if args.attn_type == "minference": + minference_patch = MInference("vllm", model_name) + llm = minference_patch(llm) + + run_target_length(args.context_window, llm, sampling_params, args.attn_type) diff --git a/experiments/benchmarks/run_e2e_vllm.sh b/experiments/benchmarks/run_e2e_vllm.sh new file mode 100644 index 0000000..4b20598 --- /dev/null +++ b/experiments/benchmarks/run_e2e_vllm.sh @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +# Load data +wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt + +python experiments/benchmarks/benchmark_e2e_vllm.py \ + --attn_type minference \ + --context_window 100_000