Skip to content

Commit

Permalink
fix: make workflow run listener more resilient (#56)
Browse files Browse the repository at this point in the history
* fix: update requirements to indicate python version >=3.10

* fix: improve the workflow run listener

* chore: linting
  • Loading branch information
abelanger5 authored Jun 20, 2024
1 parent 464bc0f commit b69ec97
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 65 deletions.
30 changes: 8 additions & 22 deletions hatchet_sdk/clients/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import random
import threading
import time
from typing import Any, AsyncGenerator, Callable, List, Union
from typing import Any, AsyncGenerator, List

import grpc
from grpc._cython import cygrpc

from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
from hatchet_sdk.connection import new_conn

from ..dispatcher_pb2 import (
Expand Down Expand Up @@ -110,25 +111,6 @@ def unregister(self):
START_GET_GROUP_KEY = 2


class Event_ts(asyncio.Event):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._loop is None:
self._loop = asyncio.get_event_loop()

def set(self):
self._loop.call_soon_threadsafe(super().set)

def clear(self):
self._loop.call_soon_threadsafe(super().clear)


async def read_action(listener: Any, interrupt: Event_ts):
assigned_action = await listener.read()
interrupt.set()
return assigned_action


async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5):
base_time = 0.1 # starting sleep time in seconds (100 milliseconds)
jitter = random.uniform(0, base_time) # add random jitter
Expand Down Expand Up @@ -220,12 +202,16 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
try:
while True:
self.interrupt = Event_ts()
t = asyncio.create_task(read_action(listener, self.interrupt))
t = asyncio.create_task(
read_with_interrupt(listener, self.interrupt)
)
await self.interrupt.wait()

if not t.done():
# print a warning
logger.warning("Interrupted read_action task")
logger.warning(
"Interrupted read_with_interrupt task of action listener"
)

t.cancel()
listener.cancel()
Expand Down
25 changes: 25 additions & 0 deletions hatchet_sdk/clients/event_ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
from typing import Any


class Event_ts(asyncio.Event):
"""
Event_ts is a subclass of asyncio.Event that allows for thread-safe setting and clearing of the event.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._loop is None:
self._loop = asyncio.get_event_loop()

def set(self):
self._loop.call_soon_threadsafe(super().set)

def clear(self):
self._loop.call_soon_threadsafe(super().clear)


async def read_with_interrupt(listener: Any, interrupt: Event_ts):
result = await listener.read()
interrupt.set()
return result
196 changes: 154 additions & 42 deletions hatchet_sdk/clients/workflow_listener.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import json
import time
from collections.abc import AsyncIterator
from typing import AsyncGenerator

import grpc

from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
from hatchet_sdk.connection import new_conn

from ..dispatcher_pb2 import SubscribeToWorkflowRunsRequest, WorkflowRunEvent
Expand All @@ -15,14 +17,52 @@

DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL = 1 # seconds
DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT = 5
DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL = 1800 # 30 minutes


class _Subscription:
def __init__(self, id: int, workflow_run_id: str):
self.id = id
self.workflow_run_id = workflow_run_id
self.queue: asyncio.Queue[WorkflowRunEvent | None] = asyncio.Queue()

async def __aiter__(self):
return self

async def __anext__(self) -> WorkflowRunEvent:
return await self.queue.get()

async def get(self) -> WorkflowRunEvent:
event = await self.queue.get()

if event is None:
raise StopAsyncIteration

return event

async def put(self, item: WorkflowRunEvent):
await self.queue.put(item)

async def close(self):
await self.queue.put(None)


class PooledWorkflowRunListener:
# list of all active subscriptions, mapping from a subscription id to a workflow run id
subscriptionsToWorkflows: dict[int, str] = {}

# list of workflow run ids mapped to an array of subscription ids
workflowsToSubscriptions: dict[str, list[int]] = {}

subscription_counter: int = 0
subscription_counter_lock: asyncio.Lock = asyncio.Lock()

requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue()

listener: AsyncGenerator[WorkflowRunEvent, None] = None

events: dict[str, asyncio.Queue[WorkflowRunEvent]] = {}
# events have keys of the format workflow_run_id + subscription_id
events: dict[int, _Subscription] = {}

def __init__(self, config: ClientConfig):
conn = new_conn(config, True)
Expand All @@ -35,30 +75,77 @@ def abort(self):
self.stop_signal = True
self.requests.put_nowait(False)

async def _interrupter(self):
"""
_interrupter runs in a separate thread and interrupts the listener according to a configurable duration.
"""
await asyncio.sleep(DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL)

if self.interrupt is not None:
self.interrupt.set()

async def _init_producer(self):
try:
if not self.listener:
self.listener = await self._retry_subscribe()
logger.debug(f"Workflow run listener connected.")
async for workflow_event in self.listener:
if workflow_event.workflowRunId in self.events:
self.events[workflow_event.workflowRunId].put_nowait(
workflow_event
)
else:
logger.warning(
f"Received event for unknown workflow: {workflow_event.workflowRunId}"
)
while True:
try:
self.listener = await self._retry_subscribe()

logger.debug(f"Workflow run listener connected.")

# spawn an interrupter task
asyncio.create_task(self._interrupter())

while True:
self.interrupt = Event_ts()
t = asyncio.create_task(
read_with_interrupt(self.listener, self.interrupt)
)
await self.interrupt.wait()

if not t.done():
# print a warning
logger.warning(
"Interrupted read_with_interrupt task of workflow run listener"
)

t.cancel()
self.listener.cancel()
break

workflow_event: WorkflowRunEvent = t.result()

# get a list of subscriptions for this workflow
subscriptions = self.workflowsToSubscriptions.get(
workflow_event.workflowRunId, []
)

for subscription_id in subscriptions:
await self.events[subscription_id].put(workflow_event)

except grpc.RpcError as e:
logger.error(f"grpc error in workflow run listener: {e}")
continue
except Exception as e:
logger.error(f"Error in workflow run listener: {e}")

self.listener = None

# signal all subscribers to stop
# FIXME this is a bit of a hack, ideally we re-establish the listener and re-subscribe
for key in self.events.keys():
self.events[key].put_nowait(False)
# close all subscriptions
for subscription_id in self.events:
await self.events[subscription_id].close()

raise e

async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]:
# replay all existing subscriptions
workflow_run_set = set(self.subscriptionsToWorkflows.values())

for workflow_run_id in workflow_run_set:
yield SubscribeToWorkflowRunsRequest(
workflowRunId=workflow_run_id,
)

while True:
request = await self.requests.get()

Expand All @@ -69,44 +156,69 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]:
yield request
self.requests.task_done()

def cleanup_subscription(self, subscription_id: int):
workflow_run_id = self.subscriptionsToWorkflows[subscription_id]

if workflow_run_id in self.workflowsToSubscriptions:
self.workflowsToSubscriptions[workflow_run_id].remove(subscription_id)

del self.subscriptionsToWorkflows[subscription_id]
del self.events[subscription_id]

async def subscribe(self, workflow_run_id: str):
self.events[workflow_run_id] = asyncio.Queue()
try:
# create a new subscription id, place a mutex on the counter
await self.subscription_counter_lock.acquire()
self.subscription_counter += 1
subscription_id = self.subscription_counter
self.subscription_counter_lock.release()

asyncio.create_task(self._init_producer())
self.subscriptionsToWorkflows[subscription_id] = workflow_run_id

await self.requests.put(
SubscribeToWorkflowRunsRequest(
workflowRunId=workflow_run_id,
if workflow_run_id not in self.workflowsToSubscriptions:
self.workflowsToSubscriptions[workflow_run_id] = [subscription_id]
else:
self.workflowsToSubscriptions[workflow_run_id].append(subscription_id)

self.events[subscription_id] = _Subscription(
subscription_id, workflow_run_id
)
)

while True:
event = await self.events[workflow_run_id].get()
if event is False:
break
if event.workflowRunId == workflow_run_id:
yield event
break # FIXME this should only break on terminal events... but we're not broadcasting event types
asyncio.create_task(self._init_producer())

await self.requests.put(
SubscribeToWorkflowRunsRequest(
workflowRunId=workflow_run_id,
)
)

event = await self.events[subscription_id].get()

del self.events[workflow_run_id]
self.cleanup_subscription(subscription_id)

return event
except asyncio.CancelledError:
self.cleanup_subscription(subscription_id)
raise

async def result(self, workflow_run_id: str):
async for event in self.subscribe(workflow_run_id):
errors = []
event = await self.subscribe(workflow_run_id)

errors = []

if event.results:
errors = [result.error for result in event.results if result.error]
if event.results:
errors = [result.error for result in event.results if result.error]

if errors:
raise Exception(f"Workflow Errors: {errors}")
if errors:
raise Exception(f"Workflow Errors: {errors}")

results = {
result.stepReadableId: json.loads(result.output)
for result in event.results
if result.output
}
results = {
result.stepReadableId: json.loads(result.output)
for result in event.results
if result.output
}

return results
return results

async def _retry_subscribe(self):
retries = 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "0.27.0"
version = "0.27.1"
description = ""
authors = ["Alexander Belanger <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit b69ec97

Please sign in to comment.