Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Workflow Run Step Error in Runner Context #140

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions examples/on_failure/test_on_failure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
# from hatchet_sdk import Hatchet
# import pytest
import asyncio

# from tests.utils import fixture_bg_worker
# from tests.utils.hatchet_client import hatchet_client_fixture
import pytest

from hatchet_sdk import Hatchet
from hatchet_sdk.clients.rest.models.job_run_status import JobRunStatus
from tests.utils import fixture_bg_worker
from tests.utils.hatchet_client import hatchet_client_fixture

# hatchet = hatchet_client_fixture()
# worker = fixture_bg_worker(["poetry", "run", "manual_trigger"])
hatchet = hatchet_client_fixture()
worker = fixture_bg_worker(["poetry", "run", "on_failure"])

# # requires scope module or higher for shared event loop
# @pytest.mark.asyncio(scope="session")
# async def test_run(hatchet: Hatchet):
# # TODO

# requires scope module or higher for shared event loop
@pytest.mark.asyncio(scope="session")
async def test_run_timeout(hatchet: Hatchet):
run = hatchet.admin.run_workflow("OnFailureWorkflow", {})
try:
await run.result()
assert False, "Expected workflow to timeout"
except Exception as e:
assert "step1 failed" in str(e)

await asyncio.sleep(2) # Wait for the on_failure job to finish

job_runs = hatchet.rest.workflow_run_get(run.workflow_run_id).job_runs
assert len(job_runs) == 2

successful_job_runs = [jr for jr in job_runs if jr.status == JobRunStatus.SUCCEEDED]
failed_job_runs = [jr for jr in job_runs if jr.status == JobRunStatus.FAILED]
assert len(successful_job_runs) == 1
assert len(failed_job_runs) == 1
10 changes: 7 additions & 3 deletions examples/on_failure/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@

@hatchet.workflow(on_events=["user:create"])
class OnFailureWorkflow:
@hatchet.step()
@hatchet.step(timeout="1s")
def step1(self, context: Context):
raise Exception("step1 failed")

@hatchet.on_failure_step()
def on_failure(self, context):
def on_failure(self, context: Context):
failures = context.fetch_run_failures()
print("executed on_failure")
print(context)
print(json.dumps(failures, indent=2))
if len(failures) == 1 and "step1 failed" in failures[0]["error"]:
return {"status": "success"}
raise Exception("unexpected failure")


def main():
Expand Down
26 changes: 26 additions & 0 deletions hatchet_sdk/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from concurrent.futures import Future, ThreadPoolExecutor

from hatchet_sdk.clients.events import EventClient
from hatchet_sdk.clients.rest_client import RestApi
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.context.worker_context import WorkerContext
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
dispatcher_client: DispatcherClient,
admin_client: AdminClient,
event_client: EventClient,
rest_client: RestApi,
workflow_listener: PooledWorkflowRunListener,
workflow_run_event_listener: RunEventListenerClient,
worker: WorkerContext,
Expand All @@ -74,6 +76,7 @@ def __init__(
self.dispatcher_client = dispatcher_client
self.admin_client = admin_client
self.event_client = event_client
self.rest_client = rest_client
self.workflow_listener = workflow_listener
self.workflow_run_event_listener = workflow_run_event_listener
self.namespace = namespace
Expand Down Expand Up @@ -116,6 +119,7 @@ def __init__(
dispatcher_client: DispatcherClient,
admin_client: AdminClient,
event_client: EventClient,
rest_client: RestApi,
workflow_listener: PooledWorkflowRunListener,
workflow_run_event_listener: RunEventListenerClient,
worker: WorkerContext,
Expand All @@ -128,6 +132,7 @@ def __init__(
dispatcher_client,
admin_client,
event_client,
rest_client,
workflow_listener,
workflow_run_event_listener,
worker,
Expand Down Expand Up @@ -157,6 +162,7 @@ def __init__(
self.dispatcher_client = dispatcher_client
self.admin_client = admin_client
self.event_client = event_client
self.rest_client = rest_client
self.workflow_listener = workflow_listener
self.workflow_run_event_listener = workflow_run_event_listener
self.namespace = namespace
Expand Down Expand Up @@ -291,3 +297,23 @@ def child_key(self):

def parent_workflow_run_id(self):
return self.action.parent_workflow_run_id

def fetch_run_failures(self):
data = self.rest_client.workflow_run_get(self.action.workflow_run_id)
other_job_runs = [
run for run in data.job_runs if run.job_id != self.action.job_id
]
# TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary
failed_step_runs = [
{
"step_id": step_run.step_id,
"step_run_action_name": step_run.step.action,
"error": step_run.error,
}
for job_run in other_job_runs
if job_run.step_runs
for step_run in job_run.step_runs
if step_run.error
]

return failed_step_runs
7 changes: 5 additions & 2 deletions hatchet_sdk/worker/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ async def handle_start_step_run(self, action: Action):
self.dispatcher_client,
self.admin_client,
self.client.event,
self.client.rest,
self.client.workflow_listener,
self.workflow_run_event_listener,
self.worker_context,
Expand All @@ -345,6 +346,7 @@ async def handle_start_step_run(self, action: Action):
self.dispatcher_client,
self.admin_client,
self.client.event,
self.client.rest,
self.client.workflow_listener,
self.workflow_run_event_listener,
self.worker_context,
Expand Down Expand Up @@ -373,7 +375,7 @@ async def handle_start_step_run(self, action: Action):

try:
await task
except Exception as e:
except Exception:
# do nothing, this should be caught in the callback
pass

Expand All @@ -384,6 +386,7 @@ async def handle_start_group_key_run(self, action: Action):
self.dispatcher_client,
self.admin_client,
self.client.event,
self.client.rest,
self.client.workflow_listener,
self.workflow_run_event_listener,
self.worker_context,
Expand Down Expand Up @@ -415,7 +418,7 @@ async def handle_start_group_key_run(self, action: Action):

try:
await task
except Exception as e:
except Exception:
# do nothing, this should be caught in the callback
pass

Expand Down
Loading