diff --git a/python/sglang/README.md b/python/sglang/README.md index c0a61e23611..95c626fd0b5 100644 --- a/python/sglang/README.md +++ b/python/sglang/README.md @@ -13,3 +13,4 @@ - `launch_server.py`: The entry point for launching the local server. - `llama3_eval.py`: Llama 3.1 evaluation with meta-llama dataset. - `utils.py`: Common utilities. +- `download.sh`: Script to download the datasets diff --git a/python/sglang/backend_request_func.py b/python/sglang/backend_request_func.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/bench/nextqa/__init__.py b/python/sglang/bench/nextqa/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/bench/nextqa/client.py b/python/sglang/bench/nextqa/client.py new file mode 100644 index 00000000000..3e98c7daf6f --- /dev/null +++ b/python/sglang/bench/nextqa/client.py @@ -0,0 +1,254 @@ +""" +Launch the benchmark client for Llava-video model. +Sends all the videos in a directory to the server and ask the LLM to discribe. +Example: unpack videos into ./videos and run the following command: +python client.py --port 3000 +""" + +import argparse +import os +import sys +import time +from typing import List + +import requests +from video import NExTQALoader, Video, VideoFileLoader, VideoPrompt + +import sglang as sgl +from sglang.utils import encode_video_base64 + + +@sgl.function +def video_qa(s, num_frames, video_path, question): + # note: the order of video and question does not matter + s += sgl.user( + sgl.video(video_path, num_frames) + question + ) # build request and encode video frames + # TODO: video_path + # s += sgl.user(question + sgl.video(video_path, num_frames)) + s += sgl.assistant(sgl.gen("answer")) # send request to the LLM + + +# @sgl.function +# def next_qa(s, num_frames, video_path, question, ): + + +class VideoClient: + def __init__(self, port: int): + self.port = port + # self.port = port + # self.endpoint = sgl.RuntimeEndpoint(f"http://localhost:{port}") + # sgl.set_default_backend(self.endpoint) + # print(f"chat template: {self.endpoint.chat_template.name}") + + def single(self, video_path: str, num_frames): + print("single() is not implemented yet") + + def batch(self, video_dir: str, num_frames, batch_size, save_dir): + print("batch() is not implemented yet") + + +class VideoClientSgl(VideoClient): + def __init__(self, port: int): + super().__init__(port) + self.endpoint = sgl.RuntimeEndpoint(f"http://localhost:{port}") + sgl.set_default_backend(self.endpoint) + print(f"chat template: {self.endpoint.chat_template.name}") + + def single(self, video: Video, prompt: str): + """ + Handle a single video + """ + if video.num_frames == 0: + print(f"Video {video.path} has 0 frames. Skipping...") + return + + print(video) + + start_time = time.time() + state = video_qa.run( + num_frames=video.num_frames, + video_path=video.path, + question=prompt, + temperature=0.0, + max_new_tokens=1024, + ) + answer = state["answer"] + total_time = time.time() - start_time + + print("Prompt: ", prompt) + print(f"Answer: {answer}") + print(f"Latency: {total_time} seconds.") + + def batch(self, video_prompts: List[VideoPrompt], save_dir="./answers"): + """ + Handle a batch of videos + """ + # remove invalid videos + valid_videos = [] + for video in video_prompts: + if video.num_frames == 0: + print(f"Video {video.path} has 0 frames. Skipping...") + else: + valid_videos.append(video) + if len(valid_videos) == 0: + print("No valid videos in this batch.") + return + videos = valid_videos + + # process batch input + print(f"Processing batch of {len(videos)} video(s)...") + + batch_input = [ + { + "num_frames": video.num_frames, + "video_path": video.path, + "question": video.prompt, + } + for video in videos + ] + + start_time = time.time() + + # query + states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2) + # save batch results + for state, video in zip(states, videos): + with open( + os.path.join(save_dir, os.path.basename(video.path) + ".log"), "w" + ) as f: + f.write(state["answer"]) + + total_time = time.time() - start_time + throughput = len(videos) / total_time + print( + f"Number of videos in batch: {len(videos)}.\n" + f"Total time for this batch: {total_time:.2f} seconds.\n" + f"Throughput: {throughput:.2f} videos/second" + ) + + +class VideoDiscrptClientSgl(VideoClientSgl): + """ + SGLang client for Video Discription + """ + + def __init__(self, port: int): + super().__init__(port) + + def single(self, video: Video): + super().single( + video, + "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.", + ) + + def batch(self, videos: List[Video], save_dir="./answers"): + prompt = "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes." + videos = [VideoPrompt(video.path, video.num_frames, prompt) for video in videos] + super().batch( + video_prompts=videos, + save_dir=save_dir, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Video client connected to specific port." + ) + parser.add_argument( + "--port", + type=int, + default=3000, + help="The master port for distributed serving.", + ) + parser.add_argument( + "--video-dir", + type=str, + default="./videos", + help="The directory or path for the processed video files.", + ) + parser.add_argument( + "--max-frames", + type=int, + default=sys.maxsize, + help="The maximum number of frames to process in each video.", + ) + parser.add_argument( + "--save-dir", + type=str, + default="./output", + help="The directory to save the processed video files.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Whether to process videos in batch.", + ) + + args = parser.parse_args() + + # -- load files and process + # client = VideoDiscrptClientSgl(args.port) + # videos = VideoFileLoader( + # video_dir=args.video_dir, batch_size=args.batch_size, max_frames=args.max_frames + # ) + videos = NExTQALoader( + video_dir=args.video_dir, max_frames=args.max_frames, batch_size=args.batch_size + ) + + # print(args.max_frames) + # if args.batch_size > 1: + # if not os.path.exists(args.save_dir): + # os.makedirs(args.save_dir) + # for batch in videos: + # client.batch(batch, save_dir=args.save_dir) + # else: + # for video in videos: + # client.single(video) + + # -- load NExTQA and process with SGLang frontend + # client = VideoClientSgl(args.port) + # if args.batch_size > 1: + # for batch in videos: + # # TODO: can fail if the frame size (W*H) is too large + # client.batch(batch, save_dir=args.save_dir) + # else: + # for video in videos: + # client.single(video, video.prompt) + + # -- load NExTQA and process with chat completions APIs + payload = { + "model": "lmms-lab/LLaVA-NeXT-Video-7B", + "temperature": 0.0, + "stream": True, + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + for video in videos: + path = video.path + num_frames = video.num_frames + base64_data = encode_video_base64(path, num_frames) + # print(base64_data) + message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": base64_data}}, + {"type": "text", "text": video.prompt}, + ], + } + payload["messages"] = [message] + payload["max_tokens"] = 1024 + print("send: ", message["content"][1]) + response = requests.post( + url=f"http://localhost:{args.port}/v1/chat/completions", + json=payload, + headers=headers, + ) + print(response.json()) + + # -- TODO: load NExTQA and process with /generate APIs diff --git a/python/sglang/bench/nextqa/server.py b/python/sglang/bench/nextqa/server.py new file mode 100644 index 00000000000..1759d0dc1e8 --- /dev/null +++ b/python/sglang/bench/nextqa/server.py @@ -0,0 +1,93 @@ +""" +Launch the inference server for Llava-video model. +Example: python server.py --model-path lmms-lab/LLaVA-NeXT-Video-7B --tokenizer-path llava-hf/llava-1.5-7b-hf --port 3000 --chat-template vicuna_v1.1 +""" + +import argparse +import multiprocessing as mp + +from sglang.srt.server import ServerArgs, launch_server + +if __name__ == "__main__": + # command line arguments + parser = argparse.ArgumentParser() + # add arguments + parser.add_argument( + "--max-frames", + type=int, + choices=[16, 32], + default=16, + help="The max number of frames to process in each video. If the input is less then max_frames, the model will pad the input to max_frames, and most of the time the output will be correct. However, if the input is more than max_frames, the model will output wrong answer", + ) + ServerArgs.add_cli_args(parser) + # parse cli arguments + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + + # model specific arguments + model_overide_args = {} + model_overide_args["mm_spatial_pool_stride"] = 2 + model_overide_args["architectures"] = ["LlavaVidForCausalLM"] + model_overide_args["num_frames"] = args.max_frames + model_overide_args["model_type"] = "llavavid" + if model_overide_args["num_frames"] == 32: + model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_overide_args["max_sequence_length"] = 4096 * 2 + model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + model_overide_args["model_max_length"] = 4096 * 2 + + print(f"num_frames: {model_overide_args['num_frames']}") + + if "34b" in args.model_path.lower(): + model_overide_args["image_token_index"] = 64002 + + pipe_reader, pipe_writer = mp.Pipe(duplex=False) + + launch_server(server_args, pipe_writer, model_overide_args) + +""" +Launch the inference server for Llava-video model. +Example: python server.py --model-path lmms-lab/LLaVA-NeXT-Video-7B --tokenizer-path llava-hf/llava-1.5-7b-hf --port 3000 --chat-template vicuna_v1.1 +""" + +import argparse +import multiprocessing as mp + +from sglang.srt.server import ServerArgs, launch_server + +if __name__ == "__main__": + # command line arguments + parser = argparse.ArgumentParser() + # add arguments + parser.add_argument( + "--max-frames", + type=int, + choices=[16, 32], + default=16, + help="The max number of frames to process in each video. If the input is less then max_frames, the model will pad the input to max_frames, and most of the time the output will be correct. However, if the input is more than max_frames, the model will output wrong answer", + ) + ServerArgs.add_cli_args(parser) + # parse cli arguments + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + + # model specific arguments + model_overide_args = {} + model_overide_args["mm_spatial_pool_stride"] = 2 + model_overide_args["architectures"] = ["LlavaVidForCausalLM"] + model_overide_args["num_frames"] = args.max_frames + model_overide_args["model_type"] = "llavavid" + if model_overide_args["num_frames"] == 32: + model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_overide_args["max_sequence_length"] = 4096 * 2 + model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + model_overide_args["model_max_length"] = 4096 * 2 + + print(f"num_frames: {model_overide_args['num_frames']}") + + if "34b" in args.model_path.lower(): + model_overide_args["image_token_index"] = 64002 + + pipe_reader, pipe_writer = mp.Pipe(duplex=False) + + launch_server(server_args, pipe_writer, model_overide_args) diff --git a/python/sglang/bench/nextqa/video.py b/python/sglang/bench/nextqa/video.py new file mode 100644 index 00000000000..6174121c103 --- /dev/null +++ b/python/sglang/bench/nextqa/video.py @@ -0,0 +1,234 @@ +import base64 +import os +import sys +from concurrent.futures import ThreadPoolExecutor +from io import BytesIO +from typing import List, Tuple + +import av +import numpy as np +from datasets import load_dataset + + +# Adopt from SGLang +def encode_frame(frame): + import cv2 # pip install opencv-python-headless + from PIL import Image + + # Convert the frame to RGB (OpenCV uses BGR by default) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Convert the frame to PIL Image to easily convert to bytes + im_pil = Image.fromarray(frame) + + # Convert to bytes + buffered = BytesIO() + + # frame_format = str(os.getenv('FRAME_FORMAT', "JPEG")) + + im_pil.save(buffered, format="PNG") + + frame_bytes = buffered.getvalue() + + # Return the bytes of the frame + return frame_bytes + + +# Adopt from SGLang +def encode_video_base64(video_path: str, num_frames: int = 16): + import cv2 # pip install opencv-python-headless + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise IOError(f"Could not open video file:{video_path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"target_frames: {num_frames}") + + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if ret: + frames.append(frame) + else: + # Handle the case where the frame could not be read + # print(f"Warning: Could not read frame at index {i}.") + pass + + cap.release() + + # Safely select frames based on frame_indices, avoiding IndexError + frames = [frames[i] for i in frame_indices if i < len(frames)] + + # If there are not enough frames, duplicate the last frame until we reach the target + while len(frames) < num_frames: + frames.append(frames[-1]) + + # Use ThreadPoolExecutor to process and encode frames in parallel + with ThreadPoolExecutor() as executor: + encoded_frames = list(executor.map(encode_frame, frames)) + + # encoded_frames = list(map(encode_frame, frames)) + + # Concatenate all frames bytes + video_bytes = b"".join(encoded_frames) + + # Encode the concatenated bytes to base64 + video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8") + + return video_base64 + + +def find_video_files(video_dir) -> List[str]: + if os.path.isfile(video_dir): + return [video_dir] + + video_files = [] + for root, dirs, files in os.walk(video_dir): + for file in files: + if file.endswith((".mp4", ".avi", ".mov")): + video_files.append(os.path.join(root, file)) + # if file is dir + elif os.path.isdir(file): + video_files.extend(find_video_files(file)) + return video_files + + +def video_frames(video_path, max_frames) -> int: + container = av.open(video_path) + total_frames = container.streams.video[0].frames + return min(total_frames, max_frames) + + +class Video: + def __init__(self, video_path, num_frames): + self.path = video_path + self.num_frames = num_frames + + def __str__(self): + return f"Video({self.path}, {self.num_frames})" + + def __iter__(self): + return iter((self.path, self.num_frames)) + + +class VideoPrompt(Video): + def __init__(self, video_path, num_frames, prompt): + super().__init__(video_path, num_frames) + self.prompt = prompt + + def __str__(self): + return f"VideoPrompt({self.path}, {self.num_frames}, {self.prompt})" + + def __iter__(self): + return iter((self.path, self.num_frames, self.prompt)) + + +class VideoLoader: + pass + + +class VideoFileLoader(VideoLoader): + """ + Load all the videos in a directory + """ + + def __init__(self, video_dir, batch_size=1, max_frames=sys.maxsize): + super().__init__() + self.video_dir = video_dir + self.video_files = find_video_files(video_dir) + self.batch_size = batch_size + self.max_frames = max_frames + print(f"batch_size: {batch_size}, max_frames: {max_frames}") + + def __iter__(self): # (file, number of frames) + if self.batch_size == 1: + for video_file in self.video_files: + yield Video(video_file, video_frames(video_file, self.max_frames)) + else: + batch = [] + for video_file in self.video_files: + video = Video(video_file, video_frames(video_file, self.max_frames)) + batch.append(video) + if len(batch) == self.batch_size: + yield batch + batch = [] + + +class NExTQALoader(VideoLoader): + """ + Load vdideos and prompts from NExT dataset + set: train, test or validation + """ + + def __init__( + self, video_dir, batch_size=1, max_frames=sys.maxsize, dset="test", task="OE" + ): + """ + task: 'MV' or 'OE' + """ + super().__init__() + self.task = task + print(f"Loading the {dset} data of {task} from lmms-lab/NExTQA") + self.ds = load_dataset("lmms-lab/NExTQA", task) + self.ds = self.ds[dset] + + # self.n = ds.num_rows + self.video_dir = video_dir + self.video_files = find_video_files(video_dir) + self.video_to_path = dict() + for video_file in self.video_files: + video_id = video_file.split("/")[-1].split(".")[0] + self.video_to_path[video_id] = video_file + + self.batch_size = batch_size + self.max_frames = max_frames + + def get_video_prompt(self, entry, max_frames) -> VideoPrompt: + # Get video + video_id = entry["video"] + video_path = self.video_to_path[video_id] + assert os.path.exists(video_path), f"Video not found: {video_path}" + num_frames = min(entry["frame_count"], max_frames) + video = Video(video_path, num_frames) + prompt = entry["question"] + "?" + if self.task == "MC": # add choices + prompt += f' a0: {entry["a0"]}, a1: {entry["a1"]}, a2: {entry["a2"]}, a3: {entry["a3"]}' + return VideoPrompt(video_path, num_frames, prompt) + + def __iter__(self): + if self.batch_size == 1: + for entry in self.ds: + yield self.get_video_prompt(entry, self.max_frames) + else: + batch = [] + for entry in self.ds: + video = self.get_video_prompt(entry, self.max_frames) + batch.append(video) + if len(batch) == self.batch_size: + yield batch + batch = [] + + +# main +if __name__ == "__main__": + video_dir = "./videos" + # video_loader = VideoFileLoader(video_dir, batch_size=16) + # for batch in video_loader: + # print(f"Number of videos in batch: {len(batch)}") + # for video_file, num_frames in batch: + # print(f"Video: {video_file} number of frames: {num_frames}") + + video_loader = NExTQALoader(video_dir, batch_size=16, dset="test", task="OE") + for batch in video_loader: + print(f"Number of videos in batch: {len(batch)}") + for video_file, num_frames, prompt in batch: + print( + f"Video: {video_file} number of frames: {num_frames}, prompt: {prompt}" + ) + # break + # for video_file, prompt in batch: + # print(f"Video: {video_file} prompt: {prompt}") + # break diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 6067a7444eb..b020faff746 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -15,7 +15,6 @@ import asyncio import json import os -import pickle import random import resource import sys @@ -25,7 +24,6 @@ from argparse import ArgumentParser from dataclasses import dataclass, field from datetime import datetime -from pathlib import Path from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union import aiohttp @@ -39,32 +37,38 @@ PreTrainedTokenizerFast, ) -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) +from sglang.data_processing import SampleOutput, get_dataset +from sglang.utils import MsgContent + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60) global args @dataclass class RequestFuncInput: - prompt: str + prompts: List[Tuple[MsgContent, int, int]] api_url: str - prompt_len: int - output_len: int model: str lora_name: str extra_request_body: Dict[str, Any] + # For multiturn chat, store the context + prev_messages: List = field(default_factory=list) + finished_prompts: int = 0 + @dataclass class RequestFuncOutput: - generated_text: str = "" - success: bool = False - latency: float = 0.0 - ttft: float = 0.0 # Time to first token + generated_text: List[str] = field(default_factory=list) + prompt_len: List[int] = field(default_factory=list) + output_len: List[int] = field(default_factory=list) + latency: List[float] = field(default_factory=list) + ttft: List[float] = field(default_factory=list) itl: List[float] = field(default_factory=list) # List of inter-token latencies - prompt_len: int = 0 + + success: bool = False error: str = "" - output_len: int = 0 def remove_prefix(text: str, prefix: str) -> str: @@ -75,6 +79,8 @@ def remove_prefix(text: str, prefix: str) -> str: # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 async def async_request_trt_llm( request_func_input: RequestFuncInput, + queue: asyncio.Queue, + tokenizer: PreTrainedTokenizerBase, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url @@ -83,21 +89,35 @@ async def async_request_trt_llm( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { "accumulate_tokens": True, - "text_input": request_func_input.prompt, "temperature": 0.000001, "top_p": 1.0, - "max_tokens": request_func_input.output_len, "stream": True, - "min_length": request_func_input.output_len, "end_id": 1048576, **request_func_input.extra_request_body, } + + prompt_idx = request_func_input.finished_prompts + messages = request_func_input.prev_messages + prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx] + prompt_len = sum( + prompt[1] + prompt[2] # input_len + output_len + for prompt in request_func_input.prompts[:prompt_idx] + ) + prompt_len += input_len + + # TODO: Check out whether trt-llm supports native multiturn chat + messages.append(prompt) + payload["text_input"] = " ".join(messages) + payload["max_tokens"] = max_tokens + payload["min_length"] = max_tokens + if args.disable_ignore_eos: del payload["min_length"] del payload["end_id"] + output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len + generated_text = "" ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st @@ -112,12 +132,12 @@ async def async_request_trt_llm( chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") data = json.loads(chunk) - output.generated_text += data["text_output"] + generated_text += data["text_output"] timestamp = time.perf_counter() # First token if ttft == 0.0: ttft = time.perf_counter() - st - output.ttft = ttft + output.ttft.append(ttft) # Decoding phase else: @@ -125,9 +145,27 @@ async def async_request_trt_llm( most_recent_timestamp = timestamp - output.latency = most_recent_timestamp - st + output_len = len(tokenizer(generated_text).input_ids) + output.prompt_len.append(prompt_len - 1) # truncate + output.output_len.append(output_len) + output.generated_text.append(generated_text) + output.latency.append(most_recent_timestamp - st) output.success = True - output.output_len = request_func_input.output_len + + # Prepare for the new request + request_func_input.prompts[prompt_idx] = ( + prompt, + input_len, + output_len, # changes from max_tokens to output_len + ) + prompt_idx += 1 + messages.append(generated_text) + + # Move the new request to the end of the queue + if prompt_idx < len(request_func_input.prompts): + request_func_input.finished_prompts = prompt_idx + request_func_input.prev_messages = messages + await queue.put(request_func_input) else: output.error = response.reason or "" @@ -145,6 +183,8 @@ async def async_request_trt_llm( # set ignore_eos True by default async def async_request_openai_completions( request_func_input: RequestFuncInput, + queue: asyncio.Queue, + tokenizer: PreTrainedTokenizerBase, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url @@ -152,23 +192,43 @@ async def async_request_openai_completions( "completions" ), "OpenAI Completions API URL must end with 'completions'." - prompt = request_func_input.prompt - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { "model": request_func_input.model, - "prompt": prompt, "temperature": 0.0, "best_of": 1, - "max_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, **request_func_input.extra_request_body, } - headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len + + prompt_idx = request_func_input.finished_prompts + messages = request_func_input.prev_messages + prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx] + prompt_len = sum( + prompt[1] + prompt[2] # input_len + output_len + for prompt in request_func_input.prompts[:prompt_idx] + ) + prompt_len += input_len + + # Messages + messages.append( + { + "role": "user", + "content": prompt, + } + ) + payload["messages"] = messages + payload["max_tokens"] = max_tokens + + # output.prompt_len = request_func_input.prompt_len + # print(payload) generated_text = "" ttft = 0.0 @@ -190,28 +250,51 @@ async def async_request_openai_completions( pass else: data = json.loads(chunk) - + timestamp = time.perf_counter() # NOTE: Some completion API might have a last # usage summary response without a token so we # want to check a token was generated - if data["choices"][0]["text"]: - timestamp = time.perf_counter() + delta = data["choices"][0]["delta"] + + if delta.get("content", None): # First token if ttft == 0.0: ttft = time.perf_counter() - st - output.ttft = ttft + output.ttft.append(ttft) # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) - most_recent_timestamp = timestamp - generated_text += data["choices"][0]["text"] + generated_text += delta["content"] + most_recent_timestamp = timestamp - output.generated_text = generated_text + output_len = len(tokenizer(generated_text).input_ids) + output.prompt_len.append(prompt_len - 1) # truncate + output.output_len.append(output_len) + output.generated_text.append(generated_text) output.success = True - output.latency = latency - output.output_len = request_func_input.output_len + output.latency.append(latency) + + # Prepare for the new request + request_func_input.prompts[prompt_idx] = ( + prompt, + input_len, + output_len, # changes from max_tokens to output_len + ) + prompt_idx += 1 + messages.append( + { + "role": "assistant", + "content": generated_text, + } + ) + + # Move the new request to the end of the queue + if prompt_idx < len(request_func_input.prompts): + request_func_input.finished_prompts = prompt_idx + request_func_input.prev_messages = messages + await queue.put(request_func_input) else: output.error = response.reason or "" output.success = False @@ -227,19 +310,17 @@ async def async_request_openai_completions( async def async_request_truss( request_func_input: RequestFuncInput, + queue: asyncio.Queue, + tokenizer: PreTrainedTokenizerBase, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - prompt = request_func_input.prompt - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { "model": request_func_input.model, - "prompt": prompt, "temperature": 0.0, "best_of": 1, - "max_tokens": request_func_input.output_len, "stream": not args.disable_stream, "ignore_eos": not args.disable_ignore_eos, **request_func_input.extra_request_body, @@ -247,7 +328,20 @@ async def async_request_truss( headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len + + prompt_idx = request_func_input.finished_prompts + messages = request_func_input.prev_messages + prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx] + prompt_len = sum( + prompt[1] + prompt[2] # input_len + output_len + for prompt in request_func_input.prompts[:prompt_idx] + ) + prompt_len += input_len + + # TODO: Checkout truss to see whether there is a another field + messages.append(prompt) + payload["prompt"] = " ".join(messages) + payload["max_tokens"] = max_tokens generated_text = "" ttft = 0.0 @@ -278,7 +372,7 @@ async def async_request_truss( # First token if ttft == 0.0: ttft = time.perf_counter() - st - output.ttft = ttft + output.ttft.append(ttft) # Decoding phase else: @@ -287,10 +381,27 @@ async def async_request_truss( most_recent_timestamp = timestamp generated_text += data["choices"][0]["delta"]["content"] - output.generated_text = generated_text + output_len = len(tokenizer(generated_text).input_ids) + output.prompt_len.append(prompt_len - 1) # truncate + output.output_len.append(output_len) + output.generated_text.append(generated_text) output.success = True - output.latency = latency - output.output_len = request_func_input.output_len + output.latency.append(latency) + + # Prepare for the new request + request_func_input.prompts[prompt_idx] = ( + prompt, + input_len, + output_len, # changes from max_tokens to output_len + ) + prompt_idx += 1 + messages.append(generated_text) + + # Move the new request to the end of the queue + if prompt_idx < len(request_func_input.prompts): + request_func_input.finished_prompts = prompt_idx + request_func_input.prev_messages = messages + await queue.put(request_func_input) else: output.error = response.reason or "" output.success = False @@ -306,17 +417,16 @@ async def async_request_truss( async def async_request_sglang_generate( request_func_input: RequestFuncInput, + queue: asyncio.Queue, + tokenizer: PreTrainedTokenizerBase, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - prompt = request_func_input.prompt async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { - "text": prompt, "sampling_params": { "temperature": 0.0, - "max_new_tokens": request_func_input.output_len, "ignore_eos": not args.disable_ignore_eos, }, "stream": not args.disable_stream, @@ -328,7 +438,23 @@ async def async_request_sglang_generate( headers = {} output = RequestFuncOutput() - output.prompt_len = request_func_input.prompt_len + + prompt_idx = request_func_input.finished_prompts + messages = request_func_input.prev_messages + prompt, input_len, max_tokens = request_func_input.prompts[prompt_idx] + prompt_len = sum( + prompt[1] + prompt[2] # input_len + output_len + for prompt in request_func_input.prompts[:prompt_idx] + ) + prompt_len += input_len + + # TODO: Make use of the new session field of the GenerateReqInput + # Now we simply concatenate all the prompts and responses in the + # text field + + messages.append(prompt) + payload["text"] = " ".join(messages) + payload["sampling_params"]["max_new_tokens"] = max_tokens generated_text = "" ttft = 0.0 @@ -360,19 +486,36 @@ async def async_request_sglang_generate( # First token if ttft == 0.0: ttft = time.perf_counter() - st - output.ttft = ttft + output.ttft.append(ttft) # Decoding phase else: output.itl.append(timestamp - most_recent_timestamp) - most_recent_timestamp = timestamp generated_text = data["text"] + most_recent_timestamp = timestamp - output.generated_text = generated_text + output_len = len(tokenizer(generated_text).input_ids) + output.prompt_len.append(prompt_len - 1) # truncate + output.output_len.append(output_len) + output.generated_text.append(generated_text) output.success = True - output.latency = latency - output.output_len = request_func_input.output_len + output.latency.append(latency) + + # Prepare for the new request + request_func_input.prompts[prompt_idx] = ( + prompt, + input_len, + output_len, # changes from max_tokens to output_len + ) + prompt_idx += 1 + messages.append(generated_text) + + # Move the new request to the end of the queue + if prompt_idx < len(request_func_input.prompts): + request_func_input.finished_prompts = prompt_idx + request_func_input.prev_messages = messages + await queue.put(request_func_input) else: output.error = response.reason or "" output.success = False @@ -388,6 +531,8 @@ async def async_request_sglang_generate( async def async_request_gserver( request_func_input: RequestFuncInput, + queue: asyncio.Queue, + tokenizer: PreTrainedTokenizerBase, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: raise NotImplementedError() @@ -445,37 +590,6 @@ def get_tokenizer( ) -def get_dataset(args, tokenizer): - if args.dataset_name == "sharegpt": - input_requests = sample_sharegpt_requests( - dataset_path=args.dataset_path, - num_requests=args.num_prompts, - tokenizer=tokenizer, - fixed_output_len=args.sharegpt_output_len, - ) - elif args.dataset_name == "random": - input_requests = sample_random_requests( - input_len=args.random_input_len, - output_len=args.random_output_len, - num_prompts=args.num_prompts, - range_ratio=args.random_range_ratio, - tokenizer=tokenizer, - dataset_path=args.dataset_path, - ) - elif args.dataset_name == "generated-shared-prefix": - input_requests = sample_generated_shared_prefix_requests( - num_groups=args.gen_num_groups, - prompts_per_group=args.gen_prompts_per_group, - system_prompt_len=args.gen_system_prompt_len, - question_len=args.gen_question_len, - output_len=args.gen_output_len, - tokenizer=tokenizer, - ) - else: - raise ValueError(f"Unknown dataset: {args.dataset_name}") - return input_requests - - ASYNC_REQUEST_FUNCS = { "sglang": async_request_sglang_generate, "sglang-native": async_request_sglang_generate, @@ -503,309 +617,46 @@ class BenchmarkMetrics: mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float + p90_ttft_ms: float p99_ttft_ms: float mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float + p90_tpot_ms: float p99_tpot_ms: float mean_itl_ms: float median_itl_ms: float std_itl_ms: float + p90_itl_ms: float p99_itl_ms: float mean_e2e_latency_ms: float median_e2e_latency_ms: float -SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" - - -def download_and_cache_file(url: str, filename: Optional[str] = None): - """Read and cache a file from a url.""" - if filename is None: - filename = os.path.join("/tmp", url.split("/")[-1]) - - # Check if the cache file already exists - if os.path.exists(filename): - return filename - - print(f"Downloading from {url} to {filename}") - - # Stream the response to show the progress bar - response = requests.get(url, stream=True) - response.raise_for_status() # Check for request errors - - # Total size of the file in bytes - total_size = int(response.headers.get("content-length", 0)) - chunk_size = 1024 # Download in chunks of 1KB - - # Use tqdm to display the progress bar - with open(filename, "wb") as f, tqdm( - desc=filename, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for chunk in response.iter_content(chunk_size=chunk_size): - f.write(chunk) - bar.update(len(chunk)) - - return filename - - -def sample_sharegpt_requests( - dataset_path: str, - num_requests: int, - tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") - - # Download sharegpt if necessary - if not os.path.isfile(dataset_path): - dataset_path = download_and_cache_file(SHAREGPT_URL) - - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] - for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: +async def get_requests( + input_requests_queue: asyncio.Queue, + request_rate: float, + num_actual_requests: int, +) -> AsyncGenerator[RequestFuncInput, None]: + for _ in range(num_actual_requests): + try: + request = await asyncio.wait_for( + input_requests_queue.get(), timeout=300 + ) # Wait for 5 minites then abort + except Exception as e: + print(f"exception: {e}") break - # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer.encode(prompt) - completion = dataset[i][1] - completion_token_ids = tokenizer.encode(completion) - prompt_len = len(prompt_token_ids) - output_len = ( - len(completion_token_ids) if fixed_output_len is None else fixed_output_len - ) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue - if prompt_len > 1024 or ( - prompt_len + output_len > 2048 and fixed_output_len is None - ): - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) - - print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") - print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}") - return filtered_dataset - - -def sample_random_requests( - input_len: int, - output_len: int, - num_prompts: int, - range_ratio: float, - tokenizer: PreTrainedTokenizerBase, - dataset_path: str, -) -> List[Tuple[str, int, int]]: - - input_lens = np.random.randint( - max(int(input_len * range_ratio), 1), - input_len + 1, - size=num_prompts, - ) - output_lens = np.random.randint( - int(output_len * range_ratio), - output_len + 1, - size=num_prompts, - ) - - if True: - # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens - - # Download sharegpt if necessary - if not os.path.isfile(dataset_path): - dataset_path = download_and_cache_file(SHAREGPT_URL) - - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - # Shuffle the dataset. - random.shuffle(dataset) - - # Filter out sequences that are too long or too short - input_requests: List[Tuple[str, int, int]] = [] - for data in dataset: - i = len(input_requests) - if i == num_prompts: - break - - # Tokenize the prompts and completions. - prompt = data[0] - prompt_token_ids = tokenizer.encode(prompt) - prompt_len = len(prompt_token_ids) - - # Skip empty prompt - if prompt_len == 0: - continue - - if prompt_len > input_lens[i]: - input_ids = prompt_token_ids[: input_lens[i]] - else: - ratio = (input_lens[i] + prompt_len - 1) // prompt_len - input_ids = (prompt_token_ids * ratio)[: input_lens[i]] - prompt = tokenizer.decode(input_ids) - input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) - else: - # Sample token ids from random integers. This can cause some NaN issues. - offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) - input_requests = [] - for i in range(num_prompts): - prompt = tokenizer.decode( - [ - (offsets[i] + i + j) % tokenizer.vocab_size - for j in range(input_lens[i]) - ] - ) - input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) - - print(f"#Input tokens: {np.sum(input_lens)}") - print(f"#Output tokens: {np.sum(output_lens)}") - return input_requests - - -def gen_prompt(tokenizer, token_num): - """Generate a random prompt of specified token length using tokenizer vocabulary.""" - all_available_tokens = list(tokenizer.get_vocab().values()) - selected_tokens = random.choices(all_available_tokens, k=token_num) - return tokenizer.decode(selected_tokens) - - -def get_gen_prefix_cache_path(args, tokenizer): - """Create cache directory under ~/.cache/sglang/benchmark""" - cache_dir = Path.home() / ".cache" / "sglang" / "benchmark" - - # Create a unique cache filename based on the generation parameters - cache_key = ( - f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_" - f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_" - f"{tokenizer.__class__.__name__}.pkl" - ) - return cache_dir / cache_key - - -def sample_generated_shared_prefix_requests( - num_groups: int, - prompts_per_group: int, - system_prompt_len: int, - question_len: int, - output_len: int, - tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, int, int]]: - """Generate benchmark requests with shared system prompts using random tokens and caching.""" - cache_path = get_gen_prefix_cache_path(args, tokenizer) - - # Try to load from cache first - if cache_path.exists(): - print(f"\nLoading cached generated input data from {cache_path}") - with open(cache_path, "rb") as f: - return pickle.load(f) - - print("\nGenerating new input data...") - - # Generate system prompts for each group - system_prompts = [] - for _ in range(num_groups): - system_prompt = gen_prompt(tokenizer, system_prompt_len) - system_prompts.append(system_prompt) - - # Generate questions - questions = [] - for _ in range(num_groups * prompts_per_group): - question = gen_prompt(tokenizer, question_len) - questions.append(question) - - # Combine system prompts with questions - input_requests = [] - total_input_tokens = 0 - total_output_tokens = 0 - - for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): - system_prompt = system_prompts[group_idx] - for prompt_idx in tqdm( - range(prompts_per_group), desc="Generating questions", leave=False - ): - question = questions[group_idx * prompts_per_group + prompt_idx] - full_prompt = f"{system_prompt}\n\n{question}" - prompt_len = len(tokenizer.encode(full_prompt)) - - input_requests.append((full_prompt, prompt_len, output_len)) - total_input_tokens += prompt_len - total_output_tokens += output_len - - # Shuffle questions - random.shuffle(input_requests) - - # Print statistics - print(f"\nGenerated shared prefix dataset statistics:") - print(f"Number of groups: {num_groups}") - print(f"Prompts per group: {prompts_per_group}") - print(f"Total prompts: {len(input_requests)}") - print(f"Total input tokens: {total_input_tokens}") - print(f"Total output tokens: {total_output_tokens}") - print( - f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens" - ) - print( - f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n" - ) - - # Save to cache - cache_path.parent.mkdir(parents=True, exist_ok=True) - print(f"Caching generated input data to {cache_path}") - with open(cache_path, "wb") as f: - pickle.dump(input_requests, f) - - return input_requests - - -async def get_request( - input_requests: List[Tuple[str, int, int]], - request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: - input_requests = iter(input_requests) - for request in input_requests: yield request if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. continue - # Sample the request interval from the exponential distribution. interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. await asyncio.sleep(interval) def calculate_metrics( - input_requests: List[Tuple[str, int, int]], outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, @@ -819,23 +670,32 @@ def calculate_metrics( tpots: List[float] = [] ttfts: List[float] = [] e2e_latencies: List[float] = [] + output_success = 0 for i in range(len(outputs)): if outputs[i].success: - output_len = outputs[i].output_len - output_lens.append(output_len) - retokenized_output_len = len( - tokenizer.encode(outputs[i].generated_text, add_special_tokens=False) - ) - retokenized_output_lens.append(retokenized_output_len) - total_input += input_requests[i][1] - if output_len > 1: - tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + output_success += 1 + assert len(outputs[i].generated_text) == len(outputs[i].latency) + assert len(outputs[i].generated_text) == len(outputs[i].ttft) + for j in range(len(outputs[i].generated_text)): + output_len = outputs[i].output_len[j] + output_lens.append(output_len) + retokenized_output_len = len( + tokenizer.encode( + outputs[i].generated_text[j], add_special_tokens=False + ) + ) + retokenized_output_lens.append(retokenized_output_len) + total_input += outputs[i].prompt_len[j] + if output_len > 1: + tpots.append( + (outputs[i].latency[j] - outputs[i].ttft[j]) / (output_len - 1) + ) + + completed += 1 itls += outputs[i].itl - ttfts.append(outputs[i].ttft) + ttfts += outputs[i].ttft + e2e_latencies += outputs[i].latency - e2e_latencies.append(outputs[i].latency) - - completed += 1 else: output_lens.append(0) retokenized_output_lens.append(0) @@ -862,14 +722,17 @@ def calculate_metrics( * 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, + p90_ttft_ms=np.percentile(ttfts or 0, 90) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, + p90_tpot_ms=np.percentile(tpots or 0, 90) * 1000, p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, + p90_itl_ms=np.percentile(itls or 0, 90) * 1000, p99_itl_ms=np.percentile(itls or 0, 99) * 1000, mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000, @@ -884,13 +747,14 @@ async def benchmark( base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[Tuple[str, int, int]], + input_requests: SampleOutput, request_rate: float, max_concurrency: Optional[int], disable_tqdm: bool, lora_name: str, extra_request_body: Dict[str, Any], profile: bool, + enable_shared_prefix: bool, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -901,25 +765,42 @@ async def benchmark( # From https://github.com/vllm-project/vllm/pull/9390 semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None - async def limited_request_func(request_func_input, pbar): + async def limited_request_func(request_func_input, queue, tokenizer, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, pbar=pbar) + return await request_func( + request_func_input=request_func_input, + queue=queue, + tokenizer=tokenizer, + pbar=pbar, + ) async with semaphore: - return await request_func(request_func_input=request_func_input, pbar=pbar) + return await request_func( + request_func_input=request_func_input, + queue=queue, + tokenizer=tokenizer, + pbar=pbar, + ) + + num_actual_requests = sum(len(r) for r in input_requests) + print(f"Num of shared prefixes or conversations: {len(input_requests)}") + print(f"Num of total requests: {num_actual_requests}") - # Warmup + # flatten the requests for shared prefix + if enable_shared_prefix: + input_requests = [[r] for requests in input_requests for r in requests] + inputs_requests_queue = asyncio.Queue(maxsize=len(input_requests)) print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + # NOTE: Just use the first request of the first conversation for warmup test_input = RequestFuncInput( model=model_id, - prompt=test_prompt, + prompts=input_requests[0][:1], api_url=api_url, - prompt_len=test_prompt_len, - output_len=min(test_output_len, 32), lora_name=lora_name, extra_request_body=extra_request_body, ) - test_output = await request_func(request_func_input=test_input) + test_output = await request_func( + request_func_input=test_input, queue=inputs_requests_queue, tokenizer=tokenizer + ) if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " @@ -928,6 +809,9 @@ async def limited_request_func(request_func_input, pbar): else: print("Initial test run completed. Starting main benchmark run...") + # Check the states + assert inputs_requests_queue.empty() + # Flush cache if "sglang" in backend: requests.post(base_url + "/flush_cache") @@ -943,25 +827,37 @@ async def limited_request_func(request_func_input, pbar): if profile_output.success: print("Profiler started") - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - - # Run all requests - benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request + for request in input_requests: request_func_input = RequestFuncInput( model=model_id, - prompt=prompt, + prompts=request, api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, lora_name=lora_name, extra_request_body=extra_request_body, ) + inputs_requests_queue.put_nowait(request_func_input) + if ( + not args.enable_multiturn + and not args.enable_shared_prefix + and not args.dataset_name == "generated-shared-prefix" + ): + assert len(input_requests) == num_actual_requests + + pbar = None if disable_tqdm else tqdm(total=num_actual_requests) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_requests( + inputs_requests_queue, request_rate, num_actual_requests + ): tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, pbar=pbar) + limited_request_func( + request_func_input=request, + queue=inputs_requests_queue, + tokenizer=tokenizer, + pbar=pbar, + ) ) ) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) @@ -979,7 +875,6 @@ async def limited_request_func(request_func_input, pbar): # Compute metrics and print results benchmark_duration = time.perf_counter() - benchmark_start_time metrics, output_lens = calculate_metrics( - input_requests=input_requests, outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, @@ -1036,16 +931,19 @@ async def limited_request_func(request_func_input, pbar): print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P90 TTFT (ms):", metrics.p90_ttft_ms)) print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) print( "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") ) print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P90 TPOT (ms):", metrics.p90_tpot_ms)) print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P90 ITL (ms):", metrics.p90_itl_ms)) print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) print("=" * 50) @@ -1067,7 +965,7 @@ async def limited_request_func(request_func_input, pbar): "median_ttft_ms": metrics.median_ttft_ms, "median_itl_ms": metrics.median_itl_ms, "output_throughput": metrics.output_throughput, - "sharegpt_output_len": args.sharegpt_output_len, + "fixed_output_len": args.fixed_output_len, "random_input_len": args.random_input_len, "random_output_len": args.random_output_len, "random_range_ratio": args.random_range_ratio, @@ -1104,14 +1002,17 @@ async def limited_request_func(request_func_input, pbar): "mean_ttft_ms": metrics.mean_ttft_ms, "median_ttft_ms": metrics.median_ttft_ms, "std_ttft_ms": metrics.std_ttft_ms, + "p90_ttft_ms": metrics.p90_ttft_ms, "p99_ttft_ms": metrics.p99_ttft_ms, "mean_tpot_ms": metrics.mean_tpot_ms, "median_tpot_ms": metrics.median_tpot_ms, "std_tpot_ms": metrics.std_tpot_ms, + "p90_tpot_ms": metrics.p90_tpot_ms, "p99_tpot_ms": metrics.p99_tpot_ms, "mean_itl_ms": metrics.mean_itl_ms, "median_itl_ms": metrics.median_itl_ms, "std_itl_ms": metrics.std_itl_ms, + "p90_itl_ms": metrics.p90_itl_ms, "p99_itl_ms": metrics.p99_itl_ms, "input_lens": [output.prompt_len for output in outputs], "output_lens": output_lens, @@ -1186,9 +1087,9 @@ def run_benchmark(args_: argparse.Namespace): ) elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: api_url = ( - f"{args.base_url}/v1/completions" + f"{args.base_url}/v1/chat/completions" if args.base_url - else f"http://{args.host}:{args.port}/v1/completions" + else f"http://{args.host}:{args.port}/v1/chat/completions" ) elif args.backend == "trt": api_url = ( @@ -1240,6 +1141,20 @@ def run_benchmark(args_: argparse.Namespace): "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n" ) + # Dataset compatibility check + if args.enable_multiturn: + # TODO: Support multiturn for random + if args.dataset_name not in ["sharegpt", "ultrachat", "loogle", "nextqa"]: + print( + "Multiturn conversation is only supported for sharegpt, ultrachat, loogle, and nextqa datasets." + ) + sys.exit(1) + + if args.enable_shared_prefix: + if args.dataset_name not in ["loogle", "nextqa"]: + print("Shared prefix is only supported for loogle and nextqa datasets.") + sys.exit(1) + print(f"{args}\n") # Read dataset @@ -1266,6 +1181,7 @@ def run_benchmark(args_: argparse.Namespace): lora_name=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, + enable_shared_prefix=args.enable_shared_prefix, ) ) else: @@ -1287,6 +1203,7 @@ def run_benchmark(args_: argparse.Namespace): lora_name=args.lora_name, extra_request_body=extra_request_body, profile=args.profile, + enable_shared_prefix=args.enable_shared_prefix, ) ) @@ -1329,7 +1246,14 @@ def set_ulimit(target_soft_limit=65535): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "random", "generated-shared-prefix"], + choices=[ + "sharegpt", + "random", + "generated-shared-prefix", + "ultrachat", + "loogle", + "nextqa", + ], help="Name of the dataset to benchmark on.", ) parser.add_argument( @@ -1352,10 +1276,10 @@ def set_ulimit(target_soft_limit=65535): help="Number of prompts to process. Default is 1000.", ) parser.add_argument( - "--sharegpt-output-len", + "--fixed-output-len", type=int, default=None, - help="Output length for each request. Overrides the output length from the ShareGPT dataset.", + help="Output length for each request. Overrides the output length from the dataset.", ) parser.add_argument( "--random-input-len", @@ -1409,6 +1333,27 @@ def set_ulimit(target_soft_limit=65535): help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.", ) parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument( + "--enable-multiturn", + action="store_true", + help="Enable multiturn chat for online serving benchmarking. " + "This option is effective on the following datasets: " + "sharegpt, ultrachat, loogle, nextqa", + ) + parser.add_argument( + "--enable-shared-prefix", + action="store_true", + help="Enable shared prefix for online serving benchmarking. " + "This option is effective on the following datasets: " + "loogle, nextqa", + ) + + parser.add_argument( + "--disable-shuffle", + action="store_true", + help="Disable shuffling datasets. This is useful to generate stable output " + "in benchmarking", + ) parser.add_argument( "--disable-tqdm", action="store_true", @@ -1480,5 +1425,19 @@ def set_ulimit(target_soft_limit=65535): default=None, help="The name of LoRA adapter", ) + # videos specific + parser.add_argument( + "--max-frames", + type=int, + default=sys.maxsize, + help="The maximum number of frames to extract from each video. " + "This option is specific to the nextqa dataset (video benchmark). ", + ) args = parser.parse_args() + + if args.enable_multiturn and args.enable_shared_prefix: + parser.error( + "--enable-multiturn and --enable-shared-prefix cannot be set at the same time." + ) + run_benchmark(args) diff --git a/python/sglang/data_processing.py b/python/sglang/data_processing.py new file mode 100644 index 00000000000..d67bdf8a832 --- /dev/null +++ b/python/sglang/data_processing.py @@ -0,0 +1,593 @@ +import json +import os +import pickle +import random +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import requests +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from sglang.bench.nextqa.video import NExTQALoader, VideoPrompt, encode_video_base64 +from sglang.utils import MsgContent + +SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + +# A list of all the conversations. Each conversation is a list of +# tuples. If multiturn is not enabled, the length of list is 1, +# containing only the first Q&A pair. +# For the shared prefix workload (synthetic, loogle, nextqa), it +# is a list of conversations sharing the same prefix (synthetic, +# doc, video) +SampleOutput = List[List[Tuple[MsgContent, int, int]]] + + +def common_filter_chat( + num_requests: int, + new_dataset: List, + tokenizer: PreTrainedTokenizerBase, + min_prompt_len: Optional[int], + min_output_len: Optional[int], + max_prompt_len: Optional[int], + max_output_len: Optional[int], + fixed_output_len: Optional[int], +) -> SampleOutput: + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = [] + l = 0 + input_tokens = 0 + output_tokens = 0 + while l < num_requests: + for i in range(len(new_dataset)): + if l == num_requests: + break + processed = [] + for j in new_dataset[i]: + # Tokenize the prompts and completions. + prompt = j[0] + prompt_token_ids = tokenizer.encode(prompt) + prompt_len = len(prompt_token_ids) + + completion = j[1] + completion_token_ids = tokenizer.encode(completion) + output_len = ( + len(completion_token_ids) + if fixed_output_len is None + else fixed_output_len + ) + if ( + min_prompt_len is not None + and prompt_len < min_prompt_len + or min_output_len is not None + and output_len < min_output_len + or max_prompt_len is not None + and prompt_len > max_prompt_len + or max_output_len is not None + and output_len > max_output_len + ): + # Prune too short sequences. + continue + input_tokens += prompt_len + output_tokens += output_len + processed.append((prompt, prompt_len, output_len)) + filtered_dataset.append(processed) + l += 1 + + print(f"#Input tokens: {input_tokens}") + print(f"#Output tokens: {output_tokens}") + return filtered_dataset + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename + + +def sample_sharegpt_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + disable_shuffle: bool = False, + enable_multiturn: bool = True, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + + # Keep one conversation in one list + new_dataset = [] + for data in dataset: + if len(data["conversations"]) % 2 != 0: + continue + if data["conversations"][0]["from"] != "human": + continue + chat = [] + total_len = 2 + if enable_multiturn: + total_len = len(data["conversations"]) + for i in range(0, total_len, 2): + # One user One Assistant + chat.append( + ( + data["conversations"][i]["value"], + data["conversations"][i + 1]["value"], + ) + ) + new_dataset.append(chat) + + if not disable_shuffle: + # Shuffle the dataset. + random.shuffle(new_dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = common_filter_chat( + num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len + ) + return filtered_dataset + + +def sample_ultrachat_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + disable_shuffle: bool = False, + enable_multiturn: bool = True, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset + dataset = [] + with open(dataset_path) as f: + while True: + line = f.readline() + if not line: + break + dataset.append(json.loads(line)) + + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["data"]) >= 2] + + # Keep one conversation in one list + new_dataset = [] + for data in dataset: + if len(data["data"]) % 2 != 0: + continue + chat = [] + total_len = 2 + if enable_multiturn: + total_len = len(data["data"]) + for i in range(0, total_len, 2): + # One user One Assistant + chat.append((data["data"][i], data["data"][i + 1])) + new_dataset.append(chat) + + # Shuffle the dataset. + if not disable_shuffle: + random.shuffle(new_dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = common_filter_chat( + num_requests, new_dataset, tokenizer, 4, 4, None, None, fixed_output_len + ) + return filtered_dataset + + +def sample_loogle_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + disable_shuffle: bool = False, + enable_multiturn: bool = True, + enable_shared_prefix: bool = False, + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset + dataset = [] + with open(dataset_path) as f: + while True: + line = f.readline() + if not line: + break + dataset.append(json.loads(line)) + + # Keep one conversation in one list + new_dataset = [] + # TODO: Add shared prefix support for loogle + # NOTE: Now we preprocess it only for chat + for data in dataset: + chat = [] + if ( + "qa_pairs" not in data + or data["qa_pairs"] == "none" + or len(data["qa_pairs"]) == 0 + ): + # If Q is none (for summarization), + # We add a question for summarization + # And keep the summary up to 1024 words + chat.append( + ( + "Input: " + + data["input"] + + " Question: " + + "Please summarize the input", + data["input"][:1024], + ) + ) + new_dataset.append(chat) + else: + qa_pairs = eval(data["qa_pairs"]) + for i, qa in enumerate(qa_pairs): + if i == 0 or enable_shared_prefix: + # Combine input with the first Q + chat.append( + ("Input: " + data["input"] + " Question: " + qa["Q"], qa["A"]) + ) + elif enable_multiturn: + chat.append((qa["Q"], qa["A"])) + + new_dataset.append(chat) + + # Shuffle the dataset. + if not disable_shuffle: + random.shuffle(new_dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: SampleOutput = common_filter_chat( + num_requests, new_dataset, tokenizer, 4, None, None, None, fixed_output_len + ) + return filtered_dataset + + +def sample_nextqa_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + max_frames: int, # Specific for video + disable_shuffle: bool = False, + enable_multiturn: bool = True, # No multiturn support for now + fixed_output_len: Optional[int] = None, +) -> SampleOutput: + """ + Example of messages: + message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": base64_data}}, + {"type": "text", "text": video.prompt}, + ], + } + """ + if fixed_output_len is None: + fixed_output_len = 4096 + + # TODO: Check for multiturn + dataset = NExTQALoader(video_dir=dataset_path, max_frames=max_frames) + new_dataset = [] + for v in dataset: + new_dataset.append(v) + + if not disable_shuffle: + random.shuffle(new_dataset) + + # TODO: prompt len can get from server side + filtered_dataset = [] + l = 0 + while l < num_requests: + for i in range(len(new_dataset)): + if l == num_requests: + break + + video = new_dataset[i] + + # text prompt + prompt = video.prompt + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + output_len = fixed_output_len # max output len, not real output len + + # video input + base64_data = encode_video_base64(video.path, video.num_frames) + # TODO: Support more models than 7B + # TODO: Remove fixed 144 + prompt_len += video.num_frames * 144 + + # add to content + content = [ + {"type": "image_url", "image_url": {"url": base64_data}}, + {"type": "text", "text": prompt}, + ] + + filtered_dataset.append([(content, prompt_len, output_len)]) + l += 1 + return filtered_dataset + + +def sample_random_requests( + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, + dataset_path: str, + disable_shuffle: bool = False, +) -> SampleOutput: + + input_lens = np.random.randint( + max(int(input_len * range_ratio), 1), + input_len + 1, + size=num_prompts, + ) + output_lens = np.random.randint( + int(output_len * range_ratio), + output_len + 1, + size=num_prompts, + ) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + if not disable_shuffle: + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: SampleOutput = [] + for data in dataset: + i = len(input_requests) + if i == num_prompts: + break + + # Tokenize the prompts and completions. + prompt = data[0] + prompt_token_ids = tokenizer.encode(prompt) + prompt_len = len(prompt_token_ids) + + # Skip empty prompt + if prompt_len == 0: + continue + + if prompt_len > input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))]) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append([(prompt, int(input_lens[i]), int(output_lens[i]))]) + + print(f"#Input tokens: {np.sum(input_lens)}") + print(f"#Output tokens: {np.sum(output_lens)}") + return input_requests + + +def gen_prompt(tokenizer, token_num): + """Generate a random prompt of specified token length using tokenizer vocabulary.""" + all_available_tokens = list(tokenizer.get_vocab().values()) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return tokenizer.decode(selected_tokens) + + +def get_gen_prefix_cache_path(args, tokenizer): + """Create cache directory under ~/.cache/sglang/benchmark""" + cache_dir = Path.home() / ".cache" / "sglang" / "benchmark" + + # Create a unique cache filename based on the generation parameters + cache_key = ( + f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_" + f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_" + f"{tokenizer.__class__.__name__}.pkl" + ) + return cache_dir / cache_key + + +def sample_generated_shared_prefix_requests( + num_groups: int, + prompts_per_group: int, + system_prompt_len: int, + question_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + args, + disable_shuffle: bool = False, +) -> SampleOutput: + """Generate benchmark requests with shared system prompts using random tokens and caching.""" + cache_path = get_gen_prefix_cache_path(args, tokenizer) + + # Try to load from cache first + if cache_path.exists(): + print(f"\nLoading cached generated input data from {cache_path}") + with open(cache_path, "rb") as f: + return pickle.load(f) + + print("\nGenerating new input data...") + + # Generate system prompts for each group + system_prompts = [] + for _ in range(num_groups): + system_prompt = gen_prompt(tokenizer, system_prompt_len) + system_prompts.append(system_prompt) + + # Generate questions + questions = [] + for _ in range(num_groups * prompts_per_group): + question = gen_prompt(tokenizer, question_len) + questions.append(question) + + # Combine system prompts with questions + input_requests = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): + system_prompt = system_prompts[group_idx] + input_requests.append([]) + for prompt_idx in tqdm( + range(prompts_per_group), desc="Generating questions", leave=False + ): + question = questions[group_idx * prompts_per_group + prompt_idx] + full_prompt = f"{system_prompt}\n\n{question}" + prompt_len = len(tokenizer.encode(full_prompt)) + input_requests[-1].append((full_prompt, prompt_len, output_len)) + total_input_tokens += prompt_len + total_output_tokens += output_len + + if not disable_shuffle: + # Shuffle questions + random.shuffle(input_requests) + + # Print statistics + print(f"\nGenerated shared prefix dataset statistics:") + print(f"Number of groups: {num_groups}") + print(f"Prompts per group: {prompts_per_group}") + print(f"Total prompts: {len(input_requests) * prompts_per_group}") + print(f"Total input tokens: {total_input_tokens}") + print(f"Total output tokens: {total_output_tokens}") + print( + f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens" + ) + print( + f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n" + ) + + # Save to cache + cache_path.parent.mkdir(parents=True, exist_ok=True) + print(f"Caching generated input data to {cache_path}") + with open(cache_path, "wb") as f: + pickle.dump(input_requests, f) + + return input_requests + + +def get_dataset(args, tokenizer): + if args.dataset_name == "sharegpt": + input_requests = sample_sharegpt_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "ultrachat": + input_requests = sample_ultrachat_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "loogle": + input_requests = sample_loogle_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + enable_shared_prefix=args.enable_shared_prefix, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "nextqa": + input_requests = sample_nextqa_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + max_frames=args.max_frames, + disable_shuffle=args.disable_shuffle, + enable_multiturn=args.enable_multiturn, + fixed_output_len=args.fixed_output_len, + ) + elif args.dataset_name == "random": + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path=args.dataset_path, + ) + elif args.dataset_name == "generated-shared-prefix": + input_requests = sample_generated_shared_prefix_requests( + num_groups=args.gen_num_groups, + prompts_per_group=args.gen_prompts_per_group, + system_prompt_len=args.gen_system_prompt_len, + question_len=args.gen_question_len, + output_len=args.gen_output_len, + args=args, + tokenizer=tokenizer, + ) + else: + raise ValueError(f"Unknown dataset: {args.dataset_name}") + return input_requests diff --git a/python/sglang/download.sh b/python/sglang/download.sh new file mode 100755 index 00000000000..340bcc3bcb9 --- /dev/null +++ b/python/sglang/download.sh @@ -0,0 +1,66 @@ +#!/usr/bin/bash + +# The usage function +usage() { + echo "Usage: $0 {sharegpt|ultragpt|loogle|nextqa|all}" + exit 1 +} + +# The download function +download() { + case "$1" in + sharegpt) + echo $1 + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + ;; + ultragpt) + echo $1 + # Questions about the world + wget https://cloud.tsinghua.edu.cn/seafhttp/files/be1d7b87-22ca-449e-a6a7-c61d1ea7e010/ultrachat_release_230407.json + # Writing and Creation + wget https://cloud.tsinghua.edu.cn/seafhttp/files/61742d2a-25e2-4d08-b2b9-15f47ae50ace/ultrachat_material_release_230417.json + wget https://cloud.tsinghua.edu.cn/seafhttp/files/f71f6aa6-d346-4b16-85b7-8502efa3d608/ultrachat_material_release_230412.json + # External materials + wget https://cloud.tsinghua.edu.cn/seafhttp/files/42d22e28-e899-4975-a70f-5eda163e265d/ultrachat_existent_material_release_230420.json.gz + gunzip ultrachat_existent_material_release_230420.json.gz + ;; + loogle) + echo $1 + git lfs install + git clone git@hf.co:datasets/bigainlco/LooGLE + unzip LooGLE/data.zip + ;; + nextqa) + echo $1 + git lfs install + git clone https://huggingface.co/datasets/lmms-lab/NExTQA + unzip NExTQA/videos.zip + ;; + *) + usage + exit 1 + ;; + esac +} + +# Arg check +if [ "$#" -ne 1 ]; then + usage +fi + +# Invoke + +case "$1" in + sharegpt|ultragpt|loogle|nextqa) + download "$1" + ;; + all) + download sharegpt + download ultragpt + download loogle + download nextqa + ;; + *) + usage + ;; +esac diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 98e0f3f4f8d..3768102693f 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -15,15 +15,20 @@ from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import requests from IPython.display import HTML, display from tqdm import tqdm +from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart + logger = logging.getLogger(__name__) +# type of content fields, can be only prompts or with images/videos +MsgContent = Union[str, List[ChatCompletionMessageContentPart]] + def get_exception_traceback(): etype, value, tb = sys.exc_info()