From 3cd2e72423f0ff5a9b93ccd17c315fc59b6dc9d9 Mon Sep 17 00:00:00 2001 From: Dairus01 Date: Tue, 9 Dec 2025 16:44:17 +0100 Subject: [PATCH] feat: implement canonical warmup cords (kick + status) --- chutes/chute/base.py | 144 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 142 insertions(+), 2 deletions(-) diff --git a/chutes/chute/base.py b/chutes/chute/base.py index c3ee622..42cd94f 100644 --- a/chutes/chute/base.py +++ b/chutes/chute/base.py @@ -7,8 +7,10 @@ import uuid from loguru import logger from typing import List, Tuple, Callable -from fastapi import FastAPI -from pydantic import BaseModel, ConfigDict +from fastapi import FastAPI, Response, status +from pydantic import BaseModel, ConfigDict, Field +import aiohttp +import time from chutes.image import Image from chutes.util.context import is_remote from chutes.chute.node_selector import NodeSelector @@ -78,6 +80,11 @@ def __init__( self.redoc_url = None self.tee = tee + # Warmup state + self._warmup_state = WarmupState() + self._warmup_lock = asyncio.Lock() + self._warmup_task: asyncio.Task | None = None + @property def name(self): return self._name @@ -216,6 +223,125 @@ async def initialize(self): for job in self._jobs: logger.info(f"Found job definition: {job._func.__name__}") + # Add warmup endpoints + self.add_api_route( + "/warmup/kick", + self._warmup_kick, + methods=["POST"], + status_code=status.HTTP_202_ACCEPTED, + ) + self.add_api_route( + "/warmup/status", + self._warmup_status, + methods=["GET"], + ) + + async def _check_models_ready(self, base_url: str, auth_header: str | None) -> bool: + """ + Check if /v1/models returns 200. + """ + headers = {} + if auth_header: + headers["Authorization"] = auth_header + + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10.0)) as session: + async with session.get(f"{base_url.rstrip('/')}/v1/models", headers=headers) as resp: + return resp.status == 200 + except Exception: + return False + + async def _do_warmup(self, base_url: str, auth_header: str | None): + """ + Background warmup task. + """ + t0 = time.time() + steps = [ + ("pulling", 10), + ("loading", 40), + ("tokenizer", 60), + ("tiny_infer", 80), + ] + + try: + async with self._warmup_lock: + self._warmup_state.state = "warming" + self._warmup_state.phase = "pulling" + self._warmup_state.progress = 0 + self._warmup_state.error = None + self._warmup_state.started_at = t0 + + # Simulate phases (in a real implementation, we might hook into actual events if possible) + # For now, we just advance through phases to show progress while waiting for the model. + # The real gate is /v1/models. + for phase, pct in steps: + await asyncio.sleep(1.0) + async with self._warmup_lock: + self._warmup_state.phase = phase + self._warmup_state.progress = pct + self._warmup_state.elapsed_sec = time.time() - t0 + + # Wait for canonical gate + # Try for up to ~3 minutes + for _ in range(180): + if await self._check_models_ready(base_url, auth_header): + async with self._warmup_lock: + self._warmup_state.phase = "ready" + self._warmup_state.progress = 100 + self._warmup_state.state = "ready" + self._warmup_state.elapsed_sec = time.time() - t0 + return + await asyncio.sleep(1.0) + async with self._warmup_lock: + self._warmup_state.elapsed_sec = time.time() - t0 + + # Timeout + async with self._warmup_lock: + self._warmup_state.state = "error" + self._warmup_state.error = "models_not_ready_within_timeout" + self._warmup_state.elapsed_sec = time.time() - t0 + + except Exception as e: + async with self._warmup_lock: + self._warmup_state.state = "error" + self._warmup_state.error = str(e) + self._warmup_state.elapsed_sec = time.time() - t0 + + async def _warmup_kick(self, response: Response, authorization: str | None = None): + """ + Kick off the warmup process. + """ + # Auth check is handled by middleware usually, but if we need to pass it to the check loop: + # We'll assume the request to this endpoint has the same auth as /v1/models needs. + + # Determine base URL. Since we are running inside the chute, localhost:8000 is likely where /v1/models is. + # But we should respect CHUTES_API_BASE if set. + base = os.environ.get("CHUTES_API_BASE", "http://127.0.0.1:8000") + + async with self._warmup_lock: + if self._warmup_state.state in ("warming", "ready"): + return self._warmup_state.model_dump(by_alias=True) + + # Reset and spawn + self._warmup_state.state = "warming" + self._warmup_state.phase = "pulling" + self._warmup_state.progress = 0 + self._warmup_state.error = None + self._warmup_state.started_at = time.time() + self._warmup_state.elapsed_sec = 0.0 + + if self._warmup_task is None or self._warmup_task.done(): + self._warmup_task = asyncio.create_task(self._do_warmup(base, authorization)) + + return self._warmup_state.model_dump(by_alias=True) + + async def _warmup_status(self): + """ + Get the current warmup status. + """ + async with self._warmup_lock: + return self._warmup_state.model_dump(by_alias=True) + def cord(self, **kwargs): """ Decorator to define a parachute cord (function). @@ -241,3 +367,17 @@ def job(self, **kwargs): class ChutePack(BaseModel): chute: Chute model_config = ConfigDict(arbitrary_types_allowed=True) + + +# Warmup state model +class WarmupState(BaseModel): + schema_version: str = Field("chutes.warmup.status.v1", alias="schema") + state: str = "idle" # idle|warming|ready|error + phase: str = "idle" # pulling|loading|tokenizer|tiny_infer|ready + progress: int = 0 # 0..100 + elapsed_sec: float = 0.0 + error: str | None = None + started_at: float | None = None + + model_config = ConfigDict(populate_by_name=True) +