Skip to content

Commit

Permalink
Merge pull request #29 from hatchet-dev/belanger/reliability
Browse files Browse the repository at this point in the history
fix: network reliability improvements
  • Loading branch information
abelanger5 authored May 30, 2024
2 parents 49c5758 + 9e867ea commit 8d00a85
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 32 deletions.
2 changes: 1 addition & 1 deletion examples/fanout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def spawn(self, context: Context):
for i in range(100):
results.append(
(
await context.spawn_workflow(
await context.aio.spawn_workflow(
"Child", {"a": str(i)}, key=f"child{i}"
)
).result()
Expand Down
68 changes: 49 additions & 19 deletions hatchet_sdk/clients/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# relative imports
import asyncio
import json
import random
import threading
import time
from typing import AsyncGenerator, Callable, List, Union
from typing import Any, AsyncGenerator, Callable, List, Union

import grpc

Expand Down Expand Up @@ -106,6 +107,19 @@ def unregister(self):
START_GET_GROUP_KEY = 2


async def read_action(listener: Any, interrupt: asyncio.Event):
assigned_action = await listener.read()
interrupt.set()
return assigned_action


async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5):
base_time = 0.1 # starting sleep time in seconds (100 milliseconds)
jitter = random.uniform(0, base_time) # add random jitter
sleep_time = min(base_time * (2**attempt) + jitter, max_sleep_time)
await asyncio.sleep(sleep_time)


class ActionListenerImpl(WorkerActionListener):
config: ClientConfig

Expand All @@ -123,44 +137,45 @@ def __init__(
self.worker_id = worker_id
self.retries = 0
self.last_connection_attempt = 0
self.heartbeat_thread = None
self.heartbeat_task: asyncio.Task = None
self.run_heartbeat = True
self.listen_strategy = "v2"
self.stop_signal = False
self.logger = logger.bind(worker_id=worker_id)

def heartbeat(self):
async def heartbeat(self):
# send a heartbeat every 4 seconds
while True:
if not self.run_heartbeat:
break

try:
self.client.Heartbeat(
await self.aio_client.Heartbeat(
HeartbeatRequest(
workerId=self.worker_id,
heartbeatAt=proto_timestamp_now(),
),
timeout=DEFAULT_REGISTER_TIMEOUT,
timeout=5,
metadata=get_metadata(self.token),
)
except grpc.RpcError as e:
# we don't reraise the error here, as we don't want to stop the heartbeat thread
logger.error(f"Failed to send heartbeat: {e}")

if self.interrupt is not None:
self.interrupt.set()

if e.code() == grpc.StatusCode.UNIMPLEMENTED:
break

time.sleep(4)
await asyncio.sleep(4)

def start_heartbeater(self):
if self.heartbeat_thread is not None:
if self.heartbeat_task is not None:
return

# create a new thread to send heartbeats
heartbeat_thread = threading.Thread(target=self.heartbeat)
heartbeat_thread.start()

self.heartbeat_thread = heartbeat_thread
self.heartbeat_task = asyncio.create_task(self.heartbeat())

def __aiter__(self):
return self._generator()
Expand All @@ -170,8 +185,23 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
if self.stop_signal:
break

listener = await self.get_listen_client()

try:
async for assigned_action in await self.get_listen_client():
while True:
self.interrupt = asyncio.Event()
t = asyncio.create_task(read_action(listener, self.interrupt))
await self.interrupt.wait()

if not t.done():
# print a warning
logger.warning("Interrupted read_action task")

t.cancel()
listener.cancel()
break

assigned_action = t.result()
self.retries = 0
assigned_action: AssignedAction

Expand Down Expand Up @@ -204,13 +234,12 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
)

yield action

except grpc.RpcError as e:
# Handle different types of errors
if e.code() == grpc.StatusCode.CANCELLED:
# Context cancelled, unsubscribe and close
# self.logger.debug("Context cancelled, closing listener")
break
self.logger.debug("Context cancelled, closing listener")
continue
elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.info("Deadline exceeded, retrying subscription")
continue
Expand All @@ -225,8 +254,6 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
continue
else:
# Unknown error, report and break
# self.logger.error(f"Failed to receive message: {e}")
# err_ch(e)
logger.error(f"Failed to receive message: {e}")

self.retries = self.retries + 1
Expand Down Expand Up @@ -265,7 +292,10 @@ async def get_listen_client(self):
elif self.retries >= 1:
# logger.info
# if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter
time.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL)
await exp_backoff_sleep(
self.retries, DEFAULT_ACTION_LISTENER_RETRY_INTERVAL
)

logger.info(
f"Could not connect to Hatchet, retrying... {self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT}"
)
Expand All @@ -288,7 +318,7 @@ async def get_listen_client(self):

self.last_connection_attempt = current_time

logger.info("Listener established.")
logger.info("Established listener.")
return listener

def unregister(self):
Expand Down
15 changes: 14 additions & 1 deletion hatchet_sdk/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@ def new_conn(config, aio=False):

strat = grpc if not aio else grpc.aio

channel_options = [
("grpc.keepalive_time_ms", 10 * 1000),
("grpc.keepalive_timeout_ms", 60 * 1000),
("grpc.client_idle_timeout_ms", 60 * 1000),
("grpc.http2.max_pings_without_data", 5),
("grpc.keepalive_permit_without_calls", 1),
]

if config.tls_config.tls_strategy == "none":
conn = strat.insecure_channel(
target=config.host_port,
options=channel_options,
)
else:
channel_options.append(
("grpc.ssl_target_name_override", config.tls_config.server_name)
)

conn = strat.secure_channel(
target=config.host_port,
credentials=credentials,
options=[("grpc.ssl_target_name_override", config.tls_config.server_name)],
options=channel_options,
)
return conn
12 changes: 1 addition & 11 deletions hatchet_sdk/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ async def handle_cancel_action(self, run_id: str):
del self.tasks[run_id]

# grace period of 1 second
time.sleep(1)
await asyncio.sleep(1)

# check if thread is still running, if so, kill it
if run_id in self.threads:
Expand Down Expand Up @@ -451,16 +451,6 @@ def exit_gracefully(self, signum, frame):
sys.exit(0)

def start(self, retry_count=1):
actions = self.action_registry.items()

for action_name, action_func in actions:
logger.debug(f"Registered action: {action_name}")

# if action_func._is_coroutine:
# raise Exception(
# "Cannot register async actions with the synchronous worker, use async_start instead."
# )

try:
loop = asyncio.get_running_loop()
self.loop = loop
Expand Down

0 comments on commit 8d00a85

Please sign in to comment.