From a226363e1671794dea72be85d88dd330b7c5f4e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=A0=E9=B8=A3?= Date: Fri, 8 May 2026 14:34:31 +0800 Subject: [PATCH] feat: add embedded mode with auto-init runtime --- .../source/getting-started/embedded-mode.md | 117 ++++++ .../source/getting-started/index.md | 9 + .../getting-started/embedded-mode.md | 117 ++++++ .../source_zh/getting-started/index.md | 9 + examples/chat_sft/train.py | 2 + examples/countdown_rl/train.py | 2 + examples/embedded_quickstart/train.py | 87 +++++ src/tuft/__init__.py | 20 +- src/tuft/cli.py | 45 ++- src/tuft/runtime/__init__.py | 338 ++++++++++++++++++ src/tuft/runtime/_config_gen.py | 201 +++++++++++ src/tuft/runtime/_constants.py | 88 +++++ src/tuft/runtime/_discovery.py | 142 ++++++++ src/tuft/runtime/_launcher.py | 207 +++++++++++ tests/test_runtime_config_gen.py | 99 +++++ tests/test_runtime_discovery.py | 107 ++++++ tests/test_runtime_init.py | 106 ++++++ tests/test_runtime_launcher.py | 75 ++++ 18 files changed, 1763 insertions(+), 8 deletions(-) create mode 100644 docs/sphinx_doc/source/getting-started/embedded-mode.md create mode 100644 docs/sphinx_doc/source_zh/getting-started/embedded-mode.md create mode 100644 examples/embedded_quickstart/train.py create mode 100644 src/tuft/runtime/__init__.py create mode 100644 src/tuft/runtime/_config_gen.py create mode 100644 src/tuft/runtime/_constants.py create mode 100644 src/tuft/runtime/_discovery.py create mode 100644 src/tuft/runtime/_launcher.py create mode 100644 tests/test_runtime_config_gen.py create mode 100644 tests/test_runtime_discovery.py create mode 100644 tests/test_runtime_init.py create mode 100644 tests/test_runtime_launcher.py diff --git a/docs/sphinx_doc/source/getting-started/embedded-mode.md b/docs/sphinx_doc/source/getting-started/embedded-mode.md new file mode 100644 index 0000000..1ad0bfb --- /dev/null +++ b/docs/sphinx_doc/source/getting-started/embedded-mode.md @@ -0,0 +1,117 @@ +# Embedded Mode + +## Background + +TuFT is designed to serve as a **transparent compute service layer** for RL training frameworks like Trinity and veRL. In production, TuFT typically runs as a standalone daemon (`tuft launch`), and users must: + +1. Write a `tuft_config.yaml` configuration file +2. Manually start the server with `tuft launch --config ...` +3. Set the `TINKER_BASE_URL` environment variable for clients to connect + +This manual setup creates friction, especially for: +- **RL framework users** who just want to run training scripts without learning TuFT internals +- **Development/debugging** workflows where quick iteration is key +- **CI pipelines** that need reproducible, self-contained environments + +**Embedded mode** solves this by providing a `tuft.init()` API — similar to `ray.init()` — that handles service discovery, configuration generation, startup, and connection automatically. + +## Two Modes of Operation + +| | Daemon Mode | Embedded Mode | +|---|---|---| +| How to start | `tuft launch --config ...` | `tuft.init(model=...)` | +| Lifecycle | Independent process, manually managed | Follows main process, auto-cleanup via atexit | +| Best for | Production deployments, multi-user shared clusters | Dev/debug, training scripts, CI | +| Service discovery | User sets `TINKER_BASE_URL` manually | Automatic (env var → address file → process scan → default port) | + +**Both modes coexist**: `tuft.init()` first tries to discover an existing daemon. Only when no running service is found does it start an embedded instance. + +## Quick Start + +```python +import tuft + +# Initialize TuFT — auto-discovers existing service or starts one +tuft.init(model="/path/to/Qwen2.5-0.5B-Instruct") + +# Use the service client for training +training_client = tuft.create_training_client( + base_model="Qwen2.5-0.5B-Instruct", + rank=8, +) +# ... your training loop ... + +# Optional: explicit shutdown (atexit handles this automatically) +tuft.shutdown() +``` + +### Other `init()` patterns + +```python +# Connect to a specific running server +tuft.init(address="http://gpu-cluster:10610") + +# Use an existing config file +tuft.init(config="/path/to/tuft_config.yaml") + +# No arguments — relies on env vars or default config file +tuft.init() + +# Get a service client (auto-inits if not already done) +service_client = tuft.get_service_client() +``` + +## Service Discovery Priority + +When `tuft.init()` is called, it tries to find an existing service in this order: + +1. `address=...` argument passed to `init()` +2. `TUFT_ADDRESS` environment variable +3. Address file at `~/.tuft/tuft_current_server` +4. Process scan (looks for running `tuft launch` or `uvicorn` processes) +5. Default port probe: `http://127.0.0.1:10610` + +If no service is found, embedded mode starts a new one using configuration from: + +1. `config=...` argument passed to `init()` +2. `TUFT_CONFIG` environment variable +3. `model=...` argument → auto-generates minimal config +4. `TUFT_MODEL_PATH` environment variable → auto-generates minimal config +5. Default config file: `~/.tuft/configs/tuft_config.yaml` +6. None available → raises `RuntimeError` with helpful guidance + +## Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `TUFT_ADDRESS` | Address of running TuFT service | — | +| `TUFT_API_KEY` | API authentication key | Auto-generated | +| `TUFT_CONFIG` | Path to configuration file | — | +| `TUFT_MODEL_PATH` | Model path for auto-config generation | — | +| `TUFT_ENABLE_AUTO_CONNECT` | Enable auto-connect in `get_service_client()` | `"1"` | +| `TUFT_HOME` | TuFT home directory | `~/.tuft` | +| `TUFT_HOST` | Server bind address | `127.0.0.1` | +| `TUFT_PORT` | Server bind port | `10610` | + +## Lifecycle + +- **Embedded services** are tied to the main process. When the Python process exits (normally or via signal), the embedded TuFT server is automatically terminated via `atexit`. +- **Daemon services** (`tuft launch`) are independent and persist until manually stopped. +- `tuft.shutdown()` can be called explicitly to stop an embedded service early. +- `tuft.init()` is **idempotent** — calling it multiple times is safe (no-op after first success). + +## Integration with RL Frameworks + +For framework integrations (e.g., Trinity), the pattern is: + +```python +import tuft + +# In your framework's initialization code: +tuft.init(model=model_path, ignore_reinit_error=True) +service_client = tuft.get_service_client() + +# Use service_client as before... +``` + +This requires no changes to the user's workflow — the framework handles TuFT setup transparently. diff --git a/docs/sphinx_doc/source/getting-started/index.md b/docs/sphinx_doc/source/getting-started/index.md index b44facc..51c051b 100644 --- a/docs/sphinx_doc/source/getting-started/index.md +++ b/docs/sphinx_doc/source/getting-started/index.md @@ -20,6 +20,14 @@ Install TuFT from source, PyPI, or Docker. Run your first training and sampling example with TuFT. ::: + +:::{grid-item-card} Embedded Mode +:link: embedded-mode +:link-type: doc +:shadow: none + +Use `tuft.init()` for automatic service discovery and startup. +::: ``` ```{toctree} @@ -28,4 +36,5 @@ Run your first training and sampling example with TuFT. installation quickstart +embedded-mode ``` diff --git a/docs/sphinx_doc/source_zh/getting-started/embedded-mode.md b/docs/sphinx_doc/source_zh/getting-started/embedded-mode.md new file mode 100644 index 0000000..7c107ed --- /dev/null +++ b/docs/sphinx_doc/source_zh/getting-started/embedded-mode.md @@ -0,0 +1,117 @@ +# 嵌入式模式 + +## 背景 + +TuFT 被设计为 RL 训练框架(如 Trinity)的**透明计算服务层**。在生产环境中,TuFT 通常作为独立守护进程运行(`tuft launch`),用户需要: + +1. 编写 `tuft_config.yaml` 配置文件 +2. 手动执行 `tuft launch --config ...` 启动服务 +3. 设置 `TINKER_BASE_URL` 环境变量供客户端连接 + +这种手动配置带来了额外负担,尤其是: +- **RL 框架用户**:只想运行训练脚本,不想学习 TuFT 的安装和配置 +- **开发调试**:需要快速迭代的工作流 +- **CI 流水线**:需要可复现的自包含环境 + +**嵌入式模式**通过提供 `tuft.init()` API 解决了这个问题——类似 `ray.init()`——自动完成服务发现、配置生成、启动和连接。 + +## 两种运行模式 + +| | 守护进程模式 | 嵌入式模式 | +|---|---|---| +| 启动方式 | `tuft launch --config ...` | `tuft.init(model=...)` | +| 生命周期 | 独立进程,手动管理 | 跟随主进程,atexit 自动清理 | +| 适用场景 | 生产部署、多用户共享集群 | 开发调试、训练脚本、CI | +| 服务发现 | 用户手动设置 `TINKER_BASE_URL` | 自动(环境变量 → 地址文件 → 进程扫描 → 默认端口) | + +**两种模式共存**:`tuft.init()` 首先尝试发现已有的守护进程服务。只有在找不到运行中的服务时,才会启动嵌入式实例。 + +## 快速开始 + +```python +import tuft + +# 初始化 TuFT — 自动发现已有服务或启动一个新的 +tuft.init(model="/path/to/Qwen2.5-0.5B-Instruct") + +# 使用 service client 进行训练 +training_client = tuft.create_training_client( + base_model="Qwen2.5-0.5B-Instruct", + rank=8, +) +# ... 你的训练循环 ... + +# 可选:显式关闭(atexit 会自动处理) +tuft.shutdown() +``` + +### 其他 `init()` 模式 + +```python +# 连接到指定的运行中服务 +tuft.init(address="http://gpu-cluster:10610") + +# 使用已有配置文件 +tuft.init(config="/path/to/tuft_config.yaml") + +# 无参数 — 依赖环境变量或默认配置文件 +tuft.init() + +# 获取 service client(未初始化时自动触发 init) +service_client = tuft.get_service_client() +``` + +## 服务发现优先级 + +调用 `tuft.init()` 时,按以下顺序尝试发现已有服务: + +1. `address=...` 参数显式传入 +2. `TUFT_ADDRESS` 环境变量 +3. 地址文件 `~/.tuft/tuft_current_server` +4. 进程扫描(查找运行中的 `tuft launch` 或 `uvicorn` 进程) +5. 默认端口探测:`http://127.0.0.1:10610` + +如果未发现服务,嵌入式模式按以下优先级获取配置并启动: + +1. `config=...` 参数显式传入 +2. `TUFT_CONFIG` 环境变量 +3. `model=...` 参数 → 自动生成最小配置 +4. `TUFT_MODEL_PATH` 环境变量 → 自动生成最小配置 +5. 默认配置文件:`~/.tuft/configs/tuft_config.yaml` +6. 全部没有 → 抛出 `RuntimeError` 并给出提示 + +## 环境变量 + +| 变量 | 说明 | 默认值 | +|------|------|--------| +| `TUFT_ADDRESS` | TuFT 服务地址 | — | +| `TUFT_API_KEY` | API 认证密钥 | 自动生成 | +| `TUFT_CONFIG` | 配置文件路径 | — | +| `TUFT_MODEL_PATH` | 模型路径(用于自动生成配置) | — | +| `TUFT_ENABLE_AUTO_CONNECT` | 启用 `get_service_client()` 自动连接 | `"1"` | +| `TUFT_HOME` | TuFT 主目录 | `~/.tuft` | +| `TUFT_HOST` | 服务绑定地址 | `127.0.0.1` | +| `TUFT_PORT` | 服务绑定端口 | `10610` | + +## 生命周期 + +- **嵌入式服务**绑定到主进程。当 Python 进程退出(正常或信号)时,嵌入式 TuFT 服务通过 `atexit` 自动终止。 +- **守护进程服务**(`tuft launch`)独立运行,持续到手动停止。 +- `tuft.shutdown()` 可显式调用以提前停止嵌入式服务。 +- `tuft.init()` 是**幂等的** — 多次调用安全(首次成功后为空操作)。 + +## 与 RL 框架集成 + +框架集成(如 Trinity)的模式: + +```python +import tuft + +# 在框架的初始化代码中: +tuft.init(model=model_path, ignore_reinit_error=True) +service_client = tuft.get_service_client() + +# 像之前一样使用 service_client... +``` + +这不需要改变用户的工作流 — 框架透明地处理 TuFT 的配置和启动。 diff --git a/docs/sphinx_doc/source_zh/getting-started/index.md b/docs/sphinx_doc/source_zh/getting-started/index.md index 2e411ba..5ad22a9 100644 --- a/docs/sphinx_doc/source_zh/getting-started/index.md +++ b/docs/sphinx_doc/source_zh/getting-started/index.md @@ -20,6 +20,14 @@ 使用 TuFT 运行您的第一个训练与推理示例。 ::: + +:::{grid-item-card} 嵌入式模式 +:link: embedded-mode +:link-type: doc +:shadow: none + +使用 `tuft.init()` 实现自动服务发现和启动。 +::: ``` ```{toctree} @@ -28,4 +36,5 @@ installation quickstart +embedded-mode ``` diff --git a/examples/chat_sft/train.py b/examples/chat_sft/train.py index 937c961..cd0b7d2 100644 --- a/examples/chat_sft/train.py +++ b/examples/chat_sft/train.py @@ -112,6 +112,8 @@ def compute_weighted_nll_from_outputs(loss_fn_outputs, datums) -> float: def connect(cfg: Config) -> tinker.ServiceClient: print(f"[1/6] connect service: {cfg.base_url}") + # Alternative: use tuft.get_service_client() for auto-discovery/embedded mode + # import tuft; return tuft.get_service_client() return tinker.ServiceClient(base_url=cfg.base_url, api_key=cfg.api_key) diff --git a/examples/countdown_rl/train.py b/examples/countdown_rl/train.py index d3023b5..b89ad95 100644 --- a/examples/countdown_rl/train.py +++ b/examples/countdown_rl/train.py @@ -97,6 +97,8 @@ def init_wandb(cfg: Config): def connect(cfg: Config) -> tinker.ServiceClient: print(f"[1/6] connect service: {cfg.base_url}") + # Alternative: use tuft.get_service_client() for auto-discovery/embedded mode + # import tuft; return tuft.get_service_client() return tinker.ServiceClient(base_url=cfg.base_url, api_key=cfg.api_key) diff --git a/examples/embedded_quickstart/train.py b/examples/embedded_quickstart/train.py new file mode 100644 index 0000000..f21a749 --- /dev/null +++ b/examples/embedded_quickstart/train.py @@ -0,0 +1,87 @@ +"""Embedded TuFT quickstart — demonstrates auto-init (embedded mode). + +This example shows how to use TuFT in embedded mode, where the service +is automatically started and managed within your training script's lifecycle. + +No manual `tuft launch` or configuration files needed! + +Usage: + python train.py --model /path/to/Qwen2.5-0.5B-Instruct + +The script will: +1. Auto-detect the model and GPU configuration +2. Start a TuFT server in the background +3. Connect and run a minimal training loop +4. Automatically shut down the server on exit +""" + +from __future__ import annotations + +import argparse + +from tinker import types + +import tuft + + +def main(): + parser = argparse.ArgumentParser(description="Embedded TuFT quickstart") + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the base model (e.g., /path/to/Qwen2.5-0.5B-Instruct)", + ) + parser.add_argument("--rank", type=int, default=8, help="LoRA rank") + parser.add_argument("--steps", type=int, default=5, help="Number of training steps") + args = parser.parse_args() + + # ========================================================================= + # Step 1: Initialize TuFT in embedded mode + # This will auto-detect GPUs, generate a minimal config, and start the server. + # If a TuFT server is already running, it will connect to it instead. + # ========================================================================= + print(f"[1/4] Initializing TuFT with model: {args.model}") + tuft.init(model=args.model) + print(" TuFT initialized (mode: embedded)") + + # ========================================================================= + # Step 2: Create a training client + # ========================================================================= + print(f"[2/4] Creating LoRA training client (rank={args.rank})") + training_client = tuft.create_training_client( + base_model=args.model, + rank=args.rank, + train_mlp=True, + train_attn=True, + ) + + # ========================================================================= + # Step 3: Run a minimal training loop + # ========================================================================= + print(f"[3/4] Running {args.steps} training steps (with fake data)") + for step in range(args.steps): + # Create a fake training datum (in practice, use real tokenized data) + datum = types.Datum( + model_input=types.ModelInput.from_ints([101, 42, 37, 102]), + loss_fn_inputs={ + "target_tokens": types.TensorData( + data=[101, 99, 73, 102], dtype="int64", shape=[4] + ), + "weights": types.TensorData(data=[1.0, 1.0, 1.0, 1.0], dtype="float32", shape=[4]), + }, + ) + training_client.forward_backward([datum], loss_fn="cross_entropy").result() + training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result() + print(f" Step {step + 1}/{args.steps} complete") + + # ========================================================================= + # Step 4: Clean up (optional — atexit handles this automatically) + # ========================================================================= + print("[4/4] Shutting down TuFT") + tuft.shutdown() + print(" Done!") + + +if __name__ == "__main__": + main() diff --git a/src/tuft/__init__.py b/src/tuft/__init__.py index 3373cc8..d0afa2f 100644 --- a/src/tuft/__init__.py +++ b/src/tuft/__init__.py @@ -1,6 +1,24 @@ """TuFT service package.""" +# Runtime API (embedded mode & service management) +from .runtime import ( # noqa: E402 + create_sampling_client, + create_training_client, + get_service_client, + init, + is_initialized, + shutdown, +) from .server import create_root_app -__all__ = ["create_root_app"] +__all__ = [ + "create_root_app", + # Runtime API + "init", + "shutdown", + "is_initialized", + "get_service_client", + "create_training_client", + "create_sampling_client", +] diff --git a/src/tuft/cli.py b/src/tuft/cli.py index 15677ae..d901c78 100644 --- a/src/tuft/cli.py +++ b/src/tuft/cli.py @@ -17,6 +17,7 @@ get_redis_store, validate_config_signature, ) +from .runtime._constants import get_address_file from .server import create_root_app from .telemetry import init_telemetry from .telemetry.metrics import ResourceMetricsCollector @@ -227,13 +228,43 @@ def launch( # Initialize telemetry before starting the server _init_telemetry(app_config, log_level) logging.getLogger("tuft").info("Server starting on %s:%s", host, port) - uvicorn.run( - create_root_app(app_config), - host=host, - port=port, - log_level=log_level, - reload=reload, - ) + + # Write address file so embedded mode / other processes can discover this server + _write_address_file(host, port) + + try: + uvicorn.run( + create_root_app(app_config), + host=host, + port=port, + log_level=log_level, + reload=reload, + ) + finally: + _remove_address_file(host, port) + + +def _write_address_file(host: str, port: int) -> None: + """Write server address to the discovery file.""" + address_file = get_address_file() + address = f"http://{host}:{port}" + try: + address_file.parent.mkdir(parents=True, exist_ok=True) + address_file.write_text(address) + except OSError: + pass + + +def _remove_address_file(host: str, port: int) -> None: + """Remove address file if it points to this server.""" + address_file = get_address_file() + try: + if address_file.exists(): + content = address_file.read_text().strip() + if content == f"http://{host}:{port}": + address_file.unlink() + except OSError: + pass def main() -> None: diff --git a/src/tuft/runtime/__init__.py b/src/tuft/runtime/__init__.py new file mode 100644 index 0000000..c300db5 --- /dev/null +++ b/src/tuft/runtime/__init__.py @@ -0,0 +1,338 @@ +"""TuFT Runtime — public API for embedded mode and service management. + +Usage: + import tuft + + tuft.init(model="/path/to/model") # auto-discover or start embedded server + client = tuft.get_service_client() # returns tinker.ServiceClient + tuft.shutdown() # stop embedded server if any +""" + +from __future__ import annotations + +import logging +import os +import threading +from pathlib import Path +from typing import Optional + +import tinker + +from ._config_gen import generate_api_key, generate_config_file +from ._constants import ( + DEFAULT_HOST, + DEFAULT_PORT, + ENV_TUFT_API_KEY, + ENV_TUFT_CONFIG, + ENV_TUFT_ENABLE_AUTO_CONNECT, + ENV_TUFT_HOST, + ENV_TUFT_MODEL_PATH, + ENV_TUFT_PORT, + get_credentials_file, + get_default_config_path, +) +from ._discovery import discover +from ._launcher import EmbeddedServer + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Global state (singleton, thread-safe) +# --------------------------------------------------------------------------- + +_lock = threading.Lock() +_initialized = False +_mode: Optional[str] = None # "connected" | "embedded" +_service_client: Optional[tinker.ServiceClient] = None +_embedded_server: Optional[EmbeddedServer] = None +_api_key: Optional[str] = None + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +__all__ = [ + "init", + "shutdown", + "is_initialized", + "get_service_client", + "create_training_client", + "create_sampling_client", + "generate_api_key", +] + + +def init( + *, + address: Optional[str] = None, + model: Optional[str | Path] = None, + config: Optional[str | Path] = None, + host: Optional[str] = None, + port: Optional[int] = None, + api_key: Optional[str] = None, + ignore_reinit_error: bool = True, +) -> None: + """Initialize TuFT: discover an existing service or start an embedded one. + + This function is idempotent. Calling it multiple times is safe when + ignore_reinit_error=True (default). + + Args: + address: Explicit service address to connect to. + model: Model path for auto-generating config and starting embedded server. + config: Path to a YAML config file for the embedded server. + host: Host to bind the embedded server (default: 127.0.0.1). + port: Port to bind the embedded server (default: 10610). + api_key: API key for authentication. Auto-generated if not provided. + ignore_reinit_error: If True, silently skip if already initialized. + + Raises: + RuntimeError: If already initialized and ignore_reinit_error is False. + RuntimeError: If no service found and cannot start embedded server. + """ + global _initialized, _mode, _service_client, _embedded_server, _api_key + + if _initialized: + if ignore_reinit_error: + return + raise RuntimeError( + "TuFT is already initialized. Call tuft.shutdown() first, " + "or use ignore_reinit_error=True." + ) + + with _lock: + # Double-check after acquiring lock + if _initialized: + if ignore_reinit_error: + return + raise RuntimeError("TuFT is already initialized.") + + resolved_host = host or os.environ.get(ENV_TUFT_HOST, DEFAULT_HOST) + resolved_port = port or int(os.environ.get(ENV_TUFT_PORT, str(DEFAULT_PORT))) + + # Phase 1: Try to discover an existing service + discovered = discover(explicit_address=address) + if discovered: + _api_key = api_key or os.environ.get(ENV_TUFT_API_KEY) + _service_client = tinker.ServiceClient( + base_url=discovered, + api_key=_api_key or "", + ) + _mode = "connected" + _initialized = True + logger.info("TuFT initialized in connected mode: %s", discovered) + return + + # If explicit address was given but not healthy, fail + if address: + raise RuntimeError( + f"Cannot connect to TuFT at {address}. " + "Ensure the server is running or remove the address parameter." + ) + + # Phase 2: Auto-start embedded server + # Determine config source + config_path = _resolve_config_for_launch(config, model, resolved_host, resolved_port) + if config_path is None: + raise RuntimeError( + "Cannot start TuFT: no service found and no configuration available.\n" + "Please provide one of:\n" + " - tuft.init(address='http://...') to connect to existing service\n" + " - tuft.init(model='/path/to/model') to auto-start\n" + " - tuft.init(config='/path/to/config.yaml') to auto-start\n" + " - Set TUFT_ADDRESS, TUFT_MODEL_PATH, or TUFT_CONFIG env var\n" + " - Create ~/.tuft/configs/tuft_config.yaml" + ) + + _embedded_server = EmbeddedServer( + config_path=config_path, + host=resolved_host, + port=resolved_port, + ) + server_address = _embedded_server.start() + + # Resolve API key + if _api_key is None: + _api_key = api_key or os.environ.get(ENV_TUFT_API_KEY) or "" + + _service_client = tinker.ServiceClient( + base_url=server_address, + api_key=_api_key, + ) + _mode = "embedded" + _initialized = True + logger.info("TuFT initialized in embedded mode: %s", server_address) + + +def shutdown() -> None: + """Shutdown TuFT: disconnect and stop embedded server if running.""" + global _initialized, _mode, _service_client, _embedded_server, _api_key + + with _lock: + if _embedded_server is not None: + _embedded_server.shutdown() + _embedded_server = None + _service_client = None + _api_key = None + _mode = None + _initialized = False + logger.info("TuFT shut down.") + + +def is_initialized() -> bool: + """Return True if TuFT has been initialized.""" + return _initialized + + +def get_service_client() -> tinker.ServiceClient: + """Return the global ServiceClient, auto-initializing if needed. + + Returns: + A connected tinker.ServiceClient instance. + + Raises: + RuntimeError: If auto-initialization fails. + """ + global _service_client + if not _initialized: + # Lazy init: check if auto-connect is enabled + auto_connect = os.environ.get(ENV_TUFT_ENABLE_AUTO_CONNECT, "1") + if auto_connect != "1": + raise RuntimeError( + "TuFT is not initialized and auto-connect is disabled " + f"(TUFT_ENABLE_AUTO_CONNECT={auto_connect}). " + "Call tuft.init() explicitly." + ) + init() + if _service_client is None: + raise RuntimeError("TuFT initialization failed: no service client available.") + return _service_client + + +def create_training_client( + base_model: str, + rank: int = 16, + **kwargs, +): + """Convenience: create a LoRA training client via the global ServiceClient. + + Args: + base_model: The base model name/path registered on the server. + If a full path is given, it will be resolved to the model directory name. + rank: LoRA rank. + **kwargs: Additional arguments passed to create_lora_training_client. + + Returns: + A LoRA training client. + """ + # If base_model looks like an absolute path, extract the directory name + # since the server registers models by directory name (e.g. "Qwen2.5-0.5B-Instruct") + if os.path.sep in base_model or base_model.startswith("/"): + base_model = Path(base_model).name + + client = get_service_client() + return client.create_lora_training_client( + base_model=base_model, + rank=rank, + **kwargs, + ) + + +def create_sampling_client( + base_model: Optional[str] = None, + model_path: Optional[str] = None, + **kwargs, +): + """Convenience: create a sampling client via the global ServiceClient. + + Args: + base_model: The base model name (for base model sampling). + If a full path is given, it will be resolved to the model directory name. + model_path: A specific model path (e.g., LoRA checkpoint). + **kwargs: Additional arguments passed to create_sampling_client. + + Returns: + A sampling client. + """ + # Resolve absolute path to model name + if base_model and (os.path.sep in base_model or base_model.startswith("/")): + base_model = Path(base_model).name + + client = get_service_client() + if model_path: + return client.create_sampling_client(model_path=model_path, **kwargs) + return client.create_sampling_client(base_model=base_model, **kwargs) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _resolve_config_for_launch( + config: Optional[str | Path], + model: Optional[str | Path], + host: str, + port: int, +) -> Optional[Path]: + """Resolve config file path for launching embedded server. + + Priority: + 1. Explicit config argument + 2. TUFT_CONFIG env var + 3. model argument -> auto-generate config + 4. TUFT_MODEL_PATH env var -> auto-generate config + 5. Default config file (~/.tuft/configs/tuft_config.yaml) + """ + global _api_key + + # 1. Explicit config + if config is not None: + path = Path(config) + if not path.exists(): + raise RuntimeError(f"Config file not found: {path}") + return path + + # 2. TUFT_CONFIG env var + env_config = os.environ.get(ENV_TUFT_CONFIG) + if env_config: + path = Path(env_config) + if path.exists(): + return path + logger.warning("TUFT_CONFIG=%s does not exist, skipping", env_config) + + # 3. model argument -> auto-generate + if model is not None: + config_path, api_key = generate_config_file(model, host=host, port=port) + _api_key = api_key + _save_credentials(api_key) + return config_path + + # 4. TUFT_MODEL_PATH env var + env_model = os.environ.get(ENV_TUFT_MODEL_PATH) + if env_model: + config_path, api_key = generate_config_file(env_model, host=host, port=port) + _api_key = api_key + _save_credentials(api_key) + return config_path + + # 5. Default config file + default_config = get_default_config_path() + if default_config.exists(): + return default_config + + return None + + +def _save_credentials(api_key: str) -> None: + """Save auto-generated API key to credentials file.""" + creds_file = get_credentials_file() + try: + creds_file.parent.mkdir(parents=True, exist_ok=True) + creds_file.write_text(api_key) + # Set restrictive permissions + creds_file.chmod(0o600) + logger.debug("Saved credentials to %s", creds_file) + except OSError as e: + logger.warning("Failed to save credentials: %s", e) diff --git a/src/tuft/runtime/_config_gen.py b/src/tuft/runtime/_config_gen.py new file mode 100644 index 0000000..afaae9d --- /dev/null +++ b/src/tuft/runtime/_config_gen.py @@ -0,0 +1,201 @@ +"""Auto-generate a minimal AppConfig from a model path. + +Used by the embedded mode to create a working configuration without user +intervention. Detects GPU count/memory and infers reasonable defaults. +""" + +from __future__ import annotations + +import json +import logging +import secrets +import tempfile +from pathlib import Path +from typing import Optional + +from ._constants import get_default_checkpoint_dir + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# GPU detection helpers +# --------------------------------------------------------------------------- + + +def _detect_gpu_count() -> int: + """Return the number of available NVIDIA GPUs, or 0 if none.""" + try: + import pynvml + + pynvml.nvmlInit() + count = pynvml.nvmlDeviceGetCount() + pynvml.nvmlShutdown() + return count + except Exception: + return 0 + + +def _detect_gpu_memory_gb() -> float: + """Return the memory (GB) of the first GPU, or 0.0 if unavailable.""" + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + pynvml.nvmlShutdown() + return info.total / (1024**3) # type: ignore[operator] + except Exception: + return 0.0 + + +# --------------------------------------------------------------------------- +# Model metadata helpers +# --------------------------------------------------------------------------- + + +def _read_max_position_embeddings(model_path: Path) -> Optional[int]: + """Try to read max_position_embeddings from config.json in the model directory.""" + config_file = model_path / "config.json" + if not config_file.exists(): + return None + try: + with open(config_file) as f: + data = json.load(f) + # Common keys in HuggingFace model configs + for key in ("max_position_embeddings", "n_positions", "seq_length"): + if key in data: + return int(data[key]) + except (json.JSONDecodeError, OSError, ValueError): + pass + return None + + +# --------------------------------------------------------------------------- +# Config generation +# --------------------------------------------------------------------------- + + +def _infer_tensor_parallel_size(gpu_count: int, gpu_memory_gb: float) -> int: + """Infer a reasonable tensor_parallel_size. + + Heuristic: use all GPUs if model likely needs multiple (>40GB models), + otherwise default to 1. + """ + if gpu_count <= 1: + return 1 + # For now, default to 1 for simplicity; users can override via config + return 1 + + +def _infer_max_model_len(model_path: Path) -> int: + """Infer max_model_len from model config or use a conservative default.""" + max_pos = _read_max_position_embeddings(model_path) + if max_pos is not None: + # Cap at 32768 to avoid OOM on smaller GPUs + return min(max_pos, 32768) + return 4096 # conservative default + + +def generate_api_key() -> str: + """Generate a random local API key.""" + return f"tml-{secrets.token_hex(16)}" + + +def generate_config_dict( + model_path: str | Path, + *, + host: str = "127.0.0.1", + port: int = 10610, + checkpoint_dir: Optional[Path] = None, + api_key: Optional[str] = None, +) -> dict: + """Generate a minimal AppConfig dict suitable for YAML serialization. + + Args: + model_path: Path to the model directory or HuggingFace model ID. + host: Server bind address. + port: Server bind port. + checkpoint_dir: Where to store checkpoints. + api_key: API key for auth; auto-generated if None. + + Returns: + A dict that can be passed to AppConfig.model_validate(). + """ + model_path = Path(model_path) + model_name = model_path.name # Use directory name as model_name + + gpu_count = _detect_gpu_count() + gpu_memory_gb = _detect_gpu_memory_gb() + tp_size = _infer_tensor_parallel_size(gpu_count, gpu_memory_gb) + max_model_len = _infer_max_model_len(model_path) + + if checkpoint_dir is None: + checkpoint_dir = get_default_checkpoint_dir() + + if api_key is None: + api_key = generate_api_key() + + config = { + "checkpoint_dir": str(checkpoint_dir), + "authorized_users": {api_key: "local"}, # pragma: allowlist secret + "supported_models": [ + { + "model_name": model_name, + "model_path": str(model_path), + "max_model_len": max_model_len, + "tensor_parallel_size": tp_size, + } + ], + } + + logger.info( + "Generated minimal config: model=%s, tp_size=%d, max_model_len=%d, gpus_detected=%d", + model_name, + tp_size, + max_model_len, + gpu_count, + ) + return config + + +def generate_config_file( + model_path: str | Path, + *, + host: str = "127.0.0.1", + port: int = 10610, + checkpoint_dir: Optional[Path] = None, + api_key: Optional[str] = None, +) -> tuple[Path, str]: + """Generate a temporary YAML config file. + + Returns: + A tuple of (config_file_path, api_key). + """ + from omegaconf import OmegaConf + + if api_key is None: + api_key = generate_api_key() + + config_dict = generate_config_dict( + model_path, + host=host, + port=port, + checkpoint_dir=checkpoint_dir, + api_key=api_key, + ) + + # Write to a temp file that persists until process exit + tmp = tempfile.NamedTemporaryFile( + mode="w", + suffix=".yaml", + prefix="tuft_auto_", + delete=False, + ) + conf = OmegaConf.create(config_dict) + tmp.write(OmegaConf.to_yaml(conf)) + tmp.close() + + logger.info("Generated temporary config file: %s", tmp.name) + return Path(tmp.name), api_key diff --git a/src/tuft/runtime/_constants.py b/src/tuft/runtime/_constants.py new file mode 100644 index 0000000..30d1b49 --- /dev/null +++ b/src/tuft/runtime/_constants.py @@ -0,0 +1,88 @@ +"""Constants for TuFT runtime: environment variable names, default paths, and ports.""" + +from __future__ import annotations + +import os +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Environment variable names +# --------------------------------------------------------------------------- + +ENV_TUFT_ADDRESS = "TUFT_ADDRESS" +"""Service address, e.g. http://127.0.0.1:10610""" + +ENV_TUFT_API_KEY = "TUFT_API_KEY" # pragma: allowlist secret +"""API authentication key""" + +ENV_TUFT_CONFIG = "TUFT_CONFIG" +"""Path to config file""" + +ENV_TUFT_MODEL_PATH = "TUFT_MODEL_PATH" +"""Model path for auto-generating minimal config""" + +ENV_TUFT_HOME = "TUFT_HOME" +"""TuFT home directory, defaults to ~/.tuft""" + +ENV_TUFT_HOST = "TUFT_HOST" +"""Server bind address""" + +ENV_TUFT_PORT = "TUFT_PORT" +"""Server bind port""" + +ENV_TUFT_ENABLE_AUTO_CONNECT = "TUFT_ENABLE_AUTO_CONNECT" +"""Whether to enable auto-connect, defaults to "1" (enabled)""" + +# --------------------------------------------------------------------------- +# Defaults +# --------------------------------------------------------------------------- + +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 10610 + +# --------------------------------------------------------------------------- +# Derived paths (resolved at import time, but functions allow override) +# --------------------------------------------------------------------------- + + +def get_tuft_home() -> Path: + """Return the TUFT_HOME directory, defaulting to ~/.tuft.""" + return Path(os.environ.get(ENV_TUFT_HOME, Path.home() / ".tuft")) + + +def get_address_file() -> Path: + """Return the path to the server address file.""" + return get_tuft_home() / "tuft_current_server" + + +def get_default_config_path() -> Path: + """Return the default config file path.""" + return get_tuft_home() / "configs" / "tuft_config.yaml" + + +def get_default_checkpoint_dir() -> Path: + """Return the default checkpoint directory.""" + return get_tuft_home() / "checkpoints" + + +def get_credentials_file() -> Path: + """Return the path to the auto-generated credentials file.""" + return get_tuft_home() / "credentials" + + +# --------------------------------------------------------------------------- +# Health check endpoint +# --------------------------------------------------------------------------- + +HEALTHZ_PATH = "/api/v1/healthz" +"""Health check endpoint path used for service discovery.""" + +HEALTHZ_TIMEOUT = 2.0 +"""Timeout in seconds for a single health check request.""" + +STARTUP_TIMEOUT = 120.0 +"""Maximum seconds to wait for an embedded service to become healthy.""" + +STARTUP_POLL_INTERVAL = 0.5 +"""Seconds between health check polls during startup.""" diff --git a/src/tuft/runtime/_discovery.py b/src/tuft/runtime/_discovery.py new file mode 100644 index 0000000..a2bbf8e --- /dev/null +++ b/src/tuft/runtime/_discovery.py @@ -0,0 +1,142 @@ +"""Multi-level service discovery for TuFT. + +Discovery order: +1. TUFT_ADDRESS environment variable +2. Address file (~/.tuft/tuft_current_server) + healthz validation +3. Process scan (psutil: look for tuft/uvicorn processes) +4. Default port probe (http://127.0.0.1:10610) + +Each level validates via GET /api/v1/healthz before returning. +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +import httpx +import psutil + +from ._constants import ( + DEFAULT_HOST, + DEFAULT_PORT, + ENV_TUFT_ADDRESS, + HEALTHZ_PATH, + HEALTHZ_TIMEOUT, + get_address_file, +) + + +logger = logging.getLogger(__name__) + + +def _check_health(address: str) -> bool: + """Send a GET to /api/v1/healthz and return True if status 200.""" + url = address.rstrip("/") + HEALTHZ_PATH + try: + resp = httpx.get(url, timeout=HEALTHZ_TIMEOUT) + return resp.status_code == 200 + except (httpx.ConnectError, httpx.TimeoutException, OSError): + return False + + +def _discover_from_env() -> Optional[str]: + """Level 1: Check TUFT_ADDRESS environment variable.""" + address = os.environ.get(ENV_TUFT_ADDRESS) + if address: + address = address.strip() + if _check_health(address): + logger.info("Discovered TuFT service via TUFT_ADDRESS: %s", address) + return address + logger.debug("TUFT_ADDRESS=%s is set but not healthy", address) + return None + + +def _discover_from_address_file() -> Optional[str]: + """Level 2: Read address from ~/.tuft/tuft_current_server.""" + address_file = get_address_file() + if not address_file.exists(): + return None + try: + address = address_file.read_text().strip() + except OSError: + return None + if not address: + return None + if _check_health(address): + logger.info("Discovered TuFT service via address file: %s", address) + return address + logger.debug("Address file points to %s but not healthy", address) + return None + + +def _discover_from_process_scan() -> Optional[str]: + """Level 3: Scan local processes for a running tuft server.""" + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + cmdline = proc.info.get("cmdline") or [] + cmdline_str = " ".join(cmdline) + # Look for 'tuft launch' or 'uvicorn' serving tuft + if "tuft" in cmdline_str and ("launch" in cmdline_str or "uvicorn" in cmdline_str): + # Try to extract port from --port argument + port = DEFAULT_PORT + for i, arg in enumerate(cmdline): + if arg in ("--port", "-p") and i + 1 < len(cmdline): + try: + port = int(cmdline[i + 1]) + except ValueError: + pass + break + address = f"http://{DEFAULT_HOST}:{port}" + if _check_health(address): + logger.info( + "Discovered TuFT service via process scan (pid=%s): %s", + proc.info["pid"], + address, + ) + return address + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + continue + return None + + +def _discover_from_default_port() -> Optional[str]: + """Level 4: Probe the default port.""" + address = f"http://{DEFAULT_HOST}:{DEFAULT_PORT}" + if _check_health(address): + logger.info("Discovered TuFT service on default port: %s", address) + return address + return None + + +def discover(explicit_address: Optional[str] = None) -> Optional[str]: + """Run multi-level service discovery. + + Args: + explicit_address: If provided, check this address first (highest priority). + + Returns: + A validated service address, or None if no healthy service found. + """ + # Level 0: Explicit address passed to init() + if explicit_address: + explicit_address = explicit_address.strip() + if _check_health(explicit_address): + logger.info("Connected to explicitly provided address: %s", explicit_address) + return explicit_address + logger.warning("Explicit address %s is not healthy", explicit_address) + return None + + # Level 1-4: progressive discovery + for level_fn in ( + _discover_from_env, + _discover_from_address_file, + _discover_from_process_scan, + _discover_from_default_port, + ): + result = level_fn() + if result: + return result + + return None diff --git a/src/tuft/runtime/_launcher.py b/src/tuft/runtime/_launcher.py new file mode 100644 index 0000000..0ae3a3f --- /dev/null +++ b/src/tuft/runtime/_launcher.py @@ -0,0 +1,207 @@ +"""Embedded mode launcher: start a TuFT server as a subprocess. + +Responsibilities: +- Launch `tuft launch` in a subprocess +- Poll healthz until ready (timeout configurable) +- Write address file for discovery +- Register atexit cleanup (terminate subprocess + remove address file) +""" + +from __future__ import annotations + +import atexit +import logging +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +import httpx + +from ._constants import ( + DEFAULT_HOST, + DEFAULT_PORT, + HEALTHZ_PATH, + HEALTHZ_TIMEOUT, + STARTUP_POLL_INTERVAL, + STARTUP_TIMEOUT, + get_address_file, +) + + +logger = logging.getLogger(__name__) + + +class EmbeddedServer: + """Manages a TuFT server subprocess (embedded mode).""" + + def __init__( + self, + config_path: Path, + host: str = DEFAULT_HOST, + port: int = DEFAULT_PORT, + timeout: float = STARTUP_TIMEOUT, + ): + self.config_path = config_path + self.host = host + self.port = port + self.timeout = timeout + self.process: Optional[subprocess.Popen] = None + self._address: Optional[str] = None + self._atexit_registered = False + + @property + def address(self) -> Optional[str]: + return self._address + + def start(self) -> str: + """Start the server subprocess and wait for it to be healthy. + + Returns: + The address of the running server. + + Raises: + RuntimeError: If the server fails to start within the timeout. + """ + if self.process is not None and self.process.poll() is None: + # Already running + if self._address: + return self._address + + cmd = [ + sys.executable, + "-m", + "tuft.cli", + "launch", + "--config", + str(self.config_path), + "--host", + self.host, + "--port", + str(self.port), + ] + + env = os.environ.copy() + # Ensure subprocess doesn't inherit auto-connect to avoid recursion + env.pop("TUFT_ADDRESS", None) + + logger.info("Starting embedded TuFT server: %s", " ".join(cmd)) + + # When log level is DEBUG or server fails, show subprocess output directly. + # Otherwise pipe it for collection on failure. + # For now, always inherit stderr so users can see server startup progress. + inherit_io = True + + self.process = subprocess.Popen( + cmd, + env=env, + stdout=None if inherit_io else subprocess.PIPE, + stderr=None if inherit_io else subprocess.PIPE, + # Use a new process group so we can cleanly terminate + preexec_fn=os.setsid if sys.platform != "win32" else None, + ) + + # Wait for healthz + self._address = f"http://{self.host}:{self.port}" + if not self._wait_for_healthy(): + # Collect stderr for debugging + stderr_output = "" + if self.process.stderr: + try: + stderr_output = self.process.stderr.read().decode(errors="replace")[:2000] + except Exception: + pass + self._terminate() + raise RuntimeError( + f"TuFT embedded server failed to start within {self.timeout}s. " + f"Config: {self.config_path}\n" + f"Stderr: {stderr_output}" + ) + + # Write address file + self._write_address_file() + + # Register cleanup + if not self._atexit_registered: + atexit.register(self.shutdown) + self._atexit_registered = True + + logger.info("Embedded TuFT server ready at %s (pid=%d)", self._address, self.process.pid) + return self._address + + def _wait_for_healthy(self) -> bool: + """Poll healthz until healthy or timeout.""" + url = f"{self._address}{HEALTHZ_PATH}" + deadline = time.monotonic() + self.timeout + + while time.monotonic() < deadline: + # Check if process died + if self.process and self.process.poll() is not None: + return False + try: + resp = httpx.get(url, timeout=HEALTHZ_TIMEOUT) + if resp.status_code == 200: + return True + except (httpx.ConnectError, httpx.TimeoutException, OSError): + pass + time.sleep(STARTUP_POLL_INTERVAL) + + return False + + def _write_address_file(self) -> None: + """Write the server address to the address file.""" + address_file = get_address_file() + try: + address_file.parent.mkdir(parents=True, exist_ok=True) + address_file.write_text(self._address or "") + logger.debug("Wrote address file: %s", address_file) + except OSError as e: + logger.warning("Failed to write address file %s: %s", address_file, e) + + def _remove_address_file(self) -> None: + """Remove the address file if it points to our address.""" + address_file = get_address_file() + try: + if address_file.exists(): + content = address_file.read_text().strip() + if content == self._address: + address_file.unlink() + logger.debug("Removed address file: %s", address_file) + except OSError: + pass + + def _terminate(self) -> None: + """Terminate the subprocess.""" + if self.process is None: + return + if self.process.poll() is not None: + return + try: + if sys.platform != "win32": + # Kill the entire process group + os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) + else: + self.process.terminate() + self.process.wait(timeout=10) + except (OSError, subprocess.TimeoutExpired): + try: + self.process.kill() + self.process.wait(timeout=5) + except Exception: + pass + logger.info("Terminated embedded TuFT server (pid=%d)", self.process.pid) + + def shutdown(self) -> None: + """Stop the embedded server and clean up.""" + self._terminate() + self._remove_address_file() + self.process = None + self._address = None + + @property + def is_running(self) -> bool: + """Check if the subprocess is still running.""" + return self.process is not None and self.process.poll() is None diff --git a/tests/test_runtime_config_gen.py b/tests/test_runtime_config_gen.py new file mode 100644 index 0000000..3b5ecfe --- /dev/null +++ b/tests/test_runtime_config_gen.py @@ -0,0 +1,99 @@ +"""Tests for tuft.runtime._config_gen module.""" + +from __future__ import annotations + +import json +from unittest.mock import patch + +from tuft.runtime._config_gen import ( + _infer_max_model_len, + _read_max_position_embeddings, + generate_api_key, + generate_config_dict, + generate_config_file, +) + + +class TestReadMaxPositionEmbeddings: + def test_reads_from_config_json(self, tmp_path): + config = {"max_position_embeddings": 8192, "hidden_size": 4096} + (tmp_path / "config.json").write_text(json.dumps(config)) + assert _read_max_position_embeddings(tmp_path) == 8192 + + def test_returns_none_when_no_config(self, tmp_path): + assert _read_max_position_embeddings(tmp_path) is None + + def test_returns_none_on_invalid_json(self, tmp_path): + (tmp_path / "config.json").write_text("not json") + assert _read_max_position_embeddings(tmp_path) is None + + def test_reads_n_positions(self, tmp_path): + config = {"n_positions": 2048} + (tmp_path / "config.json").write_text(json.dumps(config)) + assert _read_max_position_embeddings(tmp_path) == 2048 + + +class TestInferMaxModelLen: + def test_caps_at_32768(self, tmp_path): + config = {"max_position_embeddings": 131072} + (tmp_path / "config.json").write_text(json.dumps(config)) + assert _infer_max_model_len(tmp_path) == 32768 + + def test_uses_value_when_below_cap(self, tmp_path): + config = {"max_position_embeddings": 4096} + (tmp_path / "config.json").write_text(json.dumps(config)) + assert _infer_max_model_len(tmp_path) == 4096 + + def test_default_when_no_config(self, tmp_path): + assert _infer_max_model_len(tmp_path) == 4096 + + +class TestGenerateApiKey: + def test_format(self): + key = generate_api_key() + assert key.startswith("tml-") + assert len(key) > 20 + + def test_uniqueness(self): + assert generate_api_key() != generate_api_key() + + +class TestGenerateConfigDict: + @patch("tuft.runtime._config_gen._detect_gpu_count", return_value=2) + @patch("tuft.runtime._config_gen._detect_gpu_memory_gb", return_value=80.0) + def test_basic_structure(self, mock_mem, mock_gpu, tmp_path): + model_path = tmp_path / "Qwen2.5-0.5B-Instruct" + model_path.mkdir() + config_json = {"max_position_embeddings": 16384} + (model_path / "config.json").write_text(json.dumps(config_json)) + + result = generate_config_dict(model_path, api_key="test-key") + assert "supported_models" in result + assert len(result["supported_models"]) == 1 + assert result["supported_models"][0]["model_name"] == "Qwen2.5-0.5B-Instruct" + assert result["supported_models"][0]["max_model_len"] == 16384 + assert result["authorized_users"] == {"test-key": "local"} # pragma: allowlist secret + + @patch("tuft.runtime._config_gen._detect_gpu_count", return_value=0) + @patch("tuft.runtime._config_gen._detect_gpu_memory_gb", return_value=0.0) + def test_no_gpu_still_generates(self, mock_mem, mock_gpu, tmp_path): + model_path = tmp_path / "my-model" + model_path.mkdir() + result = generate_config_dict(model_path) + assert result["supported_models"][0]["tensor_parallel_size"] == 1 + + +class TestGenerateConfigFile: + @patch("tuft.runtime._config_gen._detect_gpu_count", return_value=1) + @patch("tuft.runtime._config_gen._detect_gpu_memory_gb", return_value=24.0) + def test_creates_yaml_file(self, mock_mem, mock_gpu, tmp_path): + model_path = tmp_path / "test-model" + model_path.mkdir() + + config_path, api_key = generate_config_file(model_path, api_key="my-key") + assert config_path.exists() + assert config_path.suffix == ".yaml" + assert api_key == "my-key" # pragma: allowlist secret + + content = config_path.read_text() + assert "test-model" in content diff --git a/tests/test_runtime_discovery.py b/tests/test_runtime_discovery.py new file mode 100644 index 0000000..0f3ca25 --- /dev/null +++ b/tests/test_runtime_discovery.py @@ -0,0 +1,107 @@ +"""Tests for tuft.runtime._discovery module.""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from tuft.runtime._constants import DEFAULT_HOST, DEFAULT_PORT +from tuft.runtime._discovery import ( + _discover_from_address_file, + _discover_from_default_port, + _discover_from_env, + discover, +) + + +@pytest.fixture +def mock_healthy(): + """Mock _check_health to return True for any address.""" + with patch("tuft.runtime._discovery._check_health", return_value=True) as m: + yield m + + +@pytest.fixture +def mock_unhealthy(): + """Mock _check_health to return False for any address.""" + with patch("tuft.runtime._discovery._check_health", return_value=False) as m: + yield m + + +class TestDiscoverFromEnv: + def test_returns_address_when_set_and_healthy(self, mock_healthy, monkeypatch): + monkeypatch.setenv("TUFT_ADDRESS", "http://myhost:9999") + result = _discover_from_env() + assert result == "http://myhost:9999" + + def test_returns_none_when_not_set(self, mock_healthy, monkeypatch): + monkeypatch.delenv("TUFT_ADDRESS", raising=False) + result = _discover_from_env() + assert result is None + + def test_returns_none_when_unhealthy(self, mock_unhealthy, monkeypatch): + monkeypatch.setenv("TUFT_ADDRESS", "http://dead:1234") + result = _discover_from_env() + assert result is None + + +class TestDiscoverFromAddressFile: + def test_returns_address_when_file_exists_and_healthy( + self, mock_healthy, tmp_path, monkeypatch + ): + # Override TUFT_HOME to use tmp dir + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + addr_file = tmp_path / "tuft_current_server" + addr_file.write_text("http://filehost:8080") + result = _discover_from_address_file() + assert result == "http://filehost:8080" + + def test_returns_none_when_no_file(self, mock_healthy, tmp_path, monkeypatch): + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + result = _discover_from_address_file() + assert result is None + + def test_returns_none_when_file_unhealthy(self, mock_unhealthy, tmp_path, monkeypatch): + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + addr_file = tmp_path / "tuft_current_server" + addr_file.write_text("http://dead:1111") + result = _discover_from_address_file() + assert result is None + + +class TestDiscoverFromDefaultPort: + def test_returns_address_when_healthy(self, mock_healthy): + result = _discover_from_default_port() + assert result == f"http://{DEFAULT_HOST}:{DEFAULT_PORT}" + + def test_returns_none_when_unhealthy(self, mock_unhealthy): + result = _discover_from_default_port() + assert result is None + + +class TestDiscover: + def test_explicit_address_takes_priority(self, mock_healthy): + result = discover(explicit_address="http://explicit:5555") + assert result == "http://explicit:5555" + + def test_explicit_address_unhealthy_returns_none(self, mock_unhealthy): + result = discover(explicit_address="http://dead:5555") + assert result is None + + def test_env_takes_priority_over_file(self, tmp_path, monkeypatch): + """When TUFT_ADDRESS is set and healthy, address file is not checked.""" + monkeypatch.setenv("TUFT_ADDRESS", "http://envhost:7777") + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + addr_file = tmp_path / "tuft_current_server" + addr_file.write_text("http://filehost:8888") + + with patch("tuft.runtime._discovery._check_health", return_value=True): + result = discover() + assert result == "http://envhost:7777" + + def test_returns_none_when_nothing_found(self, mock_unhealthy, tmp_path, monkeypatch): + monkeypatch.delenv("TUFT_ADDRESS", raising=False) + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + result = discover() + assert result is None diff --git a/tests/test_runtime_init.py b/tests/test_runtime_init.py new file mode 100644 index 0000000..e540a3d --- /dev/null +++ b/tests/test_runtime_init.py @@ -0,0 +1,106 @@ +"""Tests for tuft.runtime init/shutdown/is_initialized flow.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +import tuft + + +@pytest.fixture(autouse=True) +def reset_runtime(): + """Ensure runtime state is clean before and after each test.""" + tuft.shutdown() + yield + tuft.shutdown() + + +class TestInit: + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_init_connected_mode(self, mock_client_cls, mock_discover): + mock_client_cls.return_value = MagicMock() + tuft.init() + assert tuft.is_initialized() + # Should be in connected mode + from tuft.runtime import _mode + + assert _mode == "connected" + + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_init_idempotent(self, mock_client_cls, mock_discover): + mock_client_cls.return_value = MagicMock() + tuft.init() + tuft.init() # should not raise + assert tuft.is_initialized() + + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_init_raises_on_reinit_when_flag_false(self, mock_client_cls, mock_discover): + mock_client_cls.return_value = MagicMock() + tuft.init() + with pytest.raises(RuntimeError, match="already initialized"): + tuft.init(ignore_reinit_error=False) + + @patch("tuft.runtime.discover", return_value=None) + def test_init_raises_when_no_service_and_no_config(self, mock_discover, tmp_path, monkeypatch): + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + monkeypatch.delenv("TUFT_CONFIG", raising=False) + monkeypatch.delenv("TUFT_MODEL_PATH", raising=False) + with pytest.raises(RuntimeError, match="Cannot start TuFT"): + tuft.init() + + +class TestShutdown: + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_shutdown_resets_state(self, mock_client_cls, mock_discover): + mock_client_cls.return_value = MagicMock() + tuft.init() + assert tuft.is_initialized() + tuft.shutdown() + assert not tuft.is_initialized() + + def test_shutdown_when_not_initialized(self): + # Should not raise + tuft.shutdown() + + +class TestGetServiceClient: + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_returns_client_after_init(self, mock_client_cls, mock_discover): + mock_instance = MagicMock() + mock_client_cls.return_value = mock_instance + tuft.init() + client = tuft.get_service_client() + assert client is mock_instance + + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_auto_init_on_get_service_client(self, mock_client_cls, mock_discover): + mock_client_cls.return_value = MagicMock() + # Should auto-init + client = tuft.get_service_client() + assert tuft.is_initialized() + assert client is not None + + def test_raises_when_auto_connect_disabled(self, monkeypatch): + monkeypatch.setenv("TUFT_ENABLE_AUTO_CONNECT", "0") + with pytest.raises(RuntimeError, match="auto-connect is disabled"): + tuft.get_service_client() + + +class TestIsInitialized: + def test_false_initially(self): + assert not tuft.is_initialized() + + @patch("tuft.runtime.discover", return_value="http://localhost:10610") + @patch("tuft.runtime.tinker.ServiceClient") + def test_true_after_init(self, mock_client_cls, mock_discover): + mock_client_cls.return_value = MagicMock() + tuft.init() + assert tuft.is_initialized() diff --git a/tests/test_runtime_launcher.py b/tests/test_runtime_launcher.py new file mode 100644 index 0000000..e7bb387 --- /dev/null +++ b/tests/test_runtime_launcher.py @@ -0,0 +1,75 @@ +"""Tests for tuft.runtime._launcher module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from tuft.runtime._launcher import EmbeddedServer + + +class TestEmbeddedServer: + def test_init_attributes(self, tmp_path): + config = tmp_path / "config.yaml" + config.write_text("supported_models: []") + server = EmbeddedServer(config_path=config, host="127.0.0.1", port=9999) + assert server.config_path == config + assert server.host == "127.0.0.1" + assert server.port == 9999 + assert server.process is None + assert not server.is_running + + @patch("tuft.runtime._launcher.EmbeddedServer._wait_for_healthy", return_value=True) + @patch("subprocess.Popen") + def test_start_success(self, mock_popen, mock_wait, tmp_path, monkeypatch): + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + config = tmp_path / "config.yaml" + config.write_text("supported_models: []") + + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.pid = 12345 + mock_popen.return_value = mock_proc + + server = EmbeddedServer(config_path=config) + address = server.start() + assert address == "http://127.0.0.1:10610" + assert server.is_running + + @patch("tuft.runtime._launcher.EmbeddedServer._wait_for_healthy", return_value=False) + @patch("subprocess.Popen") + def test_start_failure_raises(self, mock_popen, mock_wait, tmp_path, monkeypatch): + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + config = tmp_path / "config.yaml" + config.write_text("supported_models: []") + + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.pid = 12345 + mock_proc.stderr = MagicMock() + mock_proc.stderr.read.return_value = b"some error" + mock_popen.return_value = mock_proc + + server = EmbeddedServer(config_path=config, timeout=1) + with pytest.raises(RuntimeError, match="failed to start"): + server.start() + + @patch("tuft.runtime._launcher.EmbeddedServer._wait_for_healthy", return_value=True) + @patch("subprocess.Popen") + def test_shutdown_terminates_process(self, mock_popen, mock_wait, tmp_path, monkeypatch): + monkeypatch.setenv("TUFT_HOME", str(tmp_path)) + config = tmp_path / "config.yaml" + config.write_text("supported_models: []") + + mock_proc = MagicMock() + mock_proc.poll.return_value = None + mock_proc.pid = 12345 + mock_popen.return_value = mock_proc + + server = EmbeddedServer(config_path=config) + server.start() + server.shutdown() + + assert server.process is None + assert server.address is None