diff --git a/chutes/_version.py b/chutes/_version.py index 797946c..b372e25 100644 --- a/chutes/_version.py +++ b/chutes/_version.py @@ -1 +1 @@ -version = "0.5.4.rc9" +version = "0.6.0.rc0" diff --git a/chutes/chutes-netnanny.so b/chutes/chutes-netnanny.so old mode 100755 new mode 100644 index 5f7e3a9..4863516 Binary files a/chutes/chutes-netnanny.so and b/chutes/chutes-netnanny.so differ diff --git a/chutes/constants.py b/chutes/constants.py index 49ea3b5..97930b9 100644 --- a/chutes/constants.py +++ b/chutes/constants.py @@ -1,4 +1,8 @@ CHUTES_DIR = ".chutes" + +# TEE attestation service in-cluster static IP (attestation-service-internal) +ATTESTATION_SERVICE_BASE_URL = "https://10.43.50.50:8443" + HOTKEY_HEADER = "X-Chutes-Hotkey" SIGNATURE_HEADER = "X-Chutes-Signature" NONCE_HEADER = "X-Chutes-Nonce" diff --git a/chutes/entrypoint/run.py b/chutes/entrypoint/run.py index 875ba92..6eec7da 100644 --- a/chutes/entrypoint/run.py +++ b/chutes/entrypoint/run.py @@ -17,7 +17,6 @@ import typer import psutil import base64 -import socket import secrets import threading import traceback @@ -30,16 +29,20 @@ from pydantic import BaseModel from ipaddress import ip_address from uvicorn import Config, Server -from fastapi import Request, Response, status, HTTPException +from fastapi import FastAPI, Request, Response, status, HTTPException from fastapi.responses import ORJSONResponse from starlette.middleware.base import BaseHTTPMiddleware -from chutes.entrypoint.verify import GpuVerifier +from chutes.entrypoint.verify import ( + GpuVerifier, + TeeEvidenceService, +) from chutes.util.hf import verify_cache, CacheVerificationError from prometheus_client import generate_latest, CONTENT_TYPE_LATEST from substrateinterface import Keypair, KeypairType from chutes.entrypoint._shared import ( get_launch_token, get_launch_token_data, + is_tee_env, load_chute, miner, authenticate_request, @@ -641,7 +644,7 @@ async def dispatch(self, request: Request, call_next): class GraValMiddleware(BaseHTTPMiddleware): - def __init__(self, app, concurrency: int = 1): + def __init__(self, app: FastAPI, concurrency: int = 1): """ Initialize a semaphore for concurrency control/limits. """ @@ -822,105 +825,12 @@ async def wrapped_iterator(): if not response or not hasattr(response, "body_iterator"): _conn_stats.requests_in_flight.pop(request.request_id, None) - -def start_dummy_socket(port_mapping, symmetric_key): - """ - Start a dummy socket based on the port mapping configuration to validate ports. - """ - proto = port_mapping["proto"].lower() - internal_port = port_mapping["internal_port"] - response_text = f"response from {proto} {internal_port}" - if proto in ["tcp", "http"]: - return start_tcp_dummy(internal_port, symmetric_key, response_text) - return start_udp_dummy(internal_port, symmetric_key, response_text) - - -def encrypt_response(symmetric_key, plaintext): - """ - Encrypt the response using AES-CBC with PKCS7 padding. - """ - padder = padding.PKCS7(128).padder() - new_iv = secrets.token_bytes(16) - cipher = Cipher( - algorithms.AES(symmetric_key), - modes.CBC(new_iv), - backend=default_backend(), - ) - padded_data = padder.update(plaintext.encode()) + padder.finalize() - encryptor = cipher.encryptor() - encrypted_data = encryptor.update(padded_data) + encryptor.finalize() - response_cipher = base64.b64encode(encrypted_data).decode() - return new_iv, response_cipher - - -def start_tcp_dummy(port, symmetric_key, response_plaintext): - """ - TCP port check socket. - """ - - def tcp_handler(): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - try: - sock.bind(("0.0.0.0", port)) - sock.listen(1) - logger.info(f"TCP socket listening on port {port}") - conn, addr = sock.accept() - logger.info(f"TCP connection from {addr}") - data = conn.recv(1024) - logger.info(f"TCP received: {data.decode('utf-8', errors='ignore')}") - iv, encrypted_response = encrypt_response(symmetric_key, response_plaintext) - full_response = f"{iv.hex()}|{encrypted_response}".encode() - conn.send(full_response) - logger.info(f"TCP sent encrypted response on port {port}: {full_response=}") - conn.close() - except Exception as e: - logger.info(f"TCP socket error on port {port}: {e}") - raise - finally: - sock.close() - logger.info(f"TCP socket on port {port} closed") - - thread = threading.Thread(target=tcp_handler, daemon=True) - thread.start() - return thread - - -def start_udp_dummy(port, symmetric_key, response_plaintext): - """ - UDP port check socket. - """ - - def udp_handler(): - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - try: - sock.bind(("0.0.0.0", port)) - logger.info(f"UDP socket listening on port {port}") - data, addr = sock.recvfrom(1024) - logger.info(f"UDP received from {addr}: {data.decode('utf-8', errors='ignore')}") - iv, encrypted_response = encrypt_response(symmetric_key, response_plaintext) - full_response = f"{iv.hex()}|{encrypted_response}".encode() - sock.sendto(full_response, addr) - logger.info(f"UDP sent encrypted response on port {port}") - except Exception as e: - logger.info(f"UDP socket error on port {port}: {e}") - raise - finally: - sock.close() - logger.info(f"UDP socket on port {port} closed") - - thread = threading.Thread(target=udp_handler, daemon=True) - thread.start() - return thread - - async def _gather_devices_and_initialize( host: str, port_mappings: list[dict[str, Any]], chute_abspath: str, inspecto_hash: str, -) -> tuple[bool, str, dict[str, Any]]: +) -> tuple[bool, bytes, dict[str, Any]]: """ Gather the GPU info assigned to this pod, submit with our one-time token to get GraVal seed. """ @@ -929,7 +839,6 @@ async def _gather_devices_and_initialize( logger.info("Collecting GPUs and port mappings...") body = {"gpus": [], "port_mappings": port_mappings, "host": host} token_data = get_launch_token_data() - url = token_data.get("url") key = token_data.get("env_key", "a" * 32) logger.info("Collecting full envdump...") @@ -973,16 +882,9 @@ async def _gather_devices_and_initialize( logger.error(f"Error checking disk space: {e}") raise Exception(f"Failed to verify disk space availability: {e}") - # Start up dummy sockets to test port mappings. - dummy_socket_threads = [] - for port_map in port_mappings: - if port_map.get("default"): - continue - dummy_socket_threads.append(start_dummy_socket(port_map, symmetric_key)) - - # Verify GPUs for symmetric key - verifier = GpuVerifier.create(url, body) - symmetric_key, response = await verifier.verify_devices() + # Verify GPUs, spin up dummy sockets, and finalize verification. + verifier = GpuVerifier.create(body) + response = await verifier.verify() # Derive runint session key from validator's pubkey via ECDH if provided # Key derivation happens entirely in C - key never touches Python memory @@ -993,7 +895,7 @@ async def _gather_devices_and_initialize( else: logger.warning("Failed to derive runint session key - using legacy encryption") - return egress, symmetric_key, response + return egress, response # Run a chute (which can be an async job or otherwise long-running process). @@ -1225,310 +1127,339 @@ async def _run_chute(): "default": False, } ) + try: + if is_tee_env(): + await TeeEvidenceService().start() + + # GPU verification plus job fetching. + job_data: dict | None = None + job_id: str | None = None + job_obj: Job | None = None + job_method: str | None = None + job_status_url: str | None = None + activation_url: str | None = None + allow_external_egress: bool | None = False + + chute_filename = os.path.basename(chute_ref_str.split(":")[0] + ".py") + chute_abspath: str = os.path.abspath(os.path.join(os.getcwd(), chute_filename)) + if token: + ( + allow_external_egress, + response, + ) = await _gather_devices_and_initialize( + external_host, + port_mappings, + chute_abspath, + inspecto_hash, + ) + job_id = response.get("job_id") + job_method = response.get("job_method") + job_status_url = response.get("job_status_url") + job_data = response.get("job_data") + activation_url = response.get("activation_url") + code = response["code"] + fs_key = response["fs_key"] + encrypted_cache = response.get("efs") is True + if ( + fs_key + and netnanny.set_secure_fs(chute_abspath.encode(), fs_key.encode(), encrypted_cache) + != 0 + ): + logger.error("NetNanny failed to set secure FS, aborting!") + sys.exit(137) + with open(chute_abspath, "w") as outfile: + outfile.write(code) + + # Secret environment variables, e.g. HF tokens for private models. + if response.get("secrets"): + for secret_key, secret_value in response["secrets"].items(): + os.environ[secret_key] = secret_value + + elif not dev: + logger.error("No GraVal token supplied!") + sys.exit(1) + + # Now we have the chute code available, either because it's dev and the file is plain text here, + # or it's prod and we've fetched the code from the validator and stored it securely. + chute_module, chute = load_chute(chute_ref_str=chute_ref_str, config_path=None, debug=debug) + chute = chute.chute if isinstance(chute, ChutePack) else chute + if job_method: + job_obj = next(j for j in chute._jobs if j.name == job_method) + + # Configure dev method job payload/method/etc. + if dev and dev_job_data_path: + with open(dev_job_data_path) as infile: + job_data = json.loads(infile.read()) + job_id = str(uuid.uuid4()) + job_method = dev_job_method + job_obj = next(j for j in chute._jobs if j.name == dev_job_method) + logger.info(f"Creating task, dev mode, for {job_method=}") + + # Run the chute's initialization code. + await chute.initialize() + + # Encryption/rate-limiting middleware setup. + if dev: + chute.add_middleware(DevMiddleware) + else: + chute.add_middleware( + GraValMiddleware, + concurrency=chute.concurrency, + ) + + # Slurps and processes. + async def _handle_slurp(request: Request): + nonlocal chute_module + return await handle_slurp(request, chute_module) + + async def _wait_for_server_ready(timeout: float = 30.0): + """Wait until the server is accepting connections.""" + import socket + start = asyncio.get_event_loop().time() + while (asyncio.get_event_loop().time() - start) < timeout: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + result = sock.connect_ex(("127.0.0.1", port)) + sock.close() + if result == 0: + return True + except Exception: + pass + await asyncio.sleep(0.1) + return False + + async def _do_activation(): + """Activate after server is listening.""" + if not activation_url: + return + if not await _wait_for_server_ready(): + logger.error("Server failed to start listening") + raise Exception("Server not ready for activation") + activated = False + for attempt in range(10): + if attempt > 0: + await asyncio.sleep(attempt) + try: + async with aiohttp.ClientSession(raise_for_status=False) as session: + async with session.get( + activation_url, headers={"Authorization": token} + ) as resp: + if resp.ok: + logger.success(f"Instance activated: {await resp.text()}") + activated = True + if not dev and not allow_external_egress: + if netnanny.lock_network() != 0: + logger.error("Failed to unlock network") + sys.exit(137) + logger.success("Successfully enabled NetNanny network lock.") + break + logger.error( + f"Instance activation failed: {resp.status=}: {await resp.text()}" + ) + if resp.status == 423: + break + except Exception as e: + logger.error(f"Unexpected error attempting to activate instance: {str(e)}") + if not activated: + logger.error("Failed to activate instance, aborting...") + sys.exit(137) - # GPU verification plus job fetching. - job_data: dict | None = None - symmetric_key: str | None = None - job_id: str | None = None - job_obj: Job | None = None - job_method: str | None = None - job_status_url: str | None = None - activation_url: str | None = None - allow_external_egress: bool | None = False - - chute_filename = os.path.basename(chute_ref_str.split(":")[0] + ".py") - chute_abspath: str = os.path.abspath(os.path.join(os.getcwd(), chute_filename)) - if token: - ( - allow_external_egress, - symmetric_key, - response, - ) = await _gather_devices_and_initialize( - external_host, - port_mappings, - chute_abspath, - inspecto_hash, - ) - job_id = response.get("job_id") - job_method = response.get("job_method") - job_status_url = response.get("job_status_url") - job_data = response.get("job_data") - activation_url = response.get("activation_url") - code = response["code"] - fs_key = response["fs_key"] - encrypted_cache = response.get("efs") is True - if ( - fs_key - and netnanny.set_secure_fs(chute_abspath.encode(), fs_key.encode(), encrypted_cache) - != 0 - ): - logger.error("NetNanny failed to set secure FS, aborting!") - sys.exit(137) - with open(chute_abspath, "w") as outfile: - outfile.write(code) - - # Secret environment variables, e.g. HF tokens for private models. - if response.get("secrets"): - for secret_key, secret_value in response["secrets"].items(): - os.environ[secret_key] = secret_value - - elif not dev: - logger.error("No GraVal token supplied!") - sys.exit(1) - - # Now we have the chute code available, either because it's dev and the file is plain text here, - # or it's prod and we've fetched the code from the validator and stored it securely. - chute_module, chute = load_chute(chute_ref_str=chute_ref_str, config_path=None, debug=debug) - chute = chute.chute if isinstance(chute, ChutePack) else chute - if job_method: - job_obj = next(j for j in chute._jobs if j.name == job_method) - - # Configure dev method job payload/method/etc. - if dev and dev_job_data_path: - with open(dev_job_data_path) as infile: - job_data = json.loads(infile.read()) - job_id = str(uuid.uuid4()) - job_method = dev_job_method - job_obj = next(j for j in chute._jobs if j.name == dev_job_method) - logger.info(f"Creating task, dev mode, for {job_method=}") - - # Run the chute's initialization code. - await chute.initialize() - - # Encryption/rate-limiting middleware setup. - if dev: - chute.add_middleware(DevMiddleware) - else: - chute.add_middleware( - GraValMiddleware, - concurrency=chute.concurrency, - ) - - # Slurps and processes. - async def _handle_slurp(request: Request): - nonlocal chute_module - - return await handle_slurp(request, chute_module) - - async def _wait_for_server_ready(timeout: float = 30.0): - """Wait until the server is accepting connections.""" - import socket - - start = asyncio.get_event_loop().time() - while (asyncio.get_event_loop().time() - start) < timeout: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - result = sock.connect_ex(("127.0.0.1", port)) - sock.close() - if result == 0: - return True - except Exception: - pass - await asyncio.sleep(0.1) - return False - - async def _do_activation(): - """Activate after server is listening.""" - if not activation_url: - return - - if not await _wait_for_server_ready(): - logger.error("Server failed to start listening") - raise Exception("Server not ready for activation") - - activated = False - for attempt in range(10): - if attempt > 0: - await asyncio.sleep(attempt) - try: - async with aiohttp.ClientSession(raise_for_status=False) as session: - async with session.get( - activation_url, headers={"Authorization": token} - ) as resp: - if resp.ok: - logger.success(f"Instance activated: {await resp.text()}") - activated = True - if not dev and not allow_external_egress: - if netnanny.lock_network() != 0: - logger.error("Failed to unlock network") - sys.exit(137) - logger.success("Successfully enabled NetNanny network lock.") - break - - logger.error( - f"Instance activation failed: {resp.status=}: {await resp.text()}" - ) - if resp.status == 423: - break - - except Exception as e: - logger.error(f"Unexpected error attempting to activate instance: {str(e)}") - if not activated: - logger.error("Failed to activate instance, aborting...") - sys.exit(137) - - @chute.on_event("startup") - async def activate_on_startup(): - asyncio.create_task(_do_activation()) + @chute.on_event("startup") + async def activate_on_startup(): + asyncio.create_task(_do_activation()) - async def _handle_fs_hash_challenge(request: Request): - nonlocal chute_abspath - data = request.state.decrypted - return { - "result": await generate_filesystem_hash( - data["salt"], chute_abspath, mode=data.get("mode", "sparse") - ) - } - - async def _handle_conn_stats(request: Request): - return _conn_stats.get_stats() - - # Validation endpoints. - chute.add_api_route("/_ping", pong, methods=["POST"]) - chute.add_api_route("/_token", get_token, methods=["POST"]) - chute.add_api_route("/_metrics", get_metrics, methods=["GET"]) - chute.add_api_route("/_conn_stats", _handle_conn_stats, methods=["GET"]) - chute.add_api_route("/_slurp", _handle_slurp, methods=["POST"]) - chute.add_api_route("/_procs", get_all_process_info, methods=["GET"]) - chute.add_api_route("/_env_sig", get_env_sig, methods=["POST"]) - chute.add_api_route("/_env_dump", get_env_dump, methods=["POST"]) - chute.add_api_route("/_devices", get_devices, methods=["GET"]) - chute.add_api_route("/_device_challenge", process_device_challenge, methods=["GET"]) - chute.add_api_route("/_fs_challenge", process_fs_challenge, methods=["POST"]) - chute.add_api_route("/_fs_hash", _handle_fs_hash_challenge, methods=["POST"]) - chute.add_api_route("/_connectivity", check_connectivity, methods=["POST"]) - - def _handle_nn(request: Request): - return process_netnanny_challenge(chute, request) - - chute.add_api_route("/_netnanny_challenge", _handle_nn, methods=["POST"]) - - # Runtime integrity challenge endpoint. - def _handle_rint(request: Request): - """Handle runtime integrity challenge.""" - challenge = request.state.decrypted.get("challenge") - if not challenge: - return {"error": "missing challenge"} - result = runint_prove(challenge) - if result is None: - return {"error": "runtime integrity not initialized or not bound"} - signature, epoch = result - return { - "signature": signature, - "epoch": epoch, - } - - chute.add_api_route("/_rint", _handle_rint, methods=["POST"]) - - # New envdump endpoints. - import chutes.envdump as envdump - - chute.add_api_route("/_dump", envdump.handle_dump, methods=["POST"]) - chute.add_api_route("/_sig", envdump.handle_sig, methods=["POST"]) - chute.add_api_route("/_toca", envdump.handle_toca, methods=["POST"]) - chute.add_api_route("/_eslurp", envdump.handle_slurp, methods=["POST"]) - - async def _handle_hf_check(request: Request): - """ - Verify HuggingFace cache integrity. - """ - data = request.state.decrypted - repo_id = data.get("repo_id") - revision = data.get("revision") - full_hash_check = data.get("full_hash_check", False) - - if not repo_id or not revision: + async def _handle_fs_hash_challenge(request: Request): + nonlocal chute_abspath + data = request.state.decrypted return { - "error": True, - "reason": "bad_request", - "message": "repo_id and revision are required", - "repo_id": repo_id, - "revision": revision, + "result": await generate_filesystem_hash( + data["salt"], chute_abspath, mode=data.get("mode", "sparse") + ) } - try: - result = await verify_cache( - repo_id=repo_id, - revision=revision, - full_hash_check=full_hash_check, - ) - result["error"] = False - return result - except CacheVerificationError as e: - return e.to_dict() + async def _handle_conn_stats(request: Request): + return _conn_stats.get_stats() + + # Validation endpoints. + chute.add_api_route("/_ping", pong, methods=["POST"]) + chute.add_api_route("/_token", get_token, methods=["POST"]) + chute.add_api_route("/_metrics", get_metrics, methods=["GET"]) + chute.add_api_route("/_conn_stats", _handle_conn_stats, methods=["GET"]) + chute.add_api_route("/_slurp", _handle_slurp, methods=["POST"]) + chute.add_api_route("/_procs", get_all_process_info, methods=["GET"]) + chute.add_api_route("/_env_sig", get_env_sig, methods=["POST"]) + chute.add_api_route("/_env_dump", get_env_dump, methods=["POST"]) + chute.add_api_route("/_devices", get_devices, methods=["GET"]) + chute.add_api_route("/_device_challenge", process_device_challenge, methods=["GET"]) + chute.add_api_route("/_fs_challenge", process_fs_challenge, methods=["POST"]) + chute.add_api_route("/_fs_hash", _handle_fs_hash_challenge, methods=["POST"]) + chute.add_api_route("/_connectivity", check_connectivity, methods=["POST"]) + + def _handle_nn(request: Request): + return process_netnanny_challenge(chute, request) + + chute.add_api_route("/_netnanny_challenge", _handle_nn, methods=["POST"]) + + # Runtime integrity challenge endpoint. + def _handle_rint(request: Request): + """Handle runtime integrity challenge.""" + challenge = request.state.decrypted.get("challenge") + if not challenge: + return {"error": "missing challenge"} + result = runint_prove(challenge) + if result is None: + return {"error": "runtime integrity not initialized or not bound"} + signature, epoch = result + return { + "signature": signature, + "epoch": epoch, + } - chute.add_api_route("/_hf_check", _handle_hf_check, methods=["POST"]) + chute.add_api_route("/_rint", _handle_rint, methods=["POST"]) + + # New envdump endpoints. + import chutes.envdump as envdump + + chute.add_api_route("/_dump", envdump.handle_dump, methods=["POST"]) + chute.add_api_route("/_sig", envdump.handle_sig, methods=["POST"]) + chute.add_api_route("/_toca", envdump.handle_toca, methods=["POST"]) + chute.add_api_route("/_eslurp", envdump.handle_slurp, methods=["POST"]) + + async def _handle_hf_check(request: Request): + """ + Verify HuggingFace cache integrity. + """ + data = request.state.decrypted + repo_id = data.get("repo_id") + revision = data.get("revision") + full_hash_check = data.get("full_hash_check", False) + + if not repo_id or not revision: + return { + "error": True, + "reason": "bad_request", + "message": "repo_id and revision are required", + "repo_id": repo_id, + "revision": revision, + } - logger.success("Added all chutes internal endpoints.") + try: + result = await verify_cache( + repo_id=repo_id, + revision=revision, + full_hash_check=full_hash_check, + ) + result["error"] = False + return result + except CacheVerificationError as e: + return e.to_dict() + + chute.add_api_route("/_hf_check", _handle_hf_check, methods=["POST"]) + + async def _handle_hf_check(request: Request): + """ + Verify HuggingFace cache integrity. + """ + data = request.state.decrypted + repo_id = data.get("repo_id") + revision = data.get("revision") + full_hash_check = data.get("full_hash_check", False) + + if not repo_id or not revision: + return { + "error": True, + "reason": "bad_request", + "message": "repo_id and revision are required", + "repo_id": repo_id, + "revision": revision, + } - # Job shutdown/kill endpoint. - async def _shutdown(): - nonlocal job_obj, server - if not job_obj: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Job task not found", - ) - logger.warning("Shutdown requested.") - if job_obj and not job_obj.cancel_event.is_set(): - job_obj.cancel_event.set() - server.should_exit = True - return {"ok": True} - - # Jobs can't be started until the full suite of validation tests run, - # so we need to provide an endpoint for the validator to use to kick - # it off. - if job_id: - job_task = None - - async def start_job_with_monitoring(**kwargs): - nonlocal job_task - ssh_process = None - job_task = asyncio.create_task(job_obj.run(job_status_url=job_status_url, **kwargs)) - - async def monitor_job(): - try: - result = await job_task - logger.info(f"Job completed with result: {result}") - except Exception as e: - logger.error(f"Job failed with error: {e}") - finally: - logger.info("Job finished, shutting down server...") - if ssh_process: - try: - ssh_process.terminate() - await asyncio.sleep(0.5) - if ssh_process.poll() is None: - ssh_process.kill() - logger.info("SSH server stopped") - except Exception as e: - logger.error(f"Error stopping SSH server: {e}") - server.should_exit = True - - # If the pod defines SSH access, enable it. - if job_obj.ssh and job_data.get("_ssh_public_key"): - ssh_process = await setup_ssh_access(job_data["_ssh_public_key"]) - - asyncio.create_task(monitor_job()) - - await start_job_with_monitoring(**job_data) - logger.info("Started job!") - - chute.add_api_route("/_shutdown", _shutdown, methods=["POST"]) - logger.info("Added shutdown endpoint") - - # Start the uvicorn process, whether in job mode or not. - config = Config( - app=chute, - host=host or "0.0.0.0", - port=port or 8000, - limit_concurrency=1000, - ssl_certfile=certfile, - ssl_keyfile=keyfile, - ) - server = Server(config) - await server.serve() + try: + result = await verify_cache( + repo_id=repo_id, + revision=revision, + full_hash_check=full_hash_check, + ) + result["error"] = False + return result + except CacheVerificationError as e: + return e.to_dict() + + chute.add_api_route("/_hf_check", _handle_hf_check, methods=["POST"]) + + logger.success("Added all chutes internal endpoints.") + + # Job shutdown/kill endpoint. + async def _shutdown(): + nonlocal job_obj, server + if not job_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Job task not found", + ) + logger.warning("Shutdown requested.") + if job_obj and not job_obj.cancel_event.is_set(): + job_obj.cancel_event.set() + server.should_exit = True + return {"ok": True} + + # Jobs can't be started until the full suite of validation tests run, + # so we need to provide an endpoint for the validator to use to kick + # it off. + if job_id: + job_task = None + + async def start_job_with_monitoring(**kwargs): + nonlocal job_task + ssh_process = None + job_task = asyncio.create_task(job_obj.run(job_status_url=job_status_url, **kwargs)) + + async def monitor_job(): + try: + result = await job_task + logger.info(f"Job completed with result: {result}") + except Exception as e: + logger.error(f"Job failed with error: {e}") + finally: + logger.info("Job finished, shutting down server...") + if ssh_process: + try: + ssh_process.terminate() + await asyncio.sleep(0.5) + if ssh_process.poll() is None: + ssh_process.kill() + logger.info("SSH server stopped") + except Exception as e: + logger.error(f"Error stopping SSH server: {e}") + server.should_exit = True + + # If the pod defines SSH access, enable it. + if job_obj.ssh and job_data.get("_ssh_public_key"): + ssh_process = await setup_ssh_access(job_data["_ssh_public_key"]) + + asyncio.create_task(monitor_job()) + + await start_job_with_monitoring(**job_data) + logger.info("Started job!") + + chute.add_api_route("/_shutdown", _shutdown, methods=["POST"]) + logger.info("Added shutdown endpoint") + + # Start the uvicorn process, whether in job mode or not. + config = Config( + app=chute, + host=host or "0.0.0.0", + port=port or 8000, + limit_concurrency=1000, + ssl_certfile=certfile, + ssl_keyfile=keyfile, + ) + server = Server(config) + await server.serve() + finally: + if is_tee_env(): + await TeeEvidenceService().stop() # Kick everything off async def _logged_run(): diff --git a/chutes/entrypoint/verify.py b/chutes/entrypoint/verify.py index 8f2723b..fcc8c6b 100644 --- a/chutes/entrypoint/verify.py +++ b/chutes/entrypoint/verify.py @@ -1,104 +1,224 @@ from abc import abstractmethod +import asyncio import base64 from contextlib import asynccontextmanager +from functools import lru_cache import json import os import ssl +import socket +import threading from urllib.parse import urljoin, urlparse import aiohttp from loguru import logger +from fastapi import FastAPI, Request, HTTPException, status +from uvicorn import Config, Server -from chutes.entrypoint._shared import encrypt_response, get_launch_token, is_tee_env, miner +from chutes.constants import ATTESTATION_SERVICE_BASE_URL +from chutes.entrypoint._shared import encrypt_response, get_launch_token, get_launch_token_data, is_tee_env, miner + + +# TEE endpoint constants +TEE_VERIFICATION_ENDPOINT = "/verify" # Validator only: uses nonce from fetch_symmetric_key +TEE_EVIDENCE_RUNTIME_ENDPOINT = "/evidence" # Third parties: accept nonce from request + +# Global nonce storage for TEE verification (validator flow only) +# Set during fetch_symmetric_key; used only by /verify to prove same instance +_evidence_nonce: str | None = None +_evidence_nonce_locked: bool = False + + + +@asynccontextmanager +async def _use_evidence_nonce(validator_url: str): + """ + Context manager for TEE evidence nonce lifecycle. Fetches nonce from validator, + makes it available for the duration, and automatically clears it when done. + + Args: + validator_url: The base URL of the validator + + Yields: + The nonce value + + Raises: + RuntimeError: If nonce is already locked (multiple verification processes detected) + """ + global _evidence_nonce, _evidence_nonce_locked + + if _evidence_nonce_locked: + raise RuntimeError( + "TEE nonce already locked. Only one verification process should be running." + ) + + # Fetch nonce from validator + url = urljoin(validator_url, "/instances/nonce") + async with aiohttp.ClientSession(raise_for_status=True) as http_session: + async with http_session.get(url) as resp: + logger.success(f"Successfully initiated attestation with validator {validator_url}.") + nonce = await resp.json() + + # Set the nonce and lock + _evidence_nonce = nonce + _evidence_nonce_locked = True + + try: + yield nonce + finally: + # Clean up nonce state + _evidence_nonce = None + _evidence_nonce_locked = False + + +def _get_evidence_nonce() -> str | None: + """Get the current evidence nonce (used by evidence endpoint).""" + return _evidence_nonce + + +@asynccontextmanager +async def _attestation_session(): + """ + Creates an aiohttp session configured for the attestation service. + + SSL verification is disabled because certificate authenticity is verified + through TDX quotes, which include a hash of the service's public key. + """ + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + connector = aiohttp.TCPConnector(ssl=ssl_context) + + async with aiohttp.ClientSession(connector=connector, raise_for_status=True) as session: + yield session class GpuVerifier: - def __init__(self, url, body): + def __init__(self, body: dict): self._token = get_launch_token() - self._url = url + token_data = get_launch_token_data() + self._url = token_data.get("url") + self._symmetric_key: bytes | None = None + self._dummy_threads: list[threading.Thread] = [] self._body = body @classmethod - def create(cls, url, body) -> "GpuVerifier": + def create(cls, body: dict) -> "GpuVerifier": if is_tee_env(): - return TeeGpuVerifier(url, body) + return TeeGpuVerifier(body) else: - return GravalGpuVerifier(url, body) + return GravalGpuVerifier(body) + + def _start_dummy_sockets(self): + if not self._symmetric_key: + raise RuntimeError("Cannot start dummy sockets without symmetric key.") + for port_map in self._body.get("port_mappings", []): + if port_map.get("default"): + continue + self._dummy_threads.append(start_dummy_socket(port_map, self._symmetric_key)) + + async def verify(self): + """ + Execute full verification flow and spin up dummy sockets for port validation. + """ + await self.fetch_symmetric_key() + self._start_dummy_sockets() + response = await self.finalize_verification() + return response + + @abstractmethod + async def fetch_symmetric_key(self) -> bytes: ... @abstractmethod - async def verify_devices(self): ... + async def finalize_verification(self) -> dict: ... class GravalGpuVerifier(GpuVerifier): - async def verify_devices(self): + def __init__(self, body: dict): + super().__init__(body) + self._init_params: dict | None = None + self._proofs = None + self._response_plaintext: str | None = None + + async def fetch_symmetric_key(self): # Fetch the challenges. token = self._token - url = self._url - body = self._body + url = urljoin(self._url + "/", "graval") - body["gpus"] = self.gather_gpus() + self._body["gpus"] = self.gather_gpus() async with aiohttp.ClientSession(raise_for_status=True) as session: logger.info(f"Collected all environment data, submitting to validator: {url}") - async with session.post(url, headers={"Authorization": token}, json=body) as resp: - init_params = await resp.json() - logger.success(f"Successfully fetched initialization params: {init_params=}") - - # First, we initialize graval on all GPUs from the provided seed. - miner()._graval_seed = init_params["seed"] - iterations = init_params.get("iterations", 1) - logger.info(f"Generating proofs from seed={miner()._graval_seed}") - proofs = miner().prove(miner()._graval_seed, iterations=iterations) - - # Use GraVal to extract the symmetric key from the challenge. - sym_key = init_params["symmetric_key"] - bytes_ = base64.b64decode(sym_key["ciphertext"]) - iv = bytes_[:16] - cipher = bytes_[16:] - logger.info("Decrypting payload via proof challenge matrix...") - device_index = [ - miner().get_device_info(i)["uuid"] for i in range(miner()._device_count) - ].index(sym_key["uuid"]) - symmetric_key = bytes.fromhex( - miner().decrypt( - init_params["seed"], - cipher, - iv, - len(cipher), - device_index, - ) - ) - - # Now, we can respond to the URL by encrypting a payload with the symmetric key and sending it back. - plaintext = sym_key["response_plaintext"] - new_iv, response_cipher = encrypt_response(symmetric_key, plaintext) + async with session.post(url, headers={"Authorization": token}, json=self._body) as resp: + self._init_params = await resp.json() logger.success( - f"Completed PoVW challenge, sending back: {plaintext=} " - f"as {response_cipher=} where iv={new_iv.hex()}" + f"Successfully fetched initialization params: {self._init_params=}" ) - # Post the response to the challenge, which returns job data (if any). - async with session.put( - url, - headers={"Authorization": token}, - json={ - "response": response_cipher, - "iv": new_iv.hex(), - "proof": proofs, - }, - raise_for_status=False, - ) as resp: - if resp.ok: - logger.success("Successfully negotiated challenge response!") - response = await resp.json() - # validator_pubkey is returned in POST response, needed for ECDH session key - if "validator_pubkey" in init_params: - response["validator_pubkey"] = init_params["validator_pubkey"] - return symmetric_key, response - else: - # log down the reason of failure to the challenge - detail = await resp.text(encoding="utf-8", errors="replace") - logger.error(f"Failed: {resp.reason} ({resp.status}) {detail}") - resp.raise_for_status() + # First, we initialize graval on all GPUs from the provided seed. + miner()._graval_seed = self._init_params["seed"] + iterations = self._init_params.get("iterations", 1) + logger.info(f"Generating proofs from seed={miner()._graval_seed}") + self._proofs = miner().prove(miner()._graval_seed, iterations=iterations) + + # Use GraVal to extract the symmetric key from the challenge. + sym_key = self._init_params["symmetric_key"] + bytes_ = base64.b64decode(sym_key["ciphertext"]) + iv = bytes_[:16] + cipher = bytes_[16:] + logger.info("Decrypting payload via proof challenge matrix...") + device_index = [ + miner().get_device_info(i)["uuid"] for i in range(miner()._device_count) + ].index(sym_key["uuid"]) + self._symmetric_key = bytes.fromhex( + miner().decrypt( + self._init_params["seed"], + cipher, + iv, + len(cipher), + device_index, + ) + ) + + # Now, we can respond to the URL by encrypting a payload with the symmetric key and sending it back. + self._response_plaintext = sym_key["response_plaintext"] + + async def finalize_verification(self): + + token = self._token + url = urljoin(self._url + "/", "graval") + plaintext = self._response_plaintext + + async with aiohttp.ClientSession(raise_for_status=True) as session: + new_iv, response_cipher = encrypt_response(self._symmetric_key, plaintext) + logger.success( + f"Completed PoVW challenge, sending back: {plaintext=} " + f"as {response_cipher=} where iv={new_iv.hex()}" + ) + + # Post the response to the challenge, which returns job data (if any). + async with session.put( + url, + headers={"Authorization": token}, + json={ + "response": response_cipher, + "iv": new_iv.hex(), + "proof": self._proofs, + }, + raise_for_status=False, + ) as resp: + if resp.ok: + logger.success("Successfully negotiated challenge response!") + response = await resp.json() + # validator_pubkey is returned in POST response, needed for ECDH session key + if "validator_pubkey" in self._init_params: + response["validator_pubkey"] = self._init_params["validator_pubkey"] + return response + else: + # log down the reason of failure to the challenge + detail = await resp.text(encoding="utf-8", errors="replace") + logger.error(f"Failed: {resp.reason} ({resp.status}) {detail}") + resp.raise_for_status() def gather_gpus(self): gpus = [] @@ -108,83 +228,311 @@ def gather_gpus(self): return gpus -class TeeGpuVerifier(GpuVerifier): - @asynccontextmanager - async def _attestation_session(self): - """ - Creates an aiohttp session configured for the attestation service. +def _parse_evidence_port() -> int: + """Parse TEE evidence port from env (default 8002).""" + evidence_port = os.getenv("CHUTES_TEE_EVIDENCE_PORT", "8002") + if not evidence_port.isdigit(): + raise ValueError(f"CHUTES_TEE_EVIDENCE_PORT must be a valid port number, got: {evidence_port}") + return int(evidence_port) + + +class TeeEvidenceService: + """ + Singleton TEE attestation evidence server. Serves GET /_tee_evidence for the + validator (during verification) and for third-party runtime verification. + """ + + _instance: "TeeEvidenceService | None" = None + + def __new__(cls) -> "TeeEvidenceService": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance - SSL verification is disabled because certificate authenticity is verified - through TDX quotes, which include a hash of the service's public key. + def __init__(self): + if not hasattr(self, "_port"): + self._port: int | None = None + self._server: Server | None = None + self._task: asyncio.Task | None = None + + async def start(self) -> dict: """ - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE + Start the evidence server. Idempotent: if already started, returns the same + port mapping without starting a second server. - connector = aiohttp.TCPConnector(ssl=ssl_context) + Returns: + Port mapping dict to append to port_mappings in the run entrypoint, e.g.: + {"proto": "tcp", "internal_port": 8002, "external_port": 8002, "default": False} + """ + if self._task is not None: + return self._port_mapping() + self._port = _parse_evidence_port() + evidence_app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) + evidence_app.add_api_route(TEE_VERIFICATION_ENDPOINT, self._get_verification_evidence, methods=["GET"]) + evidence_app.add_api_route(TEE_EVIDENCE_RUNTIME_ENDPOINT, self._get_runtime_evidence, methods=["GET"]) + config = Config( + app=evidence_app, + host="0.0.0.0", + port=self._port, + limit_concurrency=1000, + log_level="warning", + ) + self._server = Server(config) + self._task = asyncio.create_task(self._server.serve()) + await asyncio.sleep(0.5) + logger.info(f"Started TEE evidence server on port {self._port}") + return self._port_mapping() - async with aiohttp.ClientSession(connector=connector, raise_for_status=True) as session: - yield session + def _port_mapping(self) -> dict: + """Port mapping dict for this service (internal_port == external_port == configured port).""" + return { + "proto": "tcp", + "internal_port": self._port, + "external_port": self._port, + "default": False, + } - async def _get_nonce(self): - parsed = urlparse(self._url) + async def stop(self) -> None: + """Stop the evidence server. No-op if not started.""" + if self._server is None or self._task is None: + return + logger.info("Stopping TEE evidence server...") + self._server.should_exit = True + try: + await asyncio.wait_for(self._task, timeout=2.0) + except asyncio.TimeoutError: + logger.warning("TEE evidence server did not stop within timeout") + except Exception as e: + logger.warning(f"Error stopping TEE evidence server: {e}") + finally: + self._server = None + self._task = None + self._port = None - # Get just the scheme + netloc (host) - validator_url = f"{parsed.scheme}://{parsed.netloc}" - url = urljoin(validator_url, "/servers/nonce") - async with aiohttp.ClientSession(raise_for_status=True) as http_session: - async with http_session.get(url) as resp: - logger.success("Successfully retrieved nonce for attestation evidence.") - data = await resp.json() - return data["nonce"] - - async def _get_gpu_evidence(self): - """ """ - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - - connector = aiohttp.TCPConnector(ssl=ssl_context) - - url = "https://attestation-service-internal.attestation-system.svc.cluster.local.:8443/server/nvtrust/evidence" - nonce = await self._get_nonce() + async def _fetch_evidence(self, nonce: str) -> dict: + """Request evidence from the attestation service for the given nonce.""" + url = f"{ATTESTATION_SERVICE_BASE_URL}/server/attest" params = { - "name": os.environ.get("HOSTNAME"), "nonce": nonce, "gpu_ids": os.environ.get("CHUTES_NVIDIA_DEVICES"), } - async with aiohttp.ClientSession( - connector=connector, raise_for_status=True - ) as http_session: + async with _attestation_session() as http_session: async with http_session.get(url, params=params) as resp: - logger.success("Successfully retrieved attestation evidence.") - evidence = json.loads(await resp.json()) - return nonce, evidence + evidence = await resp.json() + return {"evidence": evidence, "nonce": nonce} + + async def _get_verification_evidence(self, request: Request): + """ + TEE evidence for initial verification only. Called by the validator during + fetch_symmetric_key (Phase 2). Uses the nonce retrieved from the validator + when we started the process so we can prove we are the same instance. + """ + nonce = _get_evidence_nonce() + if nonce is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No nonce found. Attestation not initiated. Use /evidence?nonce=... for third-party evidence.", + ) + try: + logger.success("Retrieved attestation evidence for validator verification.") + return await self._fetch_evidence(nonce) + except Exception as e: + logger.error(f"Failed to fetch TEE evidence: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to fetch evidence: {str(e)}", + ) + + async def _get_runtime_evidence(self, request: Request): + """ + TEE evidence for third-party runtime verification. Caller supplies a nonce + via query param ?nonce=...; we request evidence bound to that nonce and return it. + """ + nonce = request.query_params.get("nonce") + if not nonce or not nonce.strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Query parameter 'nonce' is required for runtime evidence.", + ) + nonce = nonce.strip() + try: + logger.success("Retrieved attestation evidence.") + return await self._fetch_evidence(nonce) + except Exception as e: + logger.error(f"Failed to fetch TEE evidence: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to fetch evidence: {str(e)}", + ) + + +class TeeGpuVerifier(GpuVerifier): + @property + @lru_cache(maxsize=1) + def validator_url(self) -> str: + parsed = urlparse(self._url) + return f"{parsed.scheme}://{parsed.netloc}" + + @property + @lru_cache(maxsize=1) + def deployment_id(self) -> str: + hostname = os.environ.get("HOSTNAME") + # Pod name format: chute-{deployment_id}-{k8s-suffix} + # Service name format: chute-service-{deployment_id} + # We need to extract just the deployment_id by removing the prefix and k8s suffix + if not hostname.startswith("chute-"): + raise ValueError(f"Unexpected hostname format: {hostname}") + # Remove 'chute-' prefix + _deployment_id = hostname[6:] # len("chute-") = 6 + # Remove k8s-generated pod suffix (everything after the last hyphen) + _deployment_id = _deployment_id.rsplit("-", 1)[0] + return _deployment_id - async def verify_devices(self): + async def fetch_symmetric_key(self): + """ + TEE verification flow (3 phases): + Phase 1: Get nonce from validator (evidence server is already running at chute startup) + Phase 2: Validator calls our /_tee_evidence endpoint, we fetch evidence from attestation service, + validator verifies and returns symmetric key + Phase 3: Start dummy sockets and finalize verification (handled by base class) + """ token = self._token - url = urljoin(f"{self._url}/", "attest") - body = self._body + + # Gather GPUs before sending request + gpus = await self.gather_gpus() + # Append /tee to instance path; urljoin(base, "/tee") replaces path, urljoin(base, "tee") replaces last segment + url = urljoin(self._url + "/", "tee") + + async with _use_evidence_nonce(self.validator_url) as _nonce: + async with aiohttp.ClientSession(raise_for_status=True) as session: + headers = { + "Authorization": token, + "X-Chutes-Nonce": _nonce + } + _body = self._body.copy() + _body["deployment_id"] = self.deployment_id + _body["gpus"] = gpus + logger.info(f"Requesting verification from validator: {url}") + async with session.post(url, headers=headers, json=_body) as resp: + data = await resp.json() + self._symmetric_key = bytes.fromhex(data["symmetric_key"]) + self._validator_pubkey = data["validator_pubkey"] if "validator_pubkey" in data else None + logger.success("Successfully received symmetric key from validator") - body["gpus"] = await self.gather_gpus() - nonce, evidence = await self._get_gpu_evidence() - body["gpu_evidence"] = evidence - async with aiohttp.ClientSession(raise_for_status=True) as session: - headers = {"Authorization": token, "X-Chutes-Nonce": nonce} - logger.info(f"Collected all environment data, submitting to validator: {url}") - async with session.post(url, headers=headers, json=body) as resp: - logger.info("Successfully verified instance with validator.") - data = await resp.json() - symmetric_key = bytes.fromhex(data["symmetric_key"]) - return symmetric_key, data + async def finalize_verification(self): + """ + Send final verification with port mappings (same as GraVal flow). + """ + if not self._symmetric_key: + raise RuntimeError("Symmetric key must be fetched before finalizing verification.") + + token = self._token + # Append /tee to instance path; urljoin(base, "/tee") would replace path with /tee (RFC 3986) + url = urljoin(self._url + "/", "tee") + + async with aiohttp.ClientSession(raise_for_status=False) as session: + logger.info("Sending final verification request.") + + async with session.put( + url, + headers={"Authorization": token}, + json={}, + raise_for_status=False, + ) as resp: + if resp.ok: + logger.success("Successfully completed final verification!") + response = await resp.json() + if self._validator_pubkey: + response["validator_pubkey"] = self._validator_pubkey + return response + else: + detail = await resp.text(encoding="utf-8", errors="replace") + logger.error(f"Final verification failed: {resp.reason} ({resp.status}) {detail}") + resp.raise_for_status() async def gather_gpus(self): devices = [] - async with self._attestation_session() as http_session: - url = "https://attestation-service-internal.attestation-system.svc.cluster.local.:8443/server/devices" + async with _attestation_session() as http_session: + url = f"{ATTESTATION_SERVICE_BASE_URL}/server/devices" params = {"gpu_ids": os.environ.get("CHUTES_NVIDIA_DEVICES")} async with http_session.get(url=url, params=params) as resp: devices = await resp.json() logger.success(f"Retrieved {len(devices)} GPUs.") return devices + + +def start_dummy_socket(port_mapping, symmetric_key): + """ + Start a dummy socket based on the port mapping configuration to validate ports. + """ + proto = port_mapping["proto"].lower() + internal_port = port_mapping["internal_port"] + response_text = f"response from {proto} {internal_port}" + if proto in ["tcp", "http"]: + return start_tcp_dummy(internal_port, symmetric_key, response_text) + return start_udp_dummy(internal_port, symmetric_key, response_text) + + +def start_tcp_dummy(port, symmetric_key, response_plaintext): + """ + TCP port check socket. + """ + + def tcp_handler(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("0.0.0.0", port)) + sock.listen(1) + logger.info(f"TCP socket listening on port {port}") + conn, addr = sock.accept() + logger.info(f"TCP connection from {addr}") + data = conn.recv(1024) + logger.info(f"TCP received: {data.decode('utf-8', errors='ignore')}") + iv, encrypted_response = encrypt_response(symmetric_key, response_plaintext) + full_response = f"{iv.hex()}|{encrypted_response}".encode() + conn.send(full_response) + logger.info(f"TCP sent encrypted response on port {port}: {full_response=}") + conn.close() + except Exception as e: + logger.info(f"TCP socket error on port {port}: {e}") + raise + finally: + sock.close() + logger.info(f"TCP socket on port {port} closed") + + thread = threading.Thread(target=tcp_handler, daemon=True) + thread.start() + return thread + + +def start_udp_dummy(port, symmetric_key, response_plaintext): + """ + UDP port check socket. + """ + + def udp_handler(): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind(("0.0.0.0", port)) + logger.info(f"UDP socket listening on port {port}") + data, addr = sock.recvfrom(1024) + logger.info(f"UDP received from {addr}: {data.decode('utf-8', errors='ignore')}") + iv, encrypted_response = encrypt_response(symmetric_key, response_plaintext) + full_response = f"{iv.hex()}|{encrypted_response}".encode() + sock.sendto(full_response, addr) + logger.info(f"UDP sent encrypted response on port {port}") + except Exception as e: + logger.info(f"UDP socket error on port {port}: {e}") + raise + finally: + sock.close() + logger.info(f"UDP socket on port {port} closed") + + thread = threading.Thread(target=udp_handler, daemon=True) + thread.start() + return thread + + diff --git a/chutes/entrypoint/warmup.py b/chutes/entrypoint/warmup.py index 190b0fd..aeb058d 100644 --- a/chutes/entrypoint/warmup.py +++ b/chutes/entrypoint/warmup.py @@ -6,6 +6,8 @@ import asyncio import aiohttp import orjson as json +import sys +import time from loguru import logger import typer from chutes.config import get_config @@ -13,6 +15,248 @@ from chutes.util.auth import sign_request +async def poll_for_instances(chute_name: str, config, headers, poll_interval: float = 2.0, max_wait: float = 600.0): + """ + Poll for instances of a chute. Returns a list of instance info dicts (with instance_id and status). + + Args: + chute_name: Name or ID of the chute + config: Config object + headers: Request headers + poll_interval: Seconds between polls + max_wait: Maximum seconds to wait for an instance (default 10 minutes) + + Returns: + List of dicts with 'instance_id' and status information + """ + start_time = time.time() + async with aiohttp.ClientSession(base_url=config.generic.api_base_url) as session: + while time.time() - start_time < max_wait: + try: + async with session.get( + f"/chutes/{chute_name}", + headers=headers, + ) as response: + if response.status == 200: + data = await response.json() + instances = data.get("instances", []) + if instances: + # Return all instances with their status info + instance_infos = [ + { + "instance_id": inst.get("instance_id"), + "active": inst.get("active", False), + "verified": inst.get("verified", False), + "region": inst.get("region", "n/a"), + "last_verified_at": inst.get("last_verified_at"), + } + for inst in instances + if inst.get("instance_id") + ] + if instance_infos: + return instance_infos + elif response.status == 404: + # Chute doesn't exist - this is an error, not a polling condition + error_text = await response.text() + raise ValueError(f"Chute '{chute_name}' not found: {error_text}") + else: + error_text = await response.text() + logger.debug(f"Failed to get chute (status {response.status}): {error_text}") + except ValueError: + # Re-raise ValueError (chute not found) immediately + raise + except Exception as e: + logger.debug(f"Error polling for instances: {e}") + + await asyncio.sleep(poll_interval) + + # Timeout reached + raise TimeoutError(f"No instances found for chute {chute_name} within {max_wait} seconds") + + +async def stream_instance_logs(instance_id: str, config, backfill: int = 100): + """ + Stream logs from an instance. + + Raises: + aiohttp.ClientResponseError: If the request fails + Exception: For other streaming errors + """ + # Sign the request with purpose for instances endpoint + headers, _ = sign_request(purpose="logs") + async with aiohttp.ClientSession(base_url=config.generic.api_base_url) as session: + async with session.get( + f"/instances/{instance_id}/logs", + headers=headers, + params={"backfill": str(backfill)}, + ) as response: + if response.status != 200: + error_text = await response.text() + raise aiohttp.ClientResponseError( + request_info=response.request_info, + history=response.history, + status=response.status, + message=f"Failed to stream logs from instance {instance_id}: {error_text}", + ) + + logger.info(f"Streaming logs from instance {instance_id}...") + # Parse SSE format and extract log content + buffer = b"" + skipped_lines_count = 0 + try: + async for chunk in response.content.iter_any(): + if chunk: + buffer += chunk + # Process complete lines + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + # Skip empty lines completely + if not line.strip(): + skipped_lines_count += 1 + continue + # Skip SSE comment lines (lines starting with :) + if line.startswith(b":"): + skipped_lines_count += 1 + continue + # Only process SSE data lines - ignore everything else + if line.startswith(b"data: "): + # Check for empty data: lines (just "data: " or "data:") + data_content = line[6:].strip() if len(line) > 6 else b"" + if not data_content: + skipped_lines_count += 1 + continue + try: + # Parse JSON from SSE data line + data = json.loads(data_content) + log_message = data.get("log", "") + # Only output if we have actual non-empty log content + # Also check that it's not just a single character like "." + if log_message and log_message.strip() and len(log_message.strip()) > 1: + # Write just the log message with a newline + sys.stdout.buffer.write(log_message.encode("utf-8") + b"\n") + sys.stdout.buffer.flush() + elif log_message and log_message.strip(): + # Single character - likely a keepalive, skip it + skipped_lines_count += 1 + logger.debug(f"Skipping single-character log message (likely keepalive): {repr(log_message)}") + else: + skipped_lines_count += 1 + except (json.JSONDecodeError, KeyError) as e: + # Log what we're skipping for debugging + logger.debug(f"Skipping unparseable SSE line: {line[:100]}, error: {e}") + skipped_lines_count += 1 + else: + # Skip non-SSE lines silently (likely keepalive messages) + skipped_lines_count += 1 + except asyncio.CancelledError: + raise + except (aiohttp.ClientError, ConnectionError, OSError) as e: + # Connection errors - instance might have been deleted + raise Exception(f"Stream ended for instance {instance_id} (instance may have been deleted): {e}") from e + except Exception as e: + # Re-raise to allow caller to handle (e.g., try another instance) + raise Exception(f"Error streaming logs from instance {instance_id}: {e}") from e + + +async def monitor_warmup(chute_name: str, config, headers): + """ + Monitor the warmup stream and log status updates. + """ + async with aiohttp.ClientSession(base_url=config.generic.api_base_url) as session: + async with session.get( + f"/chutes/warmup/{chute_name}", + headers=headers, + ) as response: + if response.status == 200: + async for raw_chunk in response.content: + if raw_chunk.startswith(b"data:"): + chunk = json.loads(raw_chunk[5:]) + if chunk["status"] == "hot": + logger.success(chunk["log"]) + else: + logger.warning(f"Status: {chunk['status']} -- {chunk['log']}") + else: + logger.error(await response.text()) + + +async def poll_and_stream_logs(chute_name: str, config, headers): + """ + Poll for instances and stream logs when found. Tries multiple instances if one fails. + """ + try: + instance_infos = await poll_for_instances(chute_name, config, headers) + logger.info(f"Found {len(instance_infos)} instance(s), attempting to stream logs...") + + + # Try each instance until one works + # Keep trying instances even if they get deleted during streaming + tried_instance_ids = set() + max_retries = 10 # Limit retries to avoid infinite loops + + for attempt in range(max_retries): + # Refresh instance list in case instances were deleted + try: + current_instance_infos = await poll_for_instances(chute_name, config, headers, poll_interval=1.0, max_wait=5.0) + # Filter out instances we've already tried + available_instances = [ + inst for inst in current_instance_infos + if inst["instance_id"] not in tried_instance_ids + ] + + if not available_instances: + # No new instances available + if tried_instance_ids: + if attempt < max_retries - 1: + logger.info("All instances have been tried or deleted. Waiting for new instances...") + await asyncio.sleep(2.0) + continue + else: + raise Exception("No new instances available after retries") + else: + raise Exception("No instances available for log streaming") + + # Try the first available instance + inst_info = available_instances[0] + instance_id = inst_info["instance_id"] + tried_instance_ids.add(instance_id) + + logger.info(f"Attempting to stream logs from instance {instance_id}...") + + try: + # Stream logs - this will continue until interrupted or stream ends + await stream_instance_logs(instance_id, config) + # If we get here, streaming completed successfully (unlikely for a stream) + return + except asyncio.CancelledError: + # User interrupted, don't try other instances + raise + except Exception as e: + # Stream ended (instance deleted, connection error, etc.) + logger.warning(f"Stream ended for instance {instance_id}: {e}") + logger.info("Looking for another instance to stream from...") + # Continue the loop to try another instance + continue + except TimeoutError: + # No instances found when refreshing + if tried_instance_ids and attempt < max_retries - 1: + logger.info("No new instances found. Waiting...") + await asyncio.sleep(2.0) + continue + else: + raise + + # Exhausted retries + raise Exception(f"Failed to maintain stream after {max_retries} attempts") + + except asyncio.CancelledError: + pass + except TimeoutError as e: + logger.warning(str(e)) + except Exception as e: + logger.error(f"Error in poll_and_stream_logs: {e}") + raise + + def warmup_chute( chute_id_or_ref_str: str = typer.Argument( ..., @@ -22,12 +266,13 @@ def warmup_chute( None, help="Custom path to the chutes config (credentials, API URL, etc.)" ), debug: bool = typer.Option(False, help="enable debug logging"), + stream_logs: bool = typer.Option(False, help="automatically stream logs from the first instance that appears"), ): async def warmup(): """ Do the warmup. """ - nonlocal chute_id_or_ref_str, config_path, debug + nonlocal chute_id_or_ref_str, config_path, debug, stream_logs chute_name = chute_id_or_ref_str if ":" in chute_id_or_ref_str and os.path.exists(chute_id_or_ref_str.split(":")[0] + ".py"): from chutes.chute.base import Chute @@ -38,20 +283,37 @@ async def warmup(): if config_path: os.environ["CHUTES_CONFIG_PATH"] = config_path headers, _ = sign_request(purpose="chutes") - async with aiohttp.ClientSession(base_url=config.generic.api_base_url) as session: - async with session.get( - f"/chutes/warmup/{chute_name}", - headers=headers, - ) as response: - if response.status == 200: - async for raw_chunk in response.content: - if raw_chunk.startswith(b"data:"): - chunk = json.loads(raw_chunk[5:]) - if chunk["status"] == "hot": - logger.success(chunk["log"]) - else: - logger.warning(f"Status: {chunk['status']} -- {chunk['log']}") - else: - logger.error(await response.text()) + + if stream_logs: + # Run warmup monitoring and log streaming in parallel + warmup_task = asyncio.create_task(monitor_warmup(chute_name, config, headers)) + poll_task = asyncio.create_task(poll_and_stream_logs(chute_name, config, headers)) + + # Wait for both tasks, but log streaming should continue even after warmup completes + try: + # Wait for warmup to complete (or be cancelled) + try: + await warmup_task + except Exception as e: + logger.debug(f"Warmup task ended: {e}") + + # Log streaming continues independently - wait for it or user interrupt + try: + await poll_task + except asyncio.CancelledError: + pass + except Exception as e: + logger.debug(f"Poll task ended: {e}") + except KeyboardInterrupt: + warmup_task.cancel() + poll_task.cancel() + try: + await asyncio.gather(warmup_task, poll_task, return_exceptions=True) + except Exception: + pass + raise + else: + # Just monitor warmup + await monitor_warmup(chute_name, config, headers) return asyncio.run(warmup())