From b69ec97bb32b9e6365a331c4fd40d32a294307e2 Mon Sep 17 00:00:00 2001 From: abelanger5 Date: Thu, 20 Jun 2024 18:25:54 -0400 Subject: [PATCH] fix: make workflow run listener more resilient (#56) * fix: update requirements to indicate python version >=3.10 * fix: improve the workflow run listener * chore: linting --- hatchet_sdk/clients/dispatcher.py | 30 +--- hatchet_sdk/clients/event_ts.py | 25 +++ hatchet_sdk/clients/workflow_listener.py | 196 ++++++++++++++++++----- pyproject.toml | 2 +- 4 files changed, 188 insertions(+), 65 deletions(-) create mode 100644 hatchet_sdk/clients/event_ts.py diff --git a/hatchet_sdk/clients/dispatcher.py b/hatchet_sdk/clients/dispatcher.py index 3eadf2aa..488744dd 100644 --- a/hatchet_sdk/clients/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher.py @@ -4,11 +4,12 @@ import random import threading import time -from typing import Any, AsyncGenerator, Callable, List, Union +from typing import Any, AsyncGenerator, List import grpc from grpc._cython import cygrpc +from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.connection import new_conn from ..dispatcher_pb2 import ( @@ -110,25 +111,6 @@ def unregister(self): START_GET_GROUP_KEY = 2 -class Event_ts(asyncio.Event): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self._loop is None: - self._loop = asyncio.get_event_loop() - - def set(self): - self._loop.call_soon_threadsafe(super().set) - - def clear(self): - self._loop.call_soon_threadsafe(super().clear) - - -async def read_action(listener: Any, interrupt: Event_ts): - assigned_action = await listener.read() - interrupt.set() - return assigned_action - - async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5): base_time = 0.1 # starting sleep time in seconds (100 milliseconds) jitter = random.uniform(0, base_time) # add random jitter @@ -220,12 +202,16 @@ async def _generator(self) -> AsyncGenerator[Action, None]: try: while True: self.interrupt = Event_ts() - t = asyncio.create_task(read_action(listener, self.interrupt)) + t = asyncio.create_task( + read_with_interrupt(listener, self.interrupt) + ) await self.interrupt.wait() if not t.done(): # print a warning - logger.warning("Interrupted read_action task") + logger.warning( + "Interrupted read_with_interrupt task of action listener" + ) t.cancel() listener.cancel() diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py new file mode 100644 index 00000000..e4415f69 --- /dev/null +++ b/hatchet_sdk/clients/event_ts.py @@ -0,0 +1,25 @@ +import asyncio +from typing import Any + + +class Event_ts(asyncio.Event): + """ + Event_ts is a subclass of asyncio.Event that allows for thread-safe setting and clearing of the event. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self._loop is None: + self._loop = asyncio.get_event_loop() + + def set(self): + self._loop.call_soon_threadsafe(super().set) + + def clear(self): + self._loop.call_soon_threadsafe(super().clear) + + +async def read_with_interrupt(listener: Any, interrupt: Event_ts): + result = await listener.read() + interrupt.set() + return result diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index c7ded1d2..f04bf158 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -1,10 +1,12 @@ import asyncio import json +import time from collections.abc import AsyncIterator from typing import AsyncGenerator import grpc +from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.connection import new_conn from ..dispatcher_pb2 import SubscribeToWorkflowRunsRequest, WorkflowRunEvent @@ -15,14 +17,52 @@ DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL = 1 # seconds DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT = 5 +DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL = 1800 # 30 minutes + + +class _Subscription: + def __init__(self, id: int, workflow_run_id: str): + self.id = id + self.workflow_run_id = workflow_run_id + self.queue: asyncio.Queue[WorkflowRunEvent | None] = asyncio.Queue() + + async def __aiter__(self): + return self + + async def __anext__(self) -> WorkflowRunEvent: + return await self.queue.get() + + async def get(self) -> WorkflowRunEvent: + event = await self.queue.get() + + if event is None: + raise StopAsyncIteration + + return event + + async def put(self, item: WorkflowRunEvent): + await self.queue.put(item) + + async def close(self): + await self.queue.put(None) class PooledWorkflowRunListener: + # list of all active subscriptions, mapping from a subscription id to a workflow run id + subscriptionsToWorkflows: dict[int, str] = {} + + # list of workflow run ids mapped to an array of subscription ids + workflowsToSubscriptions: dict[str, list[int]] = {} + + subscription_counter: int = 0 + subscription_counter_lock: asyncio.Lock = asyncio.Lock() + requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() listener: AsyncGenerator[WorkflowRunEvent, None] = None - events: dict[str, asyncio.Queue[WorkflowRunEvent]] = {} + # events have keys of the format workflow_run_id + subscription_id + events: dict[int, _Subscription] = {} def __init__(self, config: ClientConfig): conn = new_conn(config, True) @@ -35,30 +75,77 @@ def abort(self): self.stop_signal = True self.requests.put_nowait(False) + async def _interrupter(self): + """ + _interrupter runs in a separate thread and interrupts the listener according to a configurable duration. + """ + await asyncio.sleep(DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL) + + if self.interrupt is not None: + self.interrupt.set() + async def _init_producer(self): try: if not self.listener: - self.listener = await self._retry_subscribe() - logger.debug(f"Workflow run listener connected.") - async for workflow_event in self.listener: - if workflow_event.workflowRunId in self.events: - self.events[workflow_event.workflowRunId].put_nowait( - workflow_event - ) - else: - logger.warning( - f"Received event for unknown workflow: {workflow_event.workflowRunId}" - ) + while True: + try: + self.listener = await self._retry_subscribe() + + logger.debug(f"Workflow run listener connected.") + + # spawn an interrupter task + asyncio.create_task(self._interrupter()) + + while True: + self.interrupt = Event_ts() + t = asyncio.create_task( + read_with_interrupt(self.listener, self.interrupt) + ) + await self.interrupt.wait() + + if not t.done(): + # print a warning + logger.warning( + "Interrupted read_with_interrupt task of workflow run listener" + ) + + t.cancel() + self.listener.cancel() + break + + workflow_event: WorkflowRunEvent = t.result() + + # get a list of subscriptions for this workflow + subscriptions = self.workflowsToSubscriptions.get( + workflow_event.workflowRunId, [] + ) + + for subscription_id in subscriptions: + await self.events[subscription_id].put(workflow_event) + + except grpc.RpcError as e: + logger.error(f"grpc error in workflow run listener: {e}") + continue except Exception as e: logger.error(f"Error in workflow run listener: {e}") + self.listener = None - # signal all subscribers to stop - # FIXME this is a bit of a hack, ideally we re-establish the listener and re-subscribe - for key in self.events.keys(): - self.events[key].put_nowait(False) + # close all subscriptions + for subscription_id in self.events: + await self.events[subscription_id].close() + + raise e async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]: + # replay all existing subscriptions + workflow_run_set = set(self.subscriptionsToWorkflows.values()) + + for workflow_run_id in workflow_run_set: + yield SubscribeToWorkflowRunsRequest( + workflowRunId=workflow_run_id, + ) + while True: request = await self.requests.get() @@ -69,44 +156,69 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]: yield request self.requests.task_done() + def cleanup_subscription(self, subscription_id: int): + workflow_run_id = self.subscriptionsToWorkflows[subscription_id] + + if workflow_run_id in self.workflowsToSubscriptions: + self.workflowsToSubscriptions[workflow_run_id].remove(subscription_id) + + del self.subscriptionsToWorkflows[subscription_id] + del self.events[subscription_id] + async def subscribe(self, workflow_run_id: str): - self.events[workflow_run_id] = asyncio.Queue() + try: + # create a new subscription id, place a mutex on the counter + await self.subscription_counter_lock.acquire() + self.subscription_counter += 1 + subscription_id = self.subscription_counter + self.subscription_counter_lock.release() - asyncio.create_task(self._init_producer()) + self.subscriptionsToWorkflows[subscription_id] = workflow_run_id - await self.requests.put( - SubscribeToWorkflowRunsRequest( - workflowRunId=workflow_run_id, + if workflow_run_id not in self.workflowsToSubscriptions: + self.workflowsToSubscriptions[workflow_run_id] = [subscription_id] + else: + self.workflowsToSubscriptions[workflow_run_id].append(subscription_id) + + self.events[subscription_id] = _Subscription( + subscription_id, workflow_run_id ) - ) - while True: - event = await self.events[workflow_run_id].get() - if event is False: - break - if event.workflowRunId == workflow_run_id: - yield event - break # FIXME this should only break on terminal events... but we're not broadcasting event types + asyncio.create_task(self._init_producer()) + + await self.requests.put( + SubscribeToWorkflowRunsRequest( + workflowRunId=workflow_run_id, + ) + ) + + event = await self.events[subscription_id].get() - del self.events[workflow_run_id] + self.cleanup_subscription(subscription_id) + + return event + except asyncio.CancelledError: + self.cleanup_subscription(subscription_id) + raise async def result(self, workflow_run_id: str): - async for event in self.subscribe(workflow_run_id): - errors = [] + event = await self.subscribe(workflow_run_id) + + errors = [] - if event.results: - errors = [result.error for result in event.results if result.error] + if event.results: + errors = [result.error for result in event.results if result.error] - if errors: - raise Exception(f"Workflow Errors: {errors}") + if errors: + raise Exception(f"Workflow Errors: {errors}") - results = { - result.stepReadableId: json.loads(result.output) - for result in event.results - if result.output - } + results = { + result.stepReadableId: json.loads(result.output) + for result in event.results + if result.output + } - return results + return results async def _retry_subscribe(self): retries = 0 diff --git a/pyproject.toml b/pyproject.toml index ae4ff1cb..4a7842cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "0.27.0" +version = "0.27.1" description = "" authors = ["Alexander Belanger "] readme = "README.md"