Skip to content

Commit

Permalink
add a run_workflows method to run many workflows in bulk (#222)
Browse files Browse the repository at this point in the history
* add a run_workflows method to run many workflows in bulk

* lint

* use the dedupe exceptions when we have a dedupe

* add spawn_workflows and run_workflows with an example

* think this is how the testing works

* Add a test although it hangs on the second one - I think because of event loop stuff

* is this how I bump version

* maybe

---------

Co-authored-by: Sean Reilly <[email protected]>
  • Loading branch information
reillyse and Sean Reilly authored Oct 23, 2024
1 parent 8dd2efe commit 19eb988
Show file tree
Hide file tree
Showing 46 changed files with 2,347 additions and 146 deletions.
52 changes: 52 additions & 0 deletions examples/bulk_fanout/bulk_trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
import base64
import json
import os

from dotenv import load_dotenv

from hatchet_sdk import new_client
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
from hatchet_sdk.clients.run_event_listener import StepRunEventType


async def main():
load_dotenv()
hatchet = new_client()

workflowRuns: WorkflowRun = []

# we are going to run the BulkParent workflow 20 which will trigger the Child workflows n times for each n in range(20)
for i in range(20):
workflowRuns.append(
{
"workflow_name": "BulkParent",
"input": {"n": i},
"options": {
"additional_metadata": {
"bulk-trigger": i,
"hello-{i}": "earth-{i}",
},
},
}
)

workflowRunRefs = hatchet.admin.run_workflows(
workflowRuns,
)

results = await asyncio.gather(
*[workflowRunRef.result() for workflowRunRef in workflowRunRefs],
return_exceptions=True,
)

for result in results:
if isinstance(result, Exception):
print(f"An error occurred: {result}") # Handle the exception here
else:
print(result)


if __name__ == "__main__":
asyncio.run(main())
45 changes: 45 additions & 0 deletions examples/bulk_fanout/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import asyncio
import base64
import json
import os
import random

from dotenv import load_dotenv

from hatchet_sdk import new_client
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
from hatchet_sdk.clients.run_event_listener import StepRunEventType
from hatchet_sdk.v2.hatchet import Hatchet


async def main():
load_dotenv()
hatchet = Hatchet()

# Generate a random stream key to use to track all
# stream events for this workflow run.

streamKey = "streamKey"
streamVal = f"sk-{random.randint(1, 100)}"

# Specify the stream key as additional metadata
# when running the workflow.

# This key gets propagated to all child workflows
# and can have an arbitrary property name.

workflowRun = hatchet.admin.run_workflow(
"Parent",
{"n": 2},
options={"additional_metadata": {streamKey: streamVal}},
)

# Stream all events for the additional meta key value
listener = hatchet.listener.stream_by_additional_metadata(streamKey, streamVal)

async for event in listener:
print(event.type, event.payload)


if __name__ == "__main__":
asyncio.run(main())
25 changes: 25 additions & 0 deletions examples/bulk_fanout/test_bulk_fanout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from hatchet_sdk import Hatchet
from tests.utils import fixture_bg_worker
from tests.utils.hatchet_client import hatchet_client_fixture

hatchet = hatchet_client_fixture()
worker = fixture_bg_worker(["poetry", "run", "bulk_fanout"])


# requires scope module or higher for shared event loop
@pytest.mark.asyncio(scope="session")
async def test_run(hatchet: Hatchet):
run = hatchet.admin.run_workflow("BulkParent", {"n": 12})
result = await run.result()
print(result)
assert len(result["spawn"]["results"]) == 12


# requires scope module or higher for shared event loop
@pytest.mark.asyncio(scope="session")
async def test_run2(hatchet: Hatchet):
run = hatchet.admin.run_workflow("BulkParent", {"n": 10})
result = await run.result()
assert len(result["spawn"]["results"]) == 10
26 changes: 26 additions & 0 deletions examples/bulk_fanout/trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio
import base64
import json
import os

from dotenv import load_dotenv

from hatchet_sdk import new_client
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
from hatchet_sdk.clients.run_event_listener import StepRunEventType


async def main():
load_dotenv()
hatchet = new_client()

workflowRuns: WorkflowRun = []

event = hatchet.event.push(
"parent:create", {"n": 999}, {"additional_metadata": {"no-dedupe": "world"}}
)


if __name__ == "__main__":
asyncio.run(main())
84 changes: 84 additions & 0 deletions examples/bulk_fanout/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import asyncio
from typing import List

from dotenv import load_dotenv

from hatchet_sdk import Context, Hatchet
from hatchet_sdk.clients.admin import ChildWorkflowRunDict

load_dotenv()

hatchet = Hatchet(debug=True)


@hatchet.workflow(on_events=["parent:create"])
class BulkParent:
@hatchet.step(timeout="5m")
async def spawn(self, context: Context):
print("spawning child")

context.put_stream("spawning...")
results = []

n = context.workflow_input().get("n", 100)

child_workflow_runs: List[ChildWorkflowRunDict] = []

for i in range(n):

child_workflow_runs.append(
{
"workflow_name": "BulkChild",
"input": {"a": str(i)},
"key": f"child{i}",
"options": {"additional_metadata": {"hello": "earth"}},
}
)

if len(child_workflow_runs) == 0:
return

spawn_results = await context.aio.spawn_workflows(child_workflow_runs)

results = await asyncio.gather(
*[workflowRunRef.result() for workflowRunRef in spawn_results],
return_exceptions=True,
)

print("finished spawning children")

for result in results:
if isinstance(result, Exception):
print(f"An error occurred: {result}")
else:
print(result)

return {"results": results}


@hatchet.workflow(on_events=["child:create"])
class BulkChild:
@hatchet.step()
def process(self, context: Context):
a = context.workflow_input()["a"]
print(f"child process {a}")
context.put_stream("child 1...")
return {"status": "success " + a}

@hatchet.step()
def process2(self, context: Context):
print("child process2")
context.put_stream("child 2...")
return {"status2": "success"}


def main():

worker = hatchet.worker("fanout-worker", max_runs=40)
worker.register_workflow(BulkParent())
worker.register_workflow(BulkChild())
worker.start()


if __name__ == "__main__":
main()
130 changes: 130 additions & 0 deletions hatchet_sdk/clients/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import grpc
from google.protobuf import timestamp_pb2

from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.clients.run_event_listener import new_listener
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.workflows_pb2 import (
BulkTriggerWorkflowRequest,
BulkTriggerWorkflowResponse,
CreateWorkflowVersionOpts,
PutRateLimitRequest,
PutWorkflowRequest,
Expand Down Expand Up @@ -44,6 +47,19 @@ class ChildTriggerWorkflowOptions(TypedDict):
sticky: bool | None = None


class WorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: Optional[dict]


class ChildWorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: ChildTriggerWorkflowOptions[dict]
key: str


class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
additional_metadata: Dict[str, str] | None = None
desired_worker_id: str | None = None
Expand Down Expand Up @@ -203,6 +219,65 @@ async def run_workflow(

raise ValueError(f"gRPC error: {e}")

@tenacity_retry
async def run_workflows(
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
) -> List[WorkflowRunRef]:

if len(workflows) == 0:
raise ValueError("No workflows to run")
try:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

namespace = self.namespace

if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]

workflow_run_requests: TriggerWorkflowRequest = []

for workflow in workflows:

workflow_name = workflow["workflow_name"]
input_data = workflow["input"]
options = workflow["options"]

if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"

# Prepare and trigger workflow for each workflow name and input
request = self._prepare_workflow_request(
workflow_name, input_data, options
)
workflow_run_requests.append(request)

request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)

resp: BulkTriggerWorkflowResponse = (
await self.aio_client.BulkTriggerWorkflow(
request,
metadata=get_metadata(self.token),
)
)

return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
for workflow_run_id in resp.workflow_run_ids
]

except grpc.RpcError as e:
raise ValueError(f"gRPC error: {e}")

@tenacity_retry
async def put_workflow(
self,
Expand Down Expand Up @@ -398,6 +473,61 @@ def run_workflow(

raise ValueError(f"gRPC error: {e}")

@tenacity_retry
def run_workflows(
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
) -> list[WorkflowRunRef]:

workflow_run_requests: TriggerWorkflowRequest = []
try:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

for workflow in workflows:

workflow_name = workflow["workflow_name"]
input_data = workflow["input"]
options = workflow["options"]

namespace = self.namespace

if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]

if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"

# Prepare and trigger workflow for each workflow name and input
request = self._prepare_workflow_request(
workflow_name, input_data, options
)

workflow_run_requests.append(request)

request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)

resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow(
request,
metadata=get_metadata(self.token),
)

except grpc.RpcError as e:
raise ValueError(f"gRPC error: {e}")

return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
for workflow_run_id in resp.workflow_run_ids
]

def run(
self,
function: Union[str, Callable[[Any], T]],
Expand Down
Loading

0 comments on commit 19eb988

Please sign in to comment.