From 521a1a47755093690243e9730ffbbb8db0727f62 Mon Sep 17 00:00:00 2001 From: matt-bernstein <60152561+matt-bernstein@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:02:06 -0500 Subject: [PATCH] fix: DIA-1619: logging and graceful handling for memory leaks (#253) --- server/app.py | 1 - server/tasks/stream_inference.py | 29 ++++++++++++++++++++++++++++- server/utils.py | 2 ++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/server/app.py b/server/app.py index 89e6a12b..c868fffe 100644 --- a/server/app.py +++ b/server/app.py @@ -209,7 +209,6 @@ async def submit_batch(batch: BatchData): f"Size of batch in bytes received for job_id:{batch.job_id} batch_size:{batch_size}" ) except UnknownTopicOrPartitionError: - await producer.stop() raise HTTPException( status_code=500, detail=f"{topic=} for job {batch.job_id} not found" ) diff --git a/server/tasks/stream_inference.py b/server/tasks/stream_inference.py index 468fdb99..82c565c0 100644 --- a/server/tasks/stream_inference.py +++ b/server/tasks/stream_inference.py @@ -1,6 +1,7 @@ import asyncio import json import os +import psutil import time import traceback @@ -9,6 +10,7 @@ from aiokafka import AIOKafkaConsumer from aiokafka.errors import UnknownTopicOrPartitionError from celery import Celery +from celery.signals import worker_process_shutdown, worker_process_init from server.handlers.result_handlers import ResultHandler from server.utils import ( Settings, @@ -21,6 +23,8 @@ logger = init_logger(__name__) +settings = Settings() + REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") app = Celery( "worker", @@ -28,9 +32,30 @@ backend=REDIS_URL, accept_content=["json", "pickle"], broker_connection_retry_on_startup=True, + worker_max_memory_per_child=settings.celery_worker_max_memory_per_child_kb, ) -settings = Settings() + +@worker_process_init.connect +def worker_process_init_handler(**kwargs): + """Called when a worker process starts.""" + process = psutil.Process() + mem_info = process.memory_info() + logger.info( + f"Worker process starting. PID: {os.getpid()}, " + f"Memory RSS: {mem_info.rss / 1024 / 1024:.2f}MB" + ) + + +@worker_process_shutdown.connect +def worker_process_shutdown_handler(**kwargs): + """Called when a worker process shuts down.""" + process = psutil.Process() + mem_info = process.memory_info() + logger.info( + f"Worker process shutting down. PID: {os.getpid()}, " + f"Memory RSS: {mem_info.rss / 1024 / 1024:.2f}MB" + ) def parent_job_error_handler(self, exc, task_id, args, kwargs, einfo): @@ -69,6 +94,8 @@ async def run_streaming( serializer="pickle", on_failure=parent_job_error_handler, task_time_limit=settings.task_time_limit_sec, + task_ignore_result=True, + task_store_errors_even_if_ignored=True, ) def streaming_parent_task( self, agent: Agent, result_handler: ResultHandler, batch_size: int = 1 diff --git a/server/utils.py b/server/utils.py index 2a2fcad6..c40b290d 100644 --- a/server/utils.py +++ b/server/utils.py @@ -22,6 +22,8 @@ class Settings(BaseSettings): kafka_input_consumer_timeout_ms: int = 1500 # 1.5 seconds kafka_output_consumer_timeout_ms: int = 1500 # 1.5 seconds task_time_limit_sec: int = 60 * 60 * 6 # 6 hours + # https://docs.celeryq.dev/en/v5.4.0/userguide/configuration.html#worker-max-memory-per-child + celery_worker_max_memory_per_child_kb: int = 1024000 # 1GB model_config = SettingsConfigDict( # have to use an absolute path here so celery workers can find it