diff --git a/src/litserve/server.py b/src/litserve/server.py index 87b9a6da..cabc4933 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -29,7 +29,7 @@ import warnings from abc import ABC, abstractmethod from collections import deque -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from queue import Queue from typing import TYPE_CHECKING, Literal, Optional, Union @@ -797,6 +797,7 @@ def __init__( self.lit_api = lit_api self.enable_shutdown_api = enable_shutdown_api self.workers_per_device = workers_per_device + self._workers_per_device_by_api_path = self._resolve_workers_per_device_config(workers_per_device) self.max_payload_size = max_payload_size self.model_metadata = model_metadata self._connector = _Connector(accelerator=accelerator, devices=devices) @@ -822,12 +823,16 @@ def __init__( device_list = range(devices) self.devices = [self.device_identifiers(accelerator, device) for device in device_list] - self.inference_workers_config = self.devices * self.workers_per_device + # self.inference_workers_config = self.devices * self.workers_per_device self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp") self.register_endpoints() # register middleware self._register_middleware() + def _inference_workers_config_for_api(self, api_path: str): + wpd = self._workers_per_device_by_api_path[api_path] + return self.devices * wpd + def launch_inference_worker(self, lit_api: LitAPI): specs = [lit_api.spec] if lit_api.spec else [] for spec in specs: @@ -839,7 +844,10 @@ def launch_inference_worker(self, lit_api: LitAPI): process_list = [] endpoint = lit_api.api_path.split("/")[-1] - for worker_id, device in enumerate(self.inference_workers_config): + + inference_workers_config = self._inference_workers_config_for_api(lit_api.api_path) + + for worker_id, device in enumerate(inference_workers_config): if len(device) == 1: device = device[0] @@ -873,7 +881,8 @@ def launch_single_inference_worker(self, lit_api: LitAPI, worker_id: int): del server_copy.app, server_copy.transport_config, server_copy.litapi_connector spec.setup(server_copy) - device = self.inference_workers_config[worker_id] + inference_workers_config = self._inference_workers_config_for_api(lit_api.api_path) + device = inference_workers_config[worker_id] endpoint = lit_api.api_path.split("/")[-1] if len(device) == 1: device = device[0] @@ -1183,6 +1192,42 @@ def _perform_graceful_shutdown( manager.shutdown() + def _resolve_workers_per_device_config(self, workers_per_device): + """Resolve workers_per_device into a dict[api_path, workers_per_device_int].""" + api_paths = [api.api_path for api in self.litapi_connector] + + if isinstance(workers_per_device, int): + if workers_per_device < 1: + raise ValueError("workers_per_device must be >= 1") + return dict.fromkeys(api_paths, workers_per_device) + + if isinstance(workers_per_device, (list, tuple)): + if len(workers_per_device) != len(api_paths): + raise ValueError( + f"workers_per_device list length must match number of APIs \n" + f"({len(api_paths)}), got {len(workers_per_device)}" + ) + cfg = {} + for p, w in zip(api_paths, workers_per_device): + if not isinstance(w, int) or w < 1: + raise ValueError("workers_per_device values must be integers >= 1") + cfg[p] = w + return cfg + + if isinstance(workers_per_device, Mapping): + unknown = sorted(set(workers_per_device.keys()) - set(api_paths)) + if unknown: + raise ValueError(f"workers_per_device contains unknown api_path values: {unknown} (unknown api_path)") + cfg = {} + for p in api_paths: + w = workers_per_device.get(p, 1) + if not isinstance(w, int) or w < 1: + raise ValueError("workers_per_device values must be integers >= 1") + cfg[p] = w + return cfg + + raise TypeError("workers_per_device must be an int, a list/tuple of ints, or a mapping of api_path -> int") + def run( self, host: str = "0.0.0.0", @@ -1388,7 +1433,10 @@ def run( sockets = [config.bind_socket()] if num_api_servers is None: - num_api_servers = len(self.inference_workers_config) + total_workers = 0 + for lit_api in self.litapi_connector: + total_workers += len(self._inference_workers_config_for_api(lit_api.api_path)) + num_api_servers = total_workers if num_api_servers < 1: raise ValueError("num_api_servers must be greater than 0") @@ -1528,8 +1576,25 @@ def monitor(): broken_workers[i] = proc for idx, proc in broken_workers.items(): - lit_api_id = idx // len(self.inference_workers_config) - worker_id = idx % len(self.inference_workers_config) + lit_api_id = 0 + worker_id = 0 + count = 0 + found = False + + for i, lit_api in enumerate(self.litapi_connector): + workers_conf = self._inference_workers_config_for_api(lit_api.api_path) + num_workers_for_api = len(workers_conf) + + if idx < count + num_workers_for_api: + lit_api_id = i + worker_id = idx - count + found = True + break + count += num_workers_for_api + + if not found: + logger.error(f"Could not map worker index {idx} to an API.") + continue for uid, resp in self.response_buffer.items(): if resp.worker_id is None or resp.worker_id != worker_id: diff --git a/tests/unit/test_lit_server.py b/tests/unit/test_lit_server.py index 09e57aca..21a8b04d 100644 --- a/tests/unit/test_lit_server.py +++ b/tests/unit/test_lit_server.py @@ -837,3 +837,85 @@ async def test_worker_restart_and_server_shutdown_streaming(): ): resp = await ac.post("/predict", json={"input": 0}) assert resp.status_code == 200 + + +class MultiRouteAPI(ls.test_examples.SimpleLitAPI): + # Mock API for testing multi-route server behavior + def __init__(self, api_path="/predict"): + super().__init__(api_path=api_path) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix") +@pytest.mark.parametrize( + ("workers_cfg", "expected_total_by_path"), + [ + # dict: explicit per-route config + ({"/sentiment": 2, "/generate": 3}, {"/sentiment": 4, "/generate": 6}), + # list: per-api (connector order) config + ([2, 3], {"/sentiment": 4, "/generate": 6}), + ], +) +def test_workers_per_device_can_be_configured_per_route(monkeypatch, workers_cfg, expected_total_by_path): + monkeypatch.setattr("litserve.server.uvicorn", MagicMock()) + + sentiment = MultiRouteAPI(api_path="/sentiment") + generate = MultiRouteAPI(api_path="/generate") + server = LitServer([sentiment, generate], accelerator="cuda", devices=[0, 1], workers_per_device=workers_cfg) + + created = [] # list[(api_path, worker_id, device)] + + class FakeProcess: + def __init__(self, target, args, name): + # inference_worker args = (lit_api, device, worker_id, request_q, transport, ...) + lit_api, device, worker_id = args[0], args[1], args[2] + created.append((lit_api.api_path, worker_id, device)) + self.pid = 123 + self.name = name + + def start(self): ... + def terminate(self): ... + def join(self, timeout=None): ... + def is_alive(self): + return False + + def kill(self): ... + + class FakeCtx: + def Process(self, target, args, name): # noqa: N802 + return FakeProcess(target=target, args=args, name=name) + + monkeypatch.setattr("litserve.server.mp.get_context", lambda *_args, **_kwargs: FakeCtx()) + + # prevent server.run() from actually running uvicorn / waiting forever + server.verify_worker_status = MagicMock() + server._start_server = MagicMock(return_value={}) + server._perform_graceful_shutdown = MagicMock() + server._start_worker_monitoring = MagicMock() + server._transport = MagicMock() + server._shutdown_event = MagicMock() + server._shutdown_event.wait = MagicMock(return_value=None) # don't block + + # init manager + queues without real multiprocessing manager + with patch("litserve.server.mp.Manager", return_value=MagicMock()): + server.run(api_server_worker_type="process", generate_client_file=False) + + # count workers created per api_path + total_by_path = {} + for api_path, _worker_id, _device in created: + total_by_path[api_path] = total_by_path.get(api_path, 0) + 1 + + assert total_by_path == expected_total_by_path + + +@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix") +def test_workers_per_device_per_route_raises_on_unknown_route(): + sentiment = MultiRouteAPI(api_path="/sentiment") + generate = MultiRouteAPI(api_path="/generate") + + with pytest.raises(ValueError, match="workers_per_device.*unknown api_path"): + LitServer( + [sentiment, generate], + accelerator="cuda", + devices=[0, 1], + workers_per_device={"/sentiment": 2, "/unknown": 1}, + )