Skip to content

Commit

Permalink
fix: DIA-1619: logging and graceful handling for memory leaks (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein authored Nov 18, 2024
1 parent 2e56218 commit 521a1a4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
1 change: 0 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ async def submit_batch(batch: BatchData):
f"Size of batch in bytes received for job_id:{batch.job_id} batch_size:{batch_size}"
)
except UnknownTopicOrPartitionError:
await producer.stop()
raise HTTPException(
status_code=500, detail=f"{topic=} for job {batch.job_id} not found"
)
Expand Down
29 changes: 28 additions & 1 deletion server/tasks/stream_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import os
import psutil
import time
import traceback

Expand All @@ -9,6 +10,7 @@
from aiokafka import AIOKafkaConsumer
from aiokafka.errors import UnknownTopicOrPartitionError
from celery import Celery
from celery.signals import worker_process_shutdown, worker_process_init
from server.handlers.result_handlers import ResultHandler
from server.utils import (
Settings,
Expand All @@ -21,16 +23,39 @@

logger = init_logger(__name__)

settings = Settings()

REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
app = Celery(
"worker",
broker=REDIS_URL,
backend=REDIS_URL,
accept_content=["json", "pickle"],
broker_connection_retry_on_startup=True,
worker_max_memory_per_child=settings.celery_worker_max_memory_per_child_kb,
)

settings = Settings()

@worker_process_init.connect
def worker_process_init_handler(**kwargs):
"""Called when a worker process starts."""
process = psutil.Process()
mem_info = process.memory_info()
logger.info(
f"Worker process starting. PID: {os.getpid()}, "
f"Memory RSS: {mem_info.rss / 1024 / 1024:.2f}MB"
)


@worker_process_shutdown.connect
def worker_process_shutdown_handler(**kwargs):
"""Called when a worker process shuts down."""
process = psutil.Process()
mem_info = process.memory_info()
logger.info(
f"Worker process shutting down. PID: {os.getpid()}, "
f"Memory RSS: {mem_info.rss / 1024 / 1024:.2f}MB"
)


def parent_job_error_handler(self, exc, task_id, args, kwargs, einfo):
Expand Down Expand Up @@ -69,6 +94,8 @@ async def run_streaming(
serializer="pickle",
on_failure=parent_job_error_handler,
task_time_limit=settings.task_time_limit_sec,
task_ignore_result=True,
task_store_errors_even_if_ignored=True,
)
def streaming_parent_task(
self, agent: Agent, result_handler: ResultHandler, batch_size: int = 1
Expand Down
2 changes: 2 additions & 0 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class Settings(BaseSettings):
kafka_input_consumer_timeout_ms: int = 1500 # 1.5 seconds
kafka_output_consumer_timeout_ms: int = 1500 # 1.5 seconds
task_time_limit_sec: int = 60 * 60 * 6 # 6 hours
# https://docs.celeryq.dev/en/v5.4.0/userguide/configuration.html#worker-max-memory-per-child
celery_worker_max_memory_per_child_kb: int = 1024000 # 1GB

model_config = SettingsConfigDict(
# have to use an absolute path here so celery workers can find it
Expand Down

0 comments on commit 521a1a4

Please sign in to comment.