Skip to content

Commit

Permalink
Feature(MInference): add e2e benchmark using vllm (#49)
Browse files Browse the repository at this point in the history
* Feature(MInference): add e2e benchmark using vllm
* Feature(MInference): change the guideline

Co-authored-by: Yucheng Li <[email protected]>
Co-authored-by: Chengruidong Zhang <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent 0b9c81b commit ddfb462
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 5 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ from transformers import pipeline

pipe = pipeline("text-generation", model=model_name, torch_dtype="auto", device_map="auto")

# Patch MInference Module
# Patch MInference Module,
# If you use the local path, please use the model_name from HF when initializing MInference.
+minference_patch = MInference("minference", model_name)
+pipe.model = minference_patch(pipe.model)

Expand All @@ -91,7 +92,8 @@ from vllm import LLM, SamplingParams

llm = LLM(model_name, max_num_seqs=1, enforce_eager=True, max_model_len=128000)

# Patch MInference Module
# Patch MInference Module,
# If you use the local path, please use the model_name from HF when initializing MInference.
+minference_patch = MInference("vllm", model_name)
+llm = minference_patch(llm)

Expand Down
30 changes: 29 additions & 1 deletion experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/
2. Run a single context window size test using one method:

```bash
python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 1000000
python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 1_000_000
```

3. Run all latency experiments using different methods:
Expand All @@ -77,6 +77,34 @@ python experiments/benchmarks/benchmark_e2e.py --run_benchmark
> [!TIP]
> Based on our tests, **a single A100 can support up to 1.8M** context prompts during the pre-filling stage using LLaMA-3-8B-4M with **bf16**.
#### End-to-End Benchmark using vLLM

And we also built the End-to-End benchmark using vLLM. You can run it using the following scripts:

1. Download the prompt:

```bash
wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt
```

2. Run a single context window size test using one method:

```bash
python experiments/benchmarks/benchmark_e2e_vllm.py --attn_type minference --context_window 100_000
```

Here are some vLLM latency data, for reference, on an A100:
```json
FlashAttention-2 MInference
1K 0.08062 3.01744
10K 0.83215 2.76216
50K 7.71675 7.53989
100K 21.73080 14.08111
128K 32.86252 18.82662
```

Please note that the current vLLM version **only supports** MInference and FlashAttention modes. Due to vLLM's PageAttention management of the KV cache, a single A100 can handle a maximum context window size of **130k**.

### Micro-Benchmark


Expand Down
4 changes: 2 additions & 2 deletions experiments/benchmarks/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from minference import MInference


def run_target_length(m: int, model, attn_type):
def run_target_length(m: int, model, attn_type: str):
# wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt
prompt_complex = open("./prompt_hardest.txt").read()
input_ids = tokenizer(prompt_complex)["input_ids"]
Expand All @@ -37,7 +37,7 @@ def run_target_length(m: int, model, attn_type):
)
torch.cuda.synchronize()
s += time.time() - start
print(m, s / T)
print(attn_type, m, s / T)
return s / T


Expand Down
73 changes: 73 additions & 0 deletions experiments/benchmarks/benchmark_e2e_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import argparse
import time

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from minference import MInference


def run_target_length(m: int, model, sampling_params, attn_type: str):
# wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt
prompt_complex = open("./prompt_hardest.txt").read()
input_ids = tokenizer(prompt_complex)["input_ids"]
n = len(input_ids)
b = m // n + 1

new_input_ids = (input_ids * b)[:m]
prompt = tokenizer.decode(new_input_ids)

s = 0
T = 10
for _ in range(T):
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
outputs = llm.generate([prompt], sampling_params)
torch.cuda.synchronize()
s += time.time() - start
print(attn_type, m, s / T)
return s / T


if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument(
"--model_name",
type=str,
default="gradientai/Llama-3-8B-Instruct-Gradient-1048k",
)
args.add_argument(
"--attn_type",
type=str,
choices=["flash_attn", "minference"],
)
args.add_argument("--context_window", type=int, default=100_000)
args = args.parse_args()

model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1,
)

llm = LLM(
model_name,
max_num_seqs=1,
enforce_eager=True,
max_model_len=129000,
)

# Patch MInference Module
if args.attn_type == "minference":
minference_patch = MInference("vllm", model_name)
llm = minference_patch(llm)

run_target_length(args.context_window, llm, sampling_params, args.attn_type)
9 changes: 9 additions & 0 deletions experiments/benchmarks/run_e2e_vllm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]

# Load data
wget https://raw.githubusercontent.com/FranxYao/chain-of-thought-hub/main/gsm8k/lib_prompt/prompt_hardest.txt

python experiments/benchmarks/benchmark_e2e_vllm.py \
--attn_type minference \
--context_window 100_000

0 comments on commit ddfb462

Please sign in to comment.