Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 72 additions & 7 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from contextlib import asynccontextmanager
from queue import Queue
from typing import TYPE_CHECKING, Literal, Optional, Union
Expand Down Expand Up @@ -797,6 +797,7 @@ def __init__(
self.lit_api = lit_api
self.enable_shutdown_api = enable_shutdown_api
self.workers_per_device = workers_per_device
self._workers_per_device_by_api_path = self._resolve_workers_per_device_config(workers_per_device)
self.max_payload_size = max_payload_size
self.model_metadata = model_metadata
self._connector = _Connector(accelerator=accelerator, devices=devices)
Expand All @@ -822,12 +823,16 @@ def __init__(
device_list = range(devices)
self.devices = [self.device_identifiers(accelerator, device) for device in device_list]

self.inference_workers_config = self.devices * self.workers_per_device
# self.inference_workers_config = self.devices * self.workers_per_device
self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp")
self.register_endpoints()
# register middleware
self._register_middleware()

def _inference_workers_config_for_api(self, api_path: str):
wpd = self._workers_per_device_by_api_path[api_path]
return self.devices * wpd

def launch_inference_worker(self, lit_api: LitAPI):
specs = [lit_api.spec] if lit_api.spec else []
for spec in specs:
Expand All @@ -839,7 +844,10 @@ def launch_inference_worker(self, lit_api: LitAPI):

process_list = []
endpoint = lit_api.api_path.split("/")[-1]
for worker_id, device in enumerate(self.inference_workers_config):

inference_workers_config = self._inference_workers_config_for_api(lit_api.api_path)

for worker_id, device in enumerate(inference_workers_config):
if len(device) == 1:
device = device[0]

Expand Down Expand Up @@ -873,7 +881,8 @@ def launch_single_inference_worker(self, lit_api: LitAPI, worker_id: int):
del server_copy.app, server_copy.transport_config, server_copy.litapi_connector
spec.setup(server_copy)

device = self.inference_workers_config[worker_id]
inference_workers_config = self._inference_workers_config_for_api(lit_api.api_path)
device = inference_workers_config[worker_id]
endpoint = lit_api.api_path.split("/")[-1]
if len(device) == 1:
device = device[0]
Expand Down Expand Up @@ -1183,6 +1192,42 @@ def _perform_graceful_shutdown(

manager.shutdown()

def _resolve_workers_per_device_config(self, workers_per_device):
"""Resolve workers_per_device into a dict[api_path, workers_per_device_int]."""
api_paths = [api.api_path for api in self.litapi_connector]

if isinstance(workers_per_device, int):
if workers_per_device < 1:
raise ValueError("workers_per_device must be >= 1")
return dict.fromkeys(api_paths, workers_per_device)

if isinstance(workers_per_device, (list, tuple)):
if len(workers_per_device) != len(api_paths):
raise ValueError(
f"workers_per_device list length must match number of APIs \n"
f"({len(api_paths)}), got {len(workers_per_device)}"
)
cfg = {}
for p, w in zip(api_paths, workers_per_device):
if not isinstance(w, int) or w < 1:
raise ValueError("workers_per_device values must be integers >= 1")
cfg[p] = w
return cfg

if isinstance(workers_per_device, Mapping):
unknown = sorted(set(workers_per_device.keys()) - set(api_paths))
if unknown:
raise ValueError(f"workers_per_device contains unknown api_path values: {unknown} (unknown api_path)")
cfg = {}
for p in api_paths:
w = workers_per_device.get(p, 1)
if not isinstance(w, int) or w < 1:
raise ValueError("workers_per_device values must be integers >= 1")
cfg[p] = w
return cfg

raise TypeError("workers_per_device must be an int, a list/tuple of ints, or a mapping of api_path -> int")

def run(
self,
host: str = "0.0.0.0",
Expand Down Expand Up @@ -1388,7 +1433,10 @@ def run(
sockets = [config.bind_socket()]

if num_api_servers is None:
num_api_servers = len(self.inference_workers_config)
total_workers = 0
for lit_api in self.litapi_connector:
total_workers += len(self._inference_workers_config_for_api(lit_api.api_path))
num_api_servers = total_workers

if num_api_servers < 1:
raise ValueError("num_api_servers must be greater than 0")
Expand Down Expand Up @@ -1528,8 +1576,25 @@ def monitor():
broken_workers[i] = proc

for idx, proc in broken_workers.items():
lit_api_id = idx // len(self.inference_workers_config)
worker_id = idx % len(self.inference_workers_config)
lit_api_id = 0
worker_id = 0
count = 0
found = False

for i, lit_api in enumerate(self.litapi_connector):
workers_conf = self._inference_workers_config_for_api(lit_api.api_path)
num_workers_for_api = len(workers_conf)

if idx < count + num_workers_for_api:
lit_api_id = i
worker_id = idx - count
found = True
break
count += num_workers_for_api

if not found:
logger.error(f"Could not map worker index {idx} to an API.")
continue

for uid, resp in self.response_buffer.items():
if resp.worker_id is None or resp.worker_id != worker_id:
Expand Down
82 changes: 82 additions & 0 deletions tests/unit/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,85 @@ async def test_worker_restart_and_server_shutdown_streaming():
):
resp = await ac.post("/predict", json={"input": 0})
assert resp.status_code == 200


class MultiRouteAPI(ls.test_examples.SimpleLitAPI):
# Mock API for testing multi-route server behavior
def __init__(self, api_path="/predict"):
super().__init__(api_path=api_path)


@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
@pytest.mark.parametrize(
("workers_cfg", "expected_total_by_path"),
[
# dict: explicit per-route config
({"/sentiment": 2, "/generate": 3}, {"/sentiment": 4, "/generate": 6}),
# list: per-api (connector order) config
([2, 3], {"/sentiment": 4, "/generate": 6}),
],
)
def test_workers_per_device_can_be_configured_per_route(monkeypatch, workers_cfg, expected_total_by_path):
monkeypatch.setattr("litserve.server.uvicorn", MagicMock())

sentiment = MultiRouteAPI(api_path="/sentiment")
generate = MultiRouteAPI(api_path="/generate")
server = LitServer([sentiment, generate], accelerator="cuda", devices=[0, 1], workers_per_device=workers_cfg)

created = [] # list[(api_path, worker_id, device)]

class FakeProcess:
def __init__(self, target, args, name):
# inference_worker args = (lit_api, device, worker_id, request_q, transport, ...)
lit_api, device, worker_id = args[0], args[1], args[2]
created.append((lit_api.api_path, worker_id, device))
self.pid = 123
self.name = name

def start(self): ...
def terminate(self): ...
def join(self, timeout=None): ...
def is_alive(self):
return False

def kill(self): ...

class FakeCtx:
def Process(self, target, args, name): # noqa: N802
return FakeProcess(target=target, args=args, name=name)

monkeypatch.setattr("litserve.server.mp.get_context", lambda *_args, **_kwargs: FakeCtx())

# prevent server.run() from actually running uvicorn / waiting forever
server.verify_worker_status = MagicMock()
server._start_server = MagicMock(return_value={})
server._perform_graceful_shutdown = MagicMock()
server._start_worker_monitoring = MagicMock()
server._transport = MagicMock()
server._shutdown_event = MagicMock()
server._shutdown_event.wait = MagicMock(return_value=None) # don't block

# init manager + queues without real multiprocessing manager
with patch("litserve.server.mp.Manager", return_value=MagicMock()):
server.run(api_server_worker_type="process", generate_client_file=False)

# count workers created per api_path
total_by_path = {}
for api_path, _worker_id, _device in created:
total_by_path[api_path] = total_by_path.get(api_path, 0) + 1

assert total_by_path == expected_total_by_path


@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
def test_workers_per_device_per_route_raises_on_unknown_route():
sentiment = MultiRouteAPI(api_path="/sentiment")
generate = MultiRouteAPI(api_path="/generate")

with pytest.raises(ValueError, match="workers_per_device.*unknown api_path"):
LitServer(
[sentiment, generate],
accelerator="cuda",
devices=[0, 1],
workers_per_device={"/sentiment": 2, "/unknown": 1},
)
Loading