Skip to content

Commit

Permalink
[minor] Improve code style and compatibility (#1961)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 8, 2024
1 parent 7ef0084 commit a509552
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 35 deletions.
1 change: 1 addition & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]

# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def sigterm_watchdog(self):
break

kill_child_process(include_self=True)
sys.exit(-1)
sys.exit(0)

async def handle_loop(self):
"""The event loop that handles requests"""
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
LogitsProcessorOutput,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import monkey_patch_vllm_all_gather
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather

if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
Expand Down Expand Up @@ -92,7 +92,7 @@ def set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 1024


@torch.compile(dynamic=True)
@maybe_torch_compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)

Expand Down
40 changes: 25 additions & 15 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
add_api_key_middleware,
assert_pkg_version,
configure_logger,
delete_directory,
is_port_available,
kill_child_process,
maybe_set_triton_cache_manager,
Expand All @@ -97,8 +98,6 @@


app = FastAPI()
tokenizer_manager: TokenizerManager = None

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand All @@ -107,6 +106,10 @@
allow_headers=["*"],
)

tokenizer_manager: TokenizerManager = None

##### Native API endpoints #####


@app.get("/health")
async def health() -> Response:
Expand Down Expand Up @@ -275,6 +278,9 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
app.put("/classify")(classify_request)


##### OpenAI-compatible API endpoints #####


@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
Expand Down Expand Up @@ -420,19 +426,6 @@ def launch_engine(
scheduler_pipe_readers[i].recv()


def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)


def launch_server(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
Expand Down Expand Up @@ -492,6 +485,19 @@ def launch_server(
t.join()


def add_prometheus_middleware(app: FastAPI):
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)


def _set_prometheus_env():
# Set prometheus multiprocess directory
# sglang uses prometheus multiprocess mode
Expand Down Expand Up @@ -565,6 +571,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
return

model_info = res.json()

# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 1
Expand Down Expand Up @@ -602,6 +609,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")

if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)


class Runtime:
"""
Expand Down
49 changes: 32 additions & 17 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ class ServerArgs:
stream_interval: int = 1
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
decode_log_interval: int = 40
watchdog_timeout: float = 300

# Logging
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
show_time_cost: bool = False
enable_metrics: bool = False
decode_log_interval: int = 40

# Other
# API related
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
watchdog_timeout: float = 600

# Data parallelism
dp_size: int = 1
load_balance_method: str = "round_robin"

# Distributed args
# Multi-node distributed serving
dist_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: int = 0
Expand Down Expand Up @@ -128,6 +128,7 @@ class ServerArgs:
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1
delete_ckpt_after_loading: bool = False

def __post_init__(self):
# Set missing default values
Expand Down Expand Up @@ -205,6 +206,7 @@ def __post_init__(self):

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
parser.add_argument(
"--model-path",
type=str,
Expand Down Expand Up @@ -324,6 +326,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)

# Memory and scheduling
parser.add_argument(
"--mem-fraction-static",
type=float,
Expand Down Expand Up @@ -368,6 +372,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)

# Other runtime options
parser.add_argument(
"--tensor-parallel-size",
"--tp-size",
Expand All @@ -393,6 +399,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)

# Logging
parser.add_argument(
"--log-level",
type=str,
Expand Down Expand Up @@ -420,7 +434,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable log prometheus metrics.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)

# API related
parser.add_argument(
"--api-key",
type=str,
Expand All @@ -438,18 +459,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
parser.add_argument(
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch",
)

# Data parallelism
parser.add_argument(
Expand All @@ -470,7 +479,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
],
)

# Multi-node distributed serving args
# Multi-node distributed serving
parser.add_argument(
"--dist-init-addr",
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
Expand Down Expand Up @@ -677,6 +686,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)
parser.add_argument(
"--delete-ckpt-after-loading",
default=ServerArgs.delete_ckpt_after_loading,
action="store_true",
help="Delete the model checkpoint after loading the model.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
48 changes: 48 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import pickle
import random
import resource
import shutil
import signal
import socket
import time
import warnings
Expand All @@ -35,6 +37,7 @@
import requests
import torch
import torch.distributed as dist
import triton
import zmq
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
Expand Down Expand Up @@ -379,6 +382,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
if include_self:
try:
itself.kill()

# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT)
except psutil.NoSuchProcess:
pass

Expand Down Expand Up @@ -704,3 +711,44 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
raise ValueError(f"Unsupported socket type: {socket_type}")

return socket


def dump_to_file(dirpath, name, value):
from vllm.distributed import get_tensor_model_parallel_rank

if get_tensor_model_parallel_rank() != 0:
return

os.makedirs(dirpath, exist_ok=True)
if value.dtype is torch.bfloat16:
value = value.float()
value = value.cpu().numpy()
output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
np.save(output_filename, value)


def is_triton_3():
return triton.__version__.startswith("3.")


def maybe_torch_compile(*args, **kwargs):
"""
torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
Therefore, we disable it here.
"""

def decorator(func):
if is_triton_3():
return torch.compile(*args, **kwargs)(func)
return func

return decorator


def delete_directory(dirpath):
try:
# This will remove the directory and all its contents
shutil.rmtree(dirpath)
except OSError as e:
print(f"Warning: {dirpath} : {e.strerror}")

0 comments on commit a509552

Please sign in to comment.