Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
abelanger5 committed May 27, 2024
1 parent d931dd1 commit f4deb3e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
4 changes: 3 additions & 1 deletion examples/async/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ async def step2(self, context):
await asyncio.sleep(2)
print("finished step2")


async def main():
workflow = AsyncWorkflow()
worker = hatchet.worker("test-worker", max_runs=4)
worker.register_workflow(workflow)
await worker.async_start()

asyncio.run(main())

asyncio.run(main())
17 changes: 14 additions & 3 deletions hatchet_sdk/clients/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import AsyncGenerator, Callable, List, Union

import grpc

from hatchet_sdk.connection import new_conn

from ..dispatcher_pb2 import (
Expand Down Expand Up @@ -107,7 +108,13 @@ def unregister(self):
class ActionListenerImpl(WorkerActionListener):
config: ClientConfig

def __init__(self, client: DispatcherStub, aio_client: DispatcherStub, config: ClientConfig, worker_id):
def __init__(
self,
client: DispatcherStub,
aio_client: DispatcherStub,
config: ClientConfig,
worker_id,
):
self.aio_client = aio_client
self.client = client
self.config = config
Expand Down Expand Up @@ -311,7 +318,9 @@ def __init__(self, config: ClientConfig):
# self.logger = logger
# self.validator = validator

async def get_action_listener(self, req: GetActionListenerRequest) -> ActionListenerImpl:
async def get_action_listener(
self, req: GetActionListenerRequest
) -> ActionListenerImpl:
# Register the worker
response: WorkerRegisterResponse = await self.aio_client.Register(
WorkerRegisterRequest(
Expand All @@ -324,7 +333,9 @@ async def get_action_listener(self, req: GetActionListenerRequest) -> ActionList
metadata=get_metadata(self.token),
)

return ActionListenerImpl(self.client, self.aio_client, self.config, response.workerId)
return ActionListenerImpl(
self.client, self.aio_client, self.config, response.workerId
)

def send_step_action_event(self, in_: StepActionEvent):
response: ActionEventResponse = self.client.SendStepActionEvent(
Expand Down
34 changes: 23 additions & 11 deletions hatchet_sdk/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import ctypes
import functools
import json
Expand All @@ -8,6 +7,7 @@
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from threading import Thread, current_thread
from typing import Any, Callable, Dict

Expand All @@ -17,7 +17,12 @@
from hatchet_sdk.loader import ClientConfig

from .client import new_client
from .clients.dispatcher import Action, ActionListenerImpl, GetActionListenerRequest, new_dispatcher
from .clients.dispatcher import (
Action,
ActionListenerImpl,
GetActionListenerRequest,
new_dispatcher,
)
from .context import Context
from .dispatcher_pb2 import (
GROUP_KEY_EVENT_TYPE_COMPLETED,
Expand All @@ -35,6 +40,7 @@
from .logger import logger
from .workflow import WorkflowMeta


class Worker:
def __init__(
self,
Expand All @@ -53,7 +59,7 @@ def __init__(
self.contexts: Dict[str, Context] = {} # Store step run ids and contexts
self.action_registry: dict[str, Callable[..., Any]] = {}

# The thread pool is used for synchronous functions which need to run concurrently
# The thread pool is used for synchronous functions which need to run concurrently
self.thread_pool = ThreadPoolExecutor(max_workers=max_runs)
self.threads: Dict[str, Thread] = {} # Store step run ids and threads

Expand All @@ -75,6 +81,7 @@ async def handle_start_step_run(self, action: Action):
action_func = self.action_registry.get(action_name)

if action_func:

def callback(task: asyncio.Task):
errored = False
cancelled = task.cancelled()
Expand Down Expand Up @@ -120,8 +127,10 @@ async def async_wrapped_action_func(context):
if action_func._is_coroutine:
return await action_func(context)
else:

pfunc = functools.partial(thread_action_func, context, action_func)

pfunc = functools.partial(
thread_action_func, context, action_func
)
res = await self.loop.run_in_executor(self.thread_pool, pfunc)

if action.step_run_id in self.threads:
Expand All @@ -141,7 +150,7 @@ async def async_wrapped_action_func(context):
finally:
if action.step_run_id in self.tasks:
del self.tasks[action.step_run_id]

task = self.loop.create_task(async_wrapped_action_func(context))
task.add_done_callback(callback)
self.tasks[action.step_run_id] = task
Expand All @@ -167,6 +176,7 @@ async def handle_start_group_key_run(self, action: Action):
action_func = self.action_registry.get(action_name)

if action_func:

def callback(task: asyncio.Task):
errored = False
cancelled = task.cancelled()
Expand Down Expand Up @@ -214,8 +224,10 @@ async def async_wrapped_action_func(context):
if action_func._is_coroutine:
return await action_func(context)
else:

pfunc = functools.partial(thread_action_func, context, action_func)

pfunc = functools.partial(
thread_action_func, context, action_func
)
res = await self.loop.run_in_executor(self.thread_pool, pfunc)

if action.step_run_id in self.threads:
Expand Down Expand Up @@ -384,7 +396,7 @@ def register_workflow(self, workflow: WorkflowMeta):
def create_action_function(action_func):
def action_function(context):
return action_func(workflow, context)

if asyncio.iscoroutinefunction(action_func):
action_function._is_coroutine = True
else:
Expand Down Expand Up @@ -414,7 +426,7 @@ def exit_gracefully(self, signum, frame):

if self.handle_kill:
logger.info("Exiting...")
sys.exit(0)
sys.exit(0)

def start(self, retry_count=1):
try:
Expand Down Expand Up @@ -452,7 +464,7 @@ async def async_start(self, retry_count=1):
)

# It's important that this iterates async so it doesn't block the event loop. This is
# what allows self.loop.create_task to work.
# what allows self.loop.create_task to work.
async for action in self.listener:
if action.action_type == ActionType.START_STEP_RUN:
self.loop.create_task(self.handle_start_step_run(action))
Expand Down

0 comments on commit f4deb3e

Please sign in to comment.