diff --git a/sharded_queue/__init__.py b/sharded_queue/__init__.py index 7861cbf..a213e77 100644 --- a/sharded_queue/__init__.py +++ b/sharded_queue/__init__.py @@ -132,6 +132,7 @@ async def register( class Worker: lock: Lock queue: Queue + pipe: Optional[str] = None async def acquire_tube( self, handler: Optional[type[Handler]] = None @@ -175,22 +176,19 @@ async def loop( limit: Optional[int] = None, handler: Optional[type[Handler]] = None, ) -> None: - loop = get_event_loop() + get_event_loop().add_signal_handler(SIGTERM, self.housekeep) processed = 0 while True and limit is None or limit > processed: tube = await self.acquire_tube(handler) - loop.add_signal_handler(SIGTERM, partial(self.stop, tube.pipe)) + self.pipe = tube.pipe processed = processed + await self.process(tube, limit) - loop.remove_signal_handler(SIGTERM) + self.pipe = None - def stop(self, pipe: str) -> None: - get_event_loop().create_task(self.shutdown_worker(pipe)) + get_event_loop().remove_signal_handler(SIGTERM) - async def shutdown_worker(self, pipe: str) -> None: - await self.lock.release(pipe) - tasks = [task for task in all_tasks() if task is not current_task()] - [task.cancel() for task in tasks] - await gather(*tasks) + def housekeep(self) -> None: + if self.pipe: + get_event_loop().create_task(self.lock.release(self.pipe)) async def process(self, tube: Tube, limit: Optional[int] = None) -> int: deserialize = self.queue.serializer.deserialize