Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 142 additions & 2 deletions chutes/chute/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import uuid
from loguru import logger
from typing import List, Tuple, Callable
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict
from fastapi import FastAPI, Response, status
from pydantic import BaseModel, ConfigDict, Field
import aiohttp
import time
from chutes.image import Image
from chutes.util.context import is_remote
from chutes.chute.node_selector import NodeSelector
Expand Down Expand Up @@ -78,6 +80,11 @@ def __init__(
self.redoc_url = None
self.tee = tee

# Warmup state
self._warmup_state = WarmupState()
self._warmup_lock = asyncio.Lock()
self._warmup_task: asyncio.Task | None = None

@property
def name(self):
return self._name
Expand Down Expand Up @@ -216,6 +223,125 @@ async def initialize(self):
for job in self._jobs:
logger.info(f"Found job definition: {job._func.__name__}")

# Add warmup endpoints
self.add_api_route(
"/warmup/kick",
self._warmup_kick,
methods=["POST"],
status_code=status.HTTP_202_ACCEPTED,
)
self.add_api_route(
"/warmup/status",
self._warmup_status,
methods=["GET"],
)

async def _check_models_ready(self, base_url: str, auth_header: str | None) -> bool:
"""
Check if /v1/models returns 200.
"""
headers = {}
if auth_header:
headers["Authorization"] = auth_header

try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10.0)) as session:
async with session.get(f"{base_url.rstrip('/')}/v1/models", headers=headers) as resp:
return resp.status == 200
except Exception:
return False

async def _do_warmup(self, base_url: str, auth_header: str | None):
"""
Background warmup task.
"""
t0 = time.time()
steps = [
("pulling", 10),
("loading", 40),
("tokenizer", 60),
("tiny_infer", 80),
]

try:
async with self._warmup_lock:
self._warmup_state.state = "warming"
self._warmup_state.phase = "pulling"
self._warmup_state.progress = 0
self._warmup_state.error = None
self._warmup_state.started_at = t0

# Simulate phases (in a real implementation, we might hook into actual events if possible)
# For now, we just advance through phases to show progress while waiting for the model.
# The real gate is /v1/models.
for phase, pct in steps:
await asyncio.sleep(1.0)
async with self._warmup_lock:
self._warmup_state.phase = phase
self._warmup_state.progress = pct
self._warmup_state.elapsed_sec = time.time() - t0

# Wait for canonical gate
# Try for up to ~3 minutes
for _ in range(180):
if await self._check_models_ready(base_url, auth_header):
async with self._warmup_lock:
self._warmup_state.phase = "ready"
self._warmup_state.progress = 100
self._warmup_state.state = "ready"
self._warmup_state.elapsed_sec = time.time() - t0
return
await asyncio.sleep(1.0)
async with self._warmup_lock:
self._warmup_state.elapsed_sec = time.time() - t0

# Timeout
async with self._warmup_lock:
self._warmup_state.state = "error"
self._warmup_state.error = "models_not_ready_within_timeout"
self._warmup_state.elapsed_sec = time.time() - t0

except Exception as e:
async with self._warmup_lock:
self._warmup_state.state = "error"
self._warmup_state.error = str(e)
self._warmup_state.elapsed_sec = time.time() - t0

async def _warmup_kick(self, response: Response, authorization: str | None = None):
"""
Kick off the warmup process.
"""
# Auth check is handled by middleware usually, but if we need to pass it to the check loop:
# We'll assume the request to this endpoint has the same auth as /v1/models needs.

# Determine base URL. Since we are running inside the chute, localhost:8000 is likely where /v1/models is.
# But we should respect CHUTES_API_BASE if set.
base = os.environ.get("CHUTES_API_BASE", "http://127.0.0.1:8000")

async with self._warmup_lock:
if self._warmup_state.state in ("warming", "ready"):
return self._warmup_state.model_dump(by_alias=True)

# Reset and spawn
self._warmup_state.state = "warming"
self._warmup_state.phase = "pulling"
self._warmup_state.progress = 0
self._warmup_state.error = None
self._warmup_state.started_at = time.time()
self._warmup_state.elapsed_sec = 0.0

if self._warmup_task is None or self._warmup_task.done():
self._warmup_task = asyncio.create_task(self._do_warmup(base, authorization))

return self._warmup_state.model_dump(by_alias=True)

async def _warmup_status(self):
"""
Get the current warmup status.
"""
async with self._warmup_lock:
return self._warmup_state.model_dump(by_alias=True)

def cord(self, **kwargs):
"""
Decorator to define a parachute cord (function).
Expand All @@ -241,3 +367,17 @@ def job(self, **kwargs):
class ChutePack(BaseModel):
chute: Chute
model_config = ConfigDict(arbitrary_types_allowed=True)


# Warmup state model
class WarmupState(BaseModel):
schema_version: str = Field("chutes.warmup.status.v1", alias="schema")
state: str = "idle" # idle|warming|ready|error
phase: str = "idle" # pulling|loading|tokenizer|tiny_infer|ready
progress: int = 0 # 0..100
elapsed_sec: float = 0.0
error: str | None = None
started_at: float | None = None

model_config = ConfigDict(populate_by_name=True)