diff --git a/examples/fanout/worker.py b/examples/fanout/worker.py index 5841632a..9c93067c 100644 --- a/examples/fanout/worker.py +++ b/examples/fanout/worker.py @@ -20,7 +20,7 @@ async def spawn(self, context: Context): for i in range(100): results.append( ( - await context.spawn_workflow( + await context.aio.spawn_workflow( "Child", {"a": str(i)}, key=f"child{i}" ) ).result() diff --git a/hatchet_sdk/clients/dispatcher.py b/hatchet_sdk/clients/dispatcher.py index a9a38ba6..97e63428 100644 --- a/hatchet_sdk/clients/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher.py @@ -1,9 +1,10 @@ # relative imports import asyncio import json +import random import threading import time -from typing import AsyncGenerator, Callable, List, Union +from typing import Any, AsyncGenerator, Callable, List, Union import grpc @@ -106,6 +107,19 @@ def unregister(self): START_GET_GROUP_KEY = 2 +async def read_action(listener: Any, interrupt: asyncio.Event): + assigned_action = await listener.read() + interrupt.set() + return assigned_action + + +async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5): + base_time = 0.1 # starting sleep time in seconds (100 milliseconds) + jitter = random.uniform(0, base_time) # add random jitter + sleep_time = min(base_time * (2**attempt) + jitter, max_sleep_time) + await asyncio.sleep(sleep_time) + + class ActionListenerImpl(WorkerActionListener): config: ClientConfig @@ -123,44 +137,45 @@ def __init__( self.worker_id = worker_id self.retries = 0 self.last_connection_attempt = 0 - self.heartbeat_thread = None + self.heartbeat_task: asyncio.Task = None self.run_heartbeat = True self.listen_strategy = "v2" self.stop_signal = False + self.logger = logger.bind(worker_id=worker_id) - def heartbeat(self): + async def heartbeat(self): # send a heartbeat every 4 seconds while True: if not self.run_heartbeat: break try: - self.client.Heartbeat( + await self.aio_client.Heartbeat( HeartbeatRequest( workerId=self.worker_id, heartbeatAt=proto_timestamp_now(), ), - timeout=DEFAULT_REGISTER_TIMEOUT, + timeout=5, metadata=get_metadata(self.token), ) except grpc.RpcError as e: # we don't reraise the error here, as we don't want to stop the heartbeat thread logger.error(f"Failed to send heartbeat: {e}") + if self.interrupt is not None: + self.interrupt.set() + if e.code() == grpc.StatusCode.UNIMPLEMENTED: break - time.sleep(4) + await asyncio.sleep(4) def start_heartbeater(self): - if self.heartbeat_thread is not None: + if self.heartbeat_task is not None: return # create a new thread to send heartbeats - heartbeat_thread = threading.Thread(target=self.heartbeat) - heartbeat_thread.start() - - self.heartbeat_thread = heartbeat_thread + self.heartbeat_task = asyncio.create_task(self.heartbeat()) def __aiter__(self): return self._generator() @@ -170,8 +185,23 @@ async def _generator(self) -> AsyncGenerator[Action, None]: if self.stop_signal: break + listener = await self.get_listen_client() + try: - async for assigned_action in await self.get_listen_client(): + while True: + self.interrupt = asyncio.Event() + t = asyncio.create_task(read_action(listener, self.interrupt)) + await self.interrupt.wait() + + if not t.done(): + # print a warning + logger.warning("Interrupted read_action task") + + t.cancel() + listener.cancel() + break + + assigned_action = t.result() self.retries = 0 assigned_action: AssignedAction @@ -204,13 +234,12 @@ async def _generator(self) -> AsyncGenerator[Action, None]: ) yield action - except grpc.RpcError as e: # Handle different types of errors if e.code() == grpc.StatusCode.CANCELLED: # Context cancelled, unsubscribe and close - # self.logger.debug("Context cancelled, closing listener") - break + self.logger.debug("Context cancelled, closing listener") + continue elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: logger.info("Deadline exceeded, retrying subscription") continue @@ -225,8 +254,6 @@ async def _generator(self) -> AsyncGenerator[Action, None]: continue else: # Unknown error, report and break - # self.logger.error(f"Failed to receive message: {e}") - # err_ch(e) logger.error(f"Failed to receive message: {e}") self.retries = self.retries + 1 @@ -265,7 +292,10 @@ async def get_listen_client(self): elif self.retries >= 1: # logger.info # if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter - time.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL) + await exp_backoff_sleep( + self.retries, DEFAULT_ACTION_LISTENER_RETRY_INTERVAL + ) + logger.info( f"Could not connect to Hatchet, retrying... {self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT}" ) @@ -288,7 +318,7 @@ async def get_listen_client(self): self.last_connection_attempt = current_time - logger.info("Listener established.") + logger.info("Established listener.") return listener def unregister(self): diff --git a/hatchet_sdk/connection.py b/hatchet_sdk/connection.py index 87a95f31..c9a44146 100644 --- a/hatchet_sdk/connection.py +++ b/hatchet_sdk/connection.py @@ -28,14 +28,27 @@ def new_conn(config, aio=False): strat = grpc if not aio else grpc.aio + channel_options = [ + ("grpc.keepalive_time_ms", 10 * 1000), + ("grpc.keepalive_timeout_ms", 60 * 1000), + ("grpc.client_idle_timeout_ms", 60 * 1000), + ("grpc.http2.max_pings_without_data", 5), + ("grpc.keepalive_permit_without_calls", 1), + ] + if config.tls_config.tls_strategy == "none": conn = strat.insecure_channel( target=config.host_port, + options=channel_options, ) else: + channel_options.append( + ("grpc.ssl_target_name_override", config.tls_config.server_name) + ) + conn = strat.secure_channel( target=config.host_port, credentials=credentials, - options=[("grpc.ssl_target_name_override", config.tls_config.server_name)], + options=channel_options, ) return conn diff --git a/hatchet_sdk/worker.py b/hatchet_sdk/worker.py index f6bbfd20..f8525a56 100644 --- a/hatchet_sdk/worker.py +++ b/hatchet_sdk/worker.py @@ -332,7 +332,7 @@ async def handle_cancel_action(self, run_id: str): del self.tasks[run_id] # grace period of 1 second - time.sleep(1) + await asyncio.sleep(1) # check if thread is still running, if so, kill it if run_id in self.threads: @@ -451,16 +451,6 @@ def exit_gracefully(self, signum, frame): sys.exit(0) def start(self, retry_count=1): - actions = self.action_registry.items() - - for action_name, action_func in actions: - logger.debug(f"Registered action: {action_name}") - - # if action_func._is_coroutine: - # raise Exception( - # "Cannot register async actions with the synchronous worker, use async_start instead." - # ) - try: loop = asyncio.get_running_loop() self.loop = loop