diff --git a/examples/async/worker.py b/examples/async/worker.py index f822f932..476efb9f 100644 --- a/examples/async/worker.py +++ b/examples/async/worker.py @@ -28,10 +28,12 @@ async def step2(self, context): await asyncio.sleep(2) print("finished step2") + async def main(): workflow = AsyncWorkflow() worker = hatchet.worker("test-worker", max_runs=4) worker.register_workflow(workflow) await worker.async_start() -asyncio.run(main()) \ No newline at end of file + +asyncio.run(main()) diff --git a/hatchet_sdk/clients/dispatcher.py b/hatchet_sdk/clients/dispatcher.py index f7ac8268..fac81a2c 100644 --- a/hatchet_sdk/clients/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher.py @@ -6,6 +6,7 @@ from typing import AsyncGenerator, Callable, List, Union import grpc + from hatchet_sdk.connection import new_conn from ..dispatcher_pb2 import ( @@ -107,7 +108,13 @@ def unregister(self): class ActionListenerImpl(WorkerActionListener): config: ClientConfig - def __init__(self, client: DispatcherStub, aio_client: DispatcherStub, config: ClientConfig, worker_id): + def __init__( + self, + client: DispatcherStub, + aio_client: DispatcherStub, + config: ClientConfig, + worker_id, + ): self.aio_client = aio_client self.client = client self.config = config @@ -311,7 +318,9 @@ def __init__(self, config: ClientConfig): # self.logger = logger # self.validator = validator - async def get_action_listener(self, req: GetActionListenerRequest) -> ActionListenerImpl: + async def get_action_listener( + self, req: GetActionListenerRequest + ) -> ActionListenerImpl: # Register the worker response: WorkerRegisterResponse = await self.aio_client.Register( WorkerRegisterRequest( @@ -324,7 +333,9 @@ async def get_action_listener(self, req: GetActionListenerRequest) -> ActionList metadata=get_metadata(self.token), ) - return ActionListenerImpl(self.client, self.aio_client, self.config, response.workerId) + return ActionListenerImpl( + self.client, self.aio_client, self.config, response.workerId + ) def send_step_action_event(self, in_: StepActionEvent): response: ActionEventResponse = self.client.SendStepActionEvent( diff --git a/hatchet_sdk/worker.py b/hatchet_sdk/worker.py index 5e23e131..e8664d77 100644 --- a/hatchet_sdk/worker.py +++ b/hatchet_sdk/worker.py @@ -1,5 +1,4 @@ import asyncio -from concurrent.futures import ThreadPoolExecutor import ctypes import functools import json @@ -8,6 +7,7 @@ import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor from threading import Thread, current_thread from typing import Any, Callable, Dict @@ -17,7 +17,12 @@ from hatchet_sdk.loader import ClientConfig from .client import new_client -from .clients.dispatcher import Action, ActionListenerImpl, GetActionListenerRequest, new_dispatcher +from .clients.dispatcher import ( + Action, + ActionListenerImpl, + GetActionListenerRequest, + new_dispatcher, +) from .context import Context from .dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_COMPLETED, @@ -35,6 +40,7 @@ from .logger import logger from .workflow import WorkflowMeta + class Worker: def __init__( self, @@ -53,7 +59,7 @@ def __init__( self.contexts: Dict[str, Context] = {} # Store step run ids and contexts self.action_registry: dict[str, Callable[..., Any]] = {} - # The thread pool is used for synchronous functions which need to run concurrently + # The thread pool is used for synchronous functions which need to run concurrently self.thread_pool = ThreadPoolExecutor(max_workers=max_runs) self.threads: Dict[str, Thread] = {} # Store step run ids and threads @@ -75,6 +81,7 @@ async def handle_start_step_run(self, action: Action): action_func = self.action_registry.get(action_name) if action_func: + def callback(task: asyncio.Task): errored = False cancelled = task.cancelled() @@ -120,8 +127,10 @@ async def async_wrapped_action_func(context): if action_func._is_coroutine: return await action_func(context) else: - - pfunc = functools.partial(thread_action_func, context, action_func) + + pfunc = functools.partial( + thread_action_func, context, action_func + ) res = await self.loop.run_in_executor(self.thread_pool, pfunc) if action.step_run_id in self.threads: @@ -141,7 +150,7 @@ async def async_wrapped_action_func(context): finally: if action.step_run_id in self.tasks: del self.tasks[action.step_run_id] - + task = self.loop.create_task(async_wrapped_action_func(context)) task.add_done_callback(callback) self.tasks[action.step_run_id] = task @@ -167,6 +176,7 @@ async def handle_start_group_key_run(self, action: Action): action_func = self.action_registry.get(action_name) if action_func: + def callback(task: asyncio.Task): errored = False cancelled = task.cancelled() @@ -214,8 +224,10 @@ async def async_wrapped_action_func(context): if action_func._is_coroutine: return await action_func(context) else: - - pfunc = functools.partial(thread_action_func, context, action_func) + + pfunc = functools.partial( + thread_action_func, context, action_func + ) res = await self.loop.run_in_executor(self.thread_pool, pfunc) if action.step_run_id in self.threads: @@ -384,7 +396,7 @@ def register_workflow(self, workflow: WorkflowMeta): def create_action_function(action_func): def action_function(context): return action_func(workflow, context) - + if asyncio.iscoroutinefunction(action_func): action_function._is_coroutine = True else: @@ -414,7 +426,7 @@ def exit_gracefully(self, signum, frame): if self.handle_kill: logger.info("Exiting...") - sys.exit(0) + sys.exit(0) def start(self, retry_count=1): try: @@ -452,7 +464,7 @@ async def async_start(self, retry_count=1): ) # It's important that this iterates async so it doesn't block the event loop. This is - # what allows self.loop.create_task to work. + # what allows self.loop.create_task to work. async for action in self.listener: if action.action_type == ActionType.START_STEP_RUN: self.loop.create_task(self.handle_start_step_run(action))