Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline LLM Engine Benchmark Throughput #1968

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions python/sglang/bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
PreTrainedTokenizerFast,
)

from sglang.api import Engine as getEngine
from sglang.srt.server import Engine

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

global args
Expand All @@ -45,7 +48,9 @@
@dataclass
class RequestFuncInput:
prompt: str
api_url: str
# one or the other must be defined but not both
api_url: Optional[str]
engine: Optional[Engine]
prompt_len: int
output_len: int
model: str
Expand Down Expand Up @@ -222,6 +227,68 @@ async def async_request_openai_completions(
return output


async def async_request_sglang_offline_engine(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
engine = request_func_input.engine
if not engine:
raise ValueError("Please pass in an Engine")

prompt = request_func_input.prompt

payload = {
"temperature": 0.0,
"n": 1,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body,
}
stream = not args.disable_stream

output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

generated_text = ""
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
gen_out = await engine.async_generate(prompt, payload, stream=stream)
if stream:
async for chunk in gen_out:
latency = time.perf_counter() - st
if chunk["text"]:
timestamp = time.perf_counter()
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
else:
output.itl.append(timestamp - most_recent_timestamp)

most_recent_timestamp = timestamp
generated_text += chunk["text"]
else:
if gen_out[0]["text"]:
# not sure why you'd ever want this
latency = time.perf_counter() - st
ttft = latency
output.ttft = ttft
generated_text = gen_out[0]["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))

if pbar:
pbar.update(1)
return output


async def async_request_truss(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
Expand Down Expand Up @@ -425,6 +492,7 @@ def get_tokenizer(
"sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions,
"sglang-offline-engine": async_request_sglang_offline_engine,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"trt": async_request_trt_llm,
Expand Down Expand Up @@ -718,7 +786,7 @@ def calculate_metrics(

async def benchmark(
backend: str,
api_url: str,
api_url: Optional[str],
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
Expand All @@ -730,13 +798,17 @@ async def benchmark(
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
engine = None
if backend == "sglang-offline-engine":
engine = getEngine(model_path=model_id)

print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=api_url,
engine=engine,
prompt_len=test_prompt_len,
output_len=test_output_len,
extra_request_body=extra_request_body,
Expand All @@ -762,6 +834,7 @@ async def benchmark(
model=model_id,
prompt=prompt,
api_url=api_url,
engine=engine,
prompt_len=prompt_len,
output_len=output_len,
extra_request_body=extra_request_body,
Expand Down Expand Up @@ -974,6 +1047,8 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url
else f"http://{args.host}:{args.port}/v1/completions"
)
elif args.backend in ["sglang-offline-engine"]:
api_url = None
elif args.backend == "trt":
api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream"
Expand Down
16 changes: 9 additions & 7 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,23 +522,25 @@ def run_bench_serving(
num_prompts,
request_rate,
other_server_args,
backend="sglang",
dataset_name="random",
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
):
# Launch the server
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
)
if backend == "sglang":
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
)

# Run benchmark
args = SimpleNamespace(
backend="sglang",
backend=backend,
base_url=base_url,
host=None,
port=None,
Expand Down
20 changes: 20 additions & 0 deletions test/srt/test_bench_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ def test_offline_throughput_default_fp8(self):
if is_in_ci():
assert res["output_throughput"] > 3100

def test_offline_throughput_default_engine(self):
res = run_bench_serving(
model=DEFAULT_MODEL_NAME_FOR_TEST,
num_prompts=500,
request_rate=float("inf"),
other_server_args=[],
)

def test_offline_throughput_llm_engine(self):
res = run_bench_serving(
backend="sgl-offline-engine",
model=DEFAULT_MODEL_NAME_FOR_TEST,
num_prompts=500,
request_rate=float("inf"),
other_server_args=[],
)

if is_in_ci():
assert res["output_throughput"] > 2830

def test_online_latency_default(self):
res = run_bench_serving(
model=DEFAULT_MODEL_NAME_FOR_TEST,
Expand Down
Loading