From a509552087fa29a62113ca0e24a6c35aa9502b30 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 8 Nov 2024 02:19:41 -0800 Subject: [PATCH] [minor] Improve code style and compatibility (#1961) --- python/pyproject.toml | 1 + .../sglang/srt/managers/tokenizer_manager.py | 2 +- .../srt/model_executor/cuda_graph_runner.py | 4 +- python/sglang/srt/server.py | 40 +++++++++------ python/sglang/srt/server_args.py | 49 ++++++++++++------- python/sglang/srt/utils.py | 48 ++++++++++++++++++ 6 files changed, 109 insertions(+), 35 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6239963f497..5ac2e8372c7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu "torchao", "uvicorn", "uvloop", "zmq", "outlines>=0.0.44", "modelscope"] srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] + # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 78f35903f4a..cf412b8fa56 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -461,7 +461,7 @@ async def sigterm_watchdog(self): break kill_child_process(include_self=True) - sys.exit(-1) + sys.exit(0) async def handle_loop(self): """The event loop that handles requests""" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e91fbac6523..236a57f1a09 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -32,7 +32,7 @@ LogitsProcessorOutput, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import monkey_patch_vllm_all_gather +from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -92,7 +92,7 @@ def set_torch_compile_config(): torch._dynamo.config.accumulated_cache_size_limit = 1024 -@torch.compile(dynamic=True) +@maybe_torch_compile(dynamic=True) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 1ed8af0e707..3d0816ce066 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -79,6 +79,7 @@ add_api_key_middleware, assert_pkg_version, configure_logger, + delete_directory, is_port_available, kill_child_process, maybe_set_triton_cache_manager, @@ -97,8 +98,6 @@ app = FastAPI() -tokenizer_manager: TokenizerManager = None - app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -107,6 +106,10 @@ allow_headers=["*"], ) +tokenizer_manager: TokenizerManager = None + +##### Native API endpoints ##### + @app.get("/health") async def health() -> Response: @@ -275,6 +278,9 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): app.put("/classify")(classify_request) +##### OpenAI-compatible API endpoints ##### + + @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): return await v1_completions(tokenizer_manager, raw_request) @@ -420,19 +426,6 @@ def launch_engine( scheduler_pipe_readers[i].recv() -def add_prometheus_middleware(app: FastAPI): - # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 - from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess - - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P.*)$") - app.routes.append(metrics_route) - - def launch_server( server_args: ServerArgs, pipe_finish_writer: Optional[mp.connection.Connection] = None, @@ -492,6 +485,19 @@ def launch_server( t.join() +def add_prometheus_middleware(app: FastAPI): + # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 + from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess + + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + def _set_prometheus_env(): # Set prometheus multiprocess directory # sglang uses prometheus multiprocess mode @@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): return model_info = res.json() + # Send a warmup request request_name = "/generate" if model_info["is_generation"] else "/encode" max_new_tokens = 8 if model_info["is_generation"] else 1 @@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if pipe_finish_writer is not None: pipe_finish_writer.send("ready") + if server_args.delete_ckpt_after_loading: + delete_directory(server_args.model_path) + class Runtime: """ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 53a493bdeae..769f435cd20 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -63,7 +63,7 @@ class ServerArgs: stream_interval: int = 1 random_seed: Optional[int] = None constrained_json_whitespace_pattern: Optional[str] = None - decode_log_interval: int = 40 + watchdog_timeout: float = 300 # Logging log_level: str = "info" @@ -71,18 +71,18 @@ class ServerArgs: log_requests: bool = False show_time_cost: bool = False enable_metrics: bool = False + decode_log_interval: int = 40 - # Other + # API related api_key: Optional[str] = None file_storage_pth: str = "SGLang_storage" enable_cache_report: bool = False - watchdog_timeout: float = 600 # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" - # Distributed args + # Multi-node distributed serving dist_init_addr: Optional[str] = None nnodes: int = 1 node_rank: int = 0 @@ -128,6 +128,7 @@ class ServerArgs: enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False num_continuous_decode_steps: int = 1 + delete_ckpt_after_loading: bool = False def __post_init__(self): # Set missing default values @@ -205,6 +206,7 @@ def __post_init__(self): @staticmethod def add_cli_args(parser: argparse.ArgumentParser): + # Model and port args parser.add_argument( "--model-path", type=str, @@ -324,6 +326,8 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Whether to use a CausalLM as an embedding model.", ) + + # Memory and scheduling parser.add_argument( "--mem-fraction-static", type=float, @@ -368,6 +372,8 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) + + # Other runtime options parser.add_argument( "--tensor-parallel-size", "--tp-size", @@ -393,6 +399,14 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.constrained_json_whitespace_pattern, help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*", ) + parser.add_argument( + "--watchdog-timeout", + type=float, + default=ServerArgs.watchdog_timeout, + help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.", + ) + + # Logging parser.add_argument( "--log-level", type=str, @@ -420,7 +434,14 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable log prometheus metrics.", ) + parser.add_argument( + "--decode-log-interval", + type=int, + default=ServerArgs.decode_log_interval, + help="The log interval of decode batch", + ) + # API related parser.add_argument( "--api-key", type=str, @@ -438,18 +459,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", ) - parser.add_argument( - "--watchdog-timeout", - type=float, - default=ServerArgs.watchdog_timeout, - help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.", - ) - parser.add_argument( - "--decode-log-interval", - type=int, - default=ServerArgs.decode_log_interval, - help="The log interval of decode batch", - ) # Data parallelism parser.add_argument( @@ -470,7 +479,7 @@ def add_cli_args(parser: argparse.ArgumentParser): ], ) - # Multi-node distributed serving args + # Multi-node distributed serving parser.add_argument( "--dist-init-addr", "--nccl-init-addr", # For backward compatbility. This will be removed in the future. @@ -677,6 +686,12 @@ def add_cli_args(parser: argparse.ArgumentParser): "This can potentially increase throughput but may also increase time-to-first-token latency. " "The default value is 1, meaning only run one decoding step at a time.", ) + parser.add_argument( + "--delete-ckpt-after-loading", + default=ServerArgs.delete_ckpt_after_loading, + action="store_true", + help="Delete the model checkpoint after loading the model.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0c3ae0c5a7c..5ee0fe59dc0 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -23,6 +23,8 @@ import pickle import random import resource +import shutil +import signal import socket import time import warnings @@ -35,6 +37,7 @@ import requests import torch import torch.distributed as dist +import triton import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version @@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): if include_self: try: itself.kill() + + # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), + # so we send an additional signal to kill them. + itself.send_signal(signal.SIGINT) except psutil.NoSuchProcess: pass @@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: raise ValueError(f"Unsupported socket type: {socket_type}") return socket + + +def dump_to_file(dirpath, name, value): + from vllm.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() != 0: + return + + os.makedirs(dirpath, exist_ok=True) + if value.dtype is torch.bfloat16: + value = value.float() + value = value.cpu().numpy() + output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy") + logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}") + np.save(output_filename, value) + + +def is_triton_3(): + return triton.__version__.startswith("3.") + + +def maybe_torch_compile(*args, **kwargs): + """ + torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax. + Therefore, we disable it here. + """ + + def decorator(func): + if is_triton_3(): + return torch.compile(*args, **kwargs)(func) + return func + + return decorator + + +def delete_directory(dirpath): + try: + # This will remove the directory and all its contents + shutil.rmtree(dirpath) + except OSError as e: + print(f"Warning: {dirpath} : {e.strerror}")