Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide an offline engine API #1567

Merged
28 changes: 28 additions & 0 deletions examples/runtime/srt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import sglang as sgl


def main():
# Sample prompts.
prompts = [
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = {"temperature": 0.8, "top_p": 0.95}

# Create an LLM.
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")


# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions python/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SGL API Components

from sglang.api import (
Engine,
Runtime,
assistant,
assistant_begin,
Expand Down Expand Up @@ -31,6 +32,7 @@
# SGLang DSL APIs
__all__ = [
"Runtime",
"Engine",
"assistant",
"assistant_begin",
"assistant_end",
Expand Down
12 changes: 11 additions & 1 deletion python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,23 @@ def decorator(func):


def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Avoid importing unnecessary dependency
from sglang.srt.server import Runtime

return Runtime(*args, **kwargs)


def Engine(*args, **kwargs):
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Avoid importing unnecessary dependency
from sglang.srt.server import Engine

return Engine(*args, **kwargs)


def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend

Expand Down
77 changes: 73 additions & 4 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import asyncio
import atexit
import dataclasses
import json
import logging
Expand Down Expand Up @@ -161,6 +162,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
)


# fastapi implicitly converts json in the request to obj (dataclass)
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
Expand Down Expand Up @@ -290,11 +292,13 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id)


def launch_server(
def launch_engine(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
"""Launch an HTTP server."""
"""
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
"""

global tokenizer_manager

# Configure global environment
Expand Down Expand Up @@ -355,6 +359,29 @@ def launch_server(
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()


def launch_server(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
"""
Launch SRT (SGLang Runtime) Server

The SRT server consists of an HTTP server and the SRT engine.
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved

1. HTTP server: A FastAPI server that routes requests to the engine.
2. SRT engine:
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

Note:
1. The HTTP server and Tokenizer Manager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""

launch_engine(server_args=server_args)

# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
Expand Down Expand Up @@ -435,7 +462,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
return

model_info = res.json()

# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 1
Expand Down Expand Up @@ -626,3 +652,46 @@ def encode(

def __del__(self):
self.shutdown()


class Engine:
"""
SRT Engine without an HTTP server layer.

This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
launching the HTTP server adds unnecessary complexity or overhead,
"""

def __init__(self, *args, **kwargs):

# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
merrymercy marked this conversation as resolved.
Show resolved Hide resolved

server_args = ServerArgs(*args, **kwargs)
launch_engine(server_args=server_args)

def generate(
self,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
):
obj = GenerateReqInput(
text=prompt,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
)

# make it synchronous
return asyncio.run(generate_request(obj, None))

def shutdown(self):
kill_child_process(os.getpid(), including_parent=False)

# TODO (ByronHsu): encode and async generate
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"test_pytorch_sampling_backend.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
"test_srt_engine.py",
"test_srt_endpoint.py",
"test_torch_compile.py",
"test_torchao.py",
Expand Down
33 changes: 33 additions & 0 deletions test/srt/test_srt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import unittest

import sglang as sgl
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST


class TestSRTBackend(unittest.TestCase):

def test_engine_runtime_consistency(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_MODEL_NAME_FOR_TEST

sampling_params = {"temperature": 0, "max_new_tokens": 8}

engine = sgl.Engine(model_path=model_path, random_seed=42)
out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()

runtime = sgl.Runtime(model_path=model_path, random_seed=42)
out2 = json.loads(runtime.generate(prompt, sampling_params))["text"]
runtime.shutdown()

print("==== Answer 1 ====")
print(out1)

print("==== Answer 2 ====")
print(out2)
assert out1 == out2, f"{out1} != {out2}"


if __name__ == "__main__":
unittest.main()
Loading