From e8bf696494ec406f6020b08a22b05b181b2cbd84 Mon Sep 17 00:00:00 2001 From: dmitry krokhin Date: Sat, 23 Sep 2023 17:10:30 +0300 Subject: [PATCH] worker terminate signal handler --- sharded_queue/__init__.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sharded_queue/__init__.py b/sharded_queue/__init__.py index 29a6db4..0887856 100644 --- a/sharded_queue/__init__.py +++ b/sharded_queue/__init__.py @@ -1,10 +1,10 @@ -from asyncio import ensure_future, get_event_loop, sleep +from asyncio import all_tasks, current_task, ensure_future, gather, get_event_loop, sleep from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime, timedelta from functools import cache, partial from importlib import import_module -from signal import SIGTERM +from signal import SIGKILL, SIGTERM from typing import (Any, AsyncGenerator, Generic, NamedTuple, Optional, Self, TypeVar, get_type_hints) @@ -179,12 +179,19 @@ async def loop( processed = 0 while True and limit is None or limit > processed: tube = await self.acquire_tube(handler) - loop.add_signal_handler(SIGTERM, partial( - ensure_future, partial(self.lock.release, tube.pipe) - )) + loop.add_signal_handler(SIGTERM, partial(self.stop, tube.pipe)) processed = processed + await self.process(tube, limit) loop.remove_signal_handler(SIGTERM) + def stop(self, pipe: str) -> None: + get_event_loop().create_task(self.shutdown_worker(pipe)) + + async def shutdown_worker(self, pipe: str) -> None: + await self.lock.release(pipe) + tasks = [task for task in all_tasks() if task is not current_task()] + [task.cancel() for task in tasks] + await gather(*tasks) + async def process(self, tube: Tube, limit: Optional[int] = None) -> int: deserialize = self.queue.serializer.deserialize storage = self.queue.storage