diff --git a/chutes/chute/template/embedding.py b/chutes/chute/template/embedding.py index c9fa9e0..c131c8a 100644 --- a/chutes/chute/template/embedding.py +++ b/chutes/chute/template/embedding.py @@ -259,6 +259,43 @@ async def monitor_subprocess(): pass await asyncio.sleep(1) + @chute.on_event("shutdown") + async def cleanup_vllm(self): + """ + Cleanup vLLM embedding subprocess and monitor task on shutdown. + """ + logger.info("Cleaning up vLLM embedding subprocess...") + try: + # Cancel the monitor task + if hasattr(self, "_monitor_task") and self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + + # Terminate the subprocess gracefully + if hasattr(self, "_vllm_process") and self._vllm_process: + try: + self._vllm_process.terminate() + # Wait up to 5 seconds for graceful shutdown + try: + self._vllm_process.wait(timeout=5) + logger.info("vLLM embedding subprocess terminated gracefully") + except subprocess.TimeoutExpired: + # Force kill if it doesn't terminate + logger.warning("vLLM embedding subprocess did not terminate, forcing kill") + self._vllm_process.kill() + self._vllm_process.wait() + logger.info("vLLM embedding subprocess force killed") + except ProcessLookupError: + # Process already terminated + logger.debug("vLLM embedding subprocess already terminated") + except Exception as e: + logger.error(f"Error cleaning up vLLM embedding subprocess: {e}") + except Exception as e: + logger.error(f"Unexpected error during vLLM embedding cleanup: {e}") + self.passthrough_headers["Authorization"] = f"Bearer {api_key}" logger.info("✅ Embedding server initialized successfully!") diff --git a/chutes/chute/template/sglang.py b/chutes/chute/template/sglang.py index d04f775..060cab1 100644 --- a/chutes/chute/template/sglang.py +++ b/chutes/chute/template/sglang.py @@ -3,6 +3,7 @@ import os import re import sys +import subprocess import uuid from enum import Enum from loguru import logger @@ -371,6 +372,43 @@ async def monitor_subprocess(): await warmup_model(self, api_key=api_key) await validate_auth(self, api_key=api_key) + @chute.on_event("shutdown") + async def cleanup_sglang(self): + """ + Cleanup SGLang subprocess and monitor task on shutdown. + """ + logger.info("Cleaning up SGLang subprocess...") + try: + # Cancel the monitor task + if hasattr(self, "_monitor_task") and self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + + # Terminate the subprocess gracefully + if hasattr(self, "_sglang_process") and self._sglang_process: + try: + self._sglang_process.terminate() + # Wait up to 5 seconds for graceful shutdown + try: + self._sglang_process.wait(timeout=5) + logger.info("SGLang subprocess terminated gracefully") + except subprocess.TimeoutExpired: + # Force kill if it doesn't terminate + logger.warning("SGLang subprocess did not terminate, forcing kill") + self._sglang_process.kill() + self._sglang_process.wait() + logger.info("SGLang subprocess force killed") + except ProcessLookupError: + # Process already terminated + logger.debug("SGLang subprocess already terminated") + except Exception as e: + logger.error(f"Error cleaning up SGLang subprocess: {e}") + except Exception as e: + logger.error(f"Unexpected error during SGLang cleanup: {e}") + def _parse_stream_chunk(encoded_chunk): chunk = encoded_chunk if isinstance(encoded_chunk, str) else encoded_chunk.decode() if "data: {" in chunk: diff --git a/chutes/chute/template/vllm.py b/chutes/chute/template/vllm.py index 8c22155..fd26dce 100644 --- a/chutes/chute/template/vllm.py +++ b/chutes/chute/template/vllm.py @@ -413,6 +413,43 @@ async def monitor_subprocess(): await validate_auth(self, base_url="http://127.0.0.1:10101", api_key=api_key) logger.info("✅ vLLM server warmed up and ready to roll!") + @chute.on_event("shutdown") + async def cleanup_vllm(self): + """ + Cleanup vLLM subprocess and monitor task on shutdown. + """ + logger.info("Cleaning up vLLM subprocess...") + try: + # Cancel the monitor task + if hasattr(self, "_monitor_task") and self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + + # Terminate the subprocess gracefully + if hasattr(self, "_vllm_process") and self._vllm_process: + try: + self._vllm_process.terminate() + # Wait up to 5 seconds for graceful shutdown + try: + self._vllm_process.wait(timeout=5) + logger.info("vLLM subprocess terminated gracefully") + except subprocess.TimeoutExpired: + # Force kill if it doesn't terminate + logger.warning("vLLM subprocess did not terminate, forcing kill") + self._vllm_process.kill() + self._vllm_process.wait() + logger.info("vLLM subprocess force killed") + except ProcessLookupError: + # Process already terminated + logger.debug("vLLM subprocess already terminated") + except Exception as e: + logger.error(f"Error cleaning up vLLM subprocess: {e}") + except Exception as e: + logger.error(f"Unexpected error during vLLM cleanup: {e}") + def _parse_stream_chunk(encoded_chunk): chunk = encoded_chunk if isinstance(encoded_chunk, str) else encoded_chunk.decode() if "data: {" in chunk: diff --git a/chutes/entrypoint/ssh.py b/chutes/entrypoint/ssh.py index d2d72b9..56ef5ee 100644 --- a/chutes/entrypoint/ssh.py +++ b/chutes/entrypoint/ssh.py @@ -61,8 +61,9 @@ async def setup_ssh_access(ssh_public_key): ], check=True, ) - subprocess.Popen(["/usr/sbin/sshd", "-D", "-f", sshd_config_path]) + sshd_process = subprocess.Popen(["/usr/sbin/sshd", "-D", "-f", sshd_config_path]) logger.info(f"SSH server started successfully on port 2202 for user {username}") + return sshd_process except Exception as e: logger.error(f"Failed to setup SSH access: {e}")