-
Notifications
You must be signed in to change notification settings - Fork 500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Offline LLM Engine Benchmark Throughput #1968
Open
zolinthecow
wants to merge
29
commits into
sgl-project:main
Choose a base branch
from
zolinthecow:benchmark-script
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+375
−31
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
807a3f0
add offline engine bench
zolinthecow e3ec623
llm_engine -> engine
zolinthecow 8b1232b
add to unit test bench
zolinthecow e6293a8
first draft bench offline throughput
zolinthecow 5564a96
script works
zolinthecow 0078bc3
reset bench serving stuff
zolinthecow 9f6c31a
merge
zolinthecow 3158414
most recent commit?
zolinthecow 550ec14
restore test utils
zolinthecow a6b183e
Merge branch 'main' into benchmark-script
zolinthecow c1c6226
lint
zolinthecow 1895c79
use sharegpt from bench_serving
zolinthecow 3c8faf9
add unit test
zolinthecow 170c83f
lint
zolinthecow 696dd95
add support for runtime backend + dataclass generic args
zolinthecow 21b6ed5
push not being processed?
zolinthecow 0589a6b
lint
zolinthecow 383b6d1
fix benches
zolinthecow 8db0340
lint
zolinthecow 568ce97
Merge branch 'main' into benchmark-script
zolinthecow c6a6827
add review
ByronHsu ed1a133
address todos
zolinthecow c485dbe
not sure how the tuple stuff got there
zolinthecow 3565766
Merge branch 'main' into benchmark-script
zolinthecow fd2d04d
fix
zolinthecow ea3b60a
fix
zolinthecow 732e3ba
lint
zolinthecow 41aad44
format benchmark + add diff metrics
zolinthecow fa76ac9
lint
zolinthecow File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
""" | ||
Benchmark the throughput of using the offline LLM engine. | ||
This script does not launch a server. | ||
It accepts the same arguments as launch_server.py and additional benchmark arguments | ||
|
||
# Usage | ||
## Sharegpt dataset with default args | ||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct | ||
|
||
## Random dataset with default args | ||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name random | ||
|
||
## Shared prefix dataset with default args | ||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --dataset-name generated-shared-prefix | ||
|
||
## Sharegpt dataset on runtime backend | ||
python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3-8B-Instruct --backend runtime | ||
""" | ||
|
||
import argparse | ||
import dataclasses | ||
import itertools | ||
import json | ||
import logging | ||
import random | ||
import time | ||
from typing import Dict, List, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from sglang.api import Engine as getEngine | ||
from sglang.bench_serving import ( | ||
get_dataset, | ||
get_tokenizer, | ||
sample_generated_shared_prefix_requests, | ||
sample_random_requests, | ||
sample_sharegpt_requests, | ||
set_ulimit, | ||
) | ||
from sglang.srt.server import Engine, Runtime | ||
from sglang.srt.server_args import ServerArgs | ||
|
||
|
||
@dataclasses.dataclass | ||
class BenchArgs: | ||
backend: str = "engine" | ||
result_filename: str = "" | ||
dataset_name: str = "sharegpt" | ||
dataset_path: str = "" | ||
num_prompts: int = 1000 | ||
sharegpt_output_len: int = 256 | ||
random_input_len: int = 256 | ||
random_output_len: int = 256 | ||
random_range_ratio: float = 0.0 | ||
gen_num_groups: int = 8 | ||
gen_prompts_per_group: int = 16 | ||
gen_system_prompt_len: int = 128 | ||
gen_question_len: int = 256 | ||
disable_ignore_eos: bool = False | ||
seed: int = 1 | ||
|
||
@staticmethod | ||
def add_cli_args(parser: argparse.ArgumentParser): | ||
parser.add_argument("--backend", type=str, default=BenchArgs.backend) | ||
parser.add_argument( | ||
"--result-filename", type=str, default=BenchArgs.result_filename | ||
) | ||
parser.add_argument( | ||
"--dataset-name", | ||
type=str, | ||
default="sharegpt", | ||
choices=["sharegpt", "random", "generated-shared-prefix"], | ||
help="Name of the dataset to benchmark on.", | ||
) | ||
parser.add_argument( | ||
"--dataset-path", type=str, default="", help="Path to the dataset." | ||
) | ||
parser.add_argument( | ||
"--num-prompts", | ||
type=int, | ||
default=BenchArgs.num_prompts, | ||
help="Number of prompts to process. Default is 1000.", | ||
) | ||
parser.add_argument( | ||
"--sharegpt-output-len", | ||
type=int, | ||
default=BenchArgs.sharegpt_output_len, | ||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.", | ||
) | ||
parser.add_argument( | ||
"--random-input-len", | ||
type=int, | ||
default=BenchArgs.random_input_len, | ||
help="Number of input tokens per request, used only for random dataset.", | ||
) | ||
parser.add_argument( | ||
"--random-output-len", | ||
type=int, | ||
default=BenchArgs.random_output_len, | ||
help="Number of output tokens per request, used only for random dataset.", | ||
) | ||
parser.add_argument( | ||
"--random-range-ratio", | ||
type=float, | ||
default=BenchArgs.random_range_ratio, | ||
help="Range of sampled ratio of input/output length, " | ||
"used only for random dataset.", | ||
) | ||
parser.add_argument( | ||
"--gen-num-groups", | ||
type=int, | ||
default=BenchArgs.gen_num_groups, | ||
help="Number of groups with shared prefix, used" | ||
"only for generate-shared-prefix", | ||
) | ||
parser.add_argument( | ||
"--gen-prompts-per-group", | ||
type=int, | ||
default=BenchArgs.gen_prompts_per_group, | ||
help="Number of prompts per group of shared prefix, used" | ||
"only for generate-shared-prefix", | ||
) | ||
parser.add_argument( | ||
"--gen-system-prompt-len", | ||
type=int, | ||
default=BenchArgs.gen_system_prompt_len, | ||
help="System prompt length, used" "only for generate-shared-prefix", | ||
) | ||
parser.add_argument( | ||
"--gen-question-len", | ||
type=int, | ||
default=BenchArgs.gen_question_len, | ||
help="Question length, used" "only for generate-shared-prefix", | ||
) | ||
parser.add_argument( | ||
"--disable-ignore-eos", | ||
type=bool, | ||
default=BenchArgs.disable_ignore_eos, | ||
help="Disable ignore EOS token", | ||
) | ||
parser.add_argument("--seed", type=int, default=1, help="The random seed.") | ||
|
||
@classmethod | ||
def from_cli_args(cls, args: argparse.Namespace): | ||
# use the default value's type to case the args into correct types. | ||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] | ||
print(attrs) | ||
return cls( | ||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} | ||
) | ||
|
||
|
||
def throughput_test_once( | ||
backend_name: str, | ||
backend: Union[Engine, Runtime], | ||
reqs: List[Tuple[str, int, int]], | ||
ignore_eos: bool, | ||
): | ||
measurement_results = { | ||
"backend": backend_name, | ||
"successful_requests": len(reqs), | ||
"total_latency": -1, | ||
"total_input_tokens": sum(r[1] for r in reqs), | ||
"total_output_tokens": -1, | ||
"request_throughput": -1, | ||
"input_throughput": -1, | ||
"output_throughput": -1, | ||
"total_throughput": -1, | ||
} | ||
|
||
prompt = [r[0] for r in reqs] | ||
sampling_params = [ | ||
{ | ||
"temperature": 0, | ||
"max_new_tokens": r[2], | ||
"ignore_eos": ignore_eos, | ||
} | ||
for r in reqs | ||
] | ||
|
||
st = time.perf_counter() | ||
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) | ||
latency = time.perf_counter() - st | ||
|
||
if backend_name == "runtime": | ||
gen_out = json.loads(gen_out) | ||
|
||
measurement_results["total_latency"] = latency | ||
measurement_results["total_output_tokens"] = sum( | ||
o["meta_info"]["completion_tokens"] for o in gen_out | ||
) | ||
measurement_results["request_throughput"] = ( | ||
measurement_results["successful_requests"] / latency | ||
) | ||
measurement_results["input_throughput"] = ( | ||
measurement_results["total_input_tokens"] / latency | ||
) | ||
measurement_results["output_throughput"] = ( | ||
measurement_results["total_output_tokens"] / latency | ||
) | ||
measurement_results["total_throughput"] = ( | ||
measurement_results["total_input_tokens"] | ||
+ measurement_results["total_output_tokens"] | ||
) / latency | ||
|
||
return measurement_results | ||
|
||
|
||
def throughput_test( | ||
server_args: ServerArgs, | ||
bench_args: BenchArgs, | ||
): | ||
if bench_args.backend == "engine": | ||
backend = getEngine(**dataclasses.asdict(server_args)) | ||
if not backend: | ||
raise ValueError("Please provide valid engine arguments") | ||
elif bench_args.backend == "runtime": | ||
backend = Runtime(**dataclasses.asdict(server_args)) | ||
else: | ||
raise ValueError('Please set backend to either "engine" or "runtime"') | ||
|
||
tokenizer_id = server_args.model_path | ||
tokenizer = get_tokenizer(tokenizer_id) | ||
|
||
# Set global environmnets | ||
set_ulimit() | ||
random.seed(bench_args.seed) | ||
np.random.seed(bench_args.seed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
input_requests = get_dataset(bench_args, tokenizer) | ||
|
||
warmup_requests = sample_random_requests( | ||
input_len=20, | ||
output_len=4, | ||
num_prompts=2, | ||
range_ratio=0.8, | ||
tokenizer=tokenizer, | ||
dataset_path=bench_args.dataset_path, | ||
) | ||
|
||
# Warm up | ||
throughput_test_once( | ||
backend_name=bench_args.backend, | ||
backend=backend, | ||
reqs=warmup_requests, | ||
output_len=output_len, | ||
ignore_eos=not bench_args.disable_ignore_eos, | ||
) | ||
|
||
result = throughput_test_once( | ||
backend_name=bench_args.backend, | ||
backend=backend, | ||
reqs=input_requests, | ||
output_len=output_len, | ||
ignore_eos=not bench_args.disable_ignore_eos, | ||
) | ||
|
||
if bench_args.result_filename: | ||
with open(bench_args.result_filename, "a") as fout: | ||
fout.write(json.dumps(result) + "\n") | ||
|
||
return result | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
ServerArgs.add_cli_args(parser) | ||
BenchArgs.add_cli_args(parser) | ||
args = parser.parse_args() | ||
server_args = ServerArgs.from_cli_args(args) | ||
bench_args = BenchArgs.from_cli_args(args) | ||
|
||
logging.basicConfig( | ||
level=getattr(logging, server_args.log_level.upper()), | ||
format="%(message)s", | ||
) | ||
|
||
try: | ||
res = throughput_test(server_args, bench_args) | ||
print( | ||
"\n{s:{c}^{n}}".format( | ||
s=" Offline Throughput Benchmark Result ", n=50, c="=" | ||
) | ||
) | ||
print("{:<40} {:<10}".format("Backend:", res["backend"])) | ||
print( | ||
"{:<40} {:<10}".format("Successful requests:", res["successful_requests"]) | ||
) | ||
print( | ||
"{:<40} {:<10.2f}".format("Benchmark duration (s):", res["total_latency"]) | ||
) | ||
print("{:<40} {:<10}".format("Total input tokens:", res["total_input_tokens"])) | ||
print( | ||
"{:<40} {:<10}".format( | ||
"Total generated tokens:", res["total_output_tokens"] | ||
) | ||
) | ||
print( | ||
"{:<40} {:<10.2f}".format( | ||
"Request throughput (req/s):", res["request_throughput"] | ||
) | ||
) | ||
print( | ||
"{:<40} {:<10.2f}".format( | ||
"Input token throughput (tok/s):", res["input_throughput"] | ||
) | ||
) | ||
print( | ||
"{:<40} {:<10.2f}".format( | ||
"Output token throughput (tok/s):", res["output_throughput"] | ||
) | ||
) | ||
print( | ||
"{:<40} {:<10.2f}".format( | ||
"Total token throughput (tok/s):", res["total_throughput"] | ||
) | ||
) | ||
except Exception as e: | ||
raise e |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this