diff --git a/examples/on_failure/test_on_failure.py b/examples/on_failure/test_on_failure.py index 55103aa6..2bd74280 100644 --- a/examples/on_failure/test_on_failure.py +++ b/examples/on_failure/test_on_failure.py @@ -1,14 +1,32 @@ -# from hatchet_sdk import Hatchet -# import pytest +import asyncio -# from tests.utils import fixture_bg_worker -# from tests.utils.hatchet_client import hatchet_client_fixture +import pytest +from hatchet_sdk import Hatchet +from hatchet_sdk.clients.rest.models.job_run_status import JobRunStatus +from tests.utils import fixture_bg_worker +from tests.utils.hatchet_client import hatchet_client_fixture -# hatchet = hatchet_client_fixture() -# worker = fixture_bg_worker(["poetry", "run", "manual_trigger"]) +hatchet = hatchet_client_fixture() +worker = fixture_bg_worker(["poetry", "run", "on_failure"]) -# # requires scope module or higher for shared event loop -# @pytest.mark.asyncio(scope="session") -# async def test_run(hatchet: Hatchet): -# # TODO + +# requires scope module or higher for shared event loop +@pytest.mark.asyncio(scope="session") +async def test_run_timeout(hatchet: Hatchet): + run = hatchet.admin.run_workflow("OnFailureWorkflow", {}) + try: + await run.result() + assert False, "Expected workflow to timeout" + except Exception as e: + assert "step1 failed" in str(e) + + await asyncio.sleep(2) # Wait for the on_failure job to finish + + job_runs = hatchet.rest.workflow_run_get(run.workflow_run_id).job_runs + assert len(job_runs) == 2 + + successful_job_runs = [jr for jr in job_runs if jr.status == JobRunStatus.SUCCEEDED] + failed_job_runs = [jr for jr in job_runs if jr.status == JobRunStatus.FAILED] + assert len(successful_job_runs) == 1 + assert len(failed_job_runs) == 1 diff --git a/examples/on_failure/worker.py b/examples/on_failure/worker.py index e1b36e3e..804203a3 100644 --- a/examples/on_failure/worker.py +++ b/examples/on_failure/worker.py @@ -11,14 +11,18 @@ @hatchet.workflow(on_events=["user:create"]) class OnFailureWorkflow: - @hatchet.step() + @hatchet.step(timeout="1s") def step1(self, context: Context): raise Exception("step1 failed") @hatchet.on_failure_step() - def on_failure(self, context): + def on_failure(self, context: Context): + failures = context.fetch_run_failures() print("executed on_failure") - print(context) + print(json.dumps(failures, indent=2)) + if len(failures) == 1 and "step1 failed" in failures[0]["error"]: + return {"status": "success"} + raise Exception("unexpected failure") def main(): diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index 986f03b2..ccfced90 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -4,6 +4,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.clients.rest_client import RestApi from hatchet_sdk.clients.run_event_listener import RunEventListenerClient from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener from hatchet_sdk.context.worker_context import WorkerContext @@ -65,6 +66,7 @@ def __init__( dispatcher_client: DispatcherClient, admin_client: AdminClient, event_client: EventClient, + rest_client: RestApi, workflow_listener: PooledWorkflowRunListener, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, @@ -74,6 +76,7 @@ def __init__( self.dispatcher_client = dispatcher_client self.admin_client = admin_client self.event_client = event_client + self.rest_client = rest_client self.workflow_listener = workflow_listener self.workflow_run_event_listener = workflow_run_event_listener self.namespace = namespace @@ -116,6 +119,7 @@ def __init__( dispatcher_client: DispatcherClient, admin_client: AdminClient, event_client: EventClient, + rest_client: RestApi, workflow_listener: PooledWorkflowRunListener, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, @@ -128,6 +132,7 @@ def __init__( dispatcher_client, admin_client, event_client, + rest_client, workflow_listener, workflow_run_event_listener, worker, @@ -157,6 +162,7 @@ def __init__( self.dispatcher_client = dispatcher_client self.admin_client = admin_client self.event_client = event_client + self.rest_client = rest_client self.workflow_listener = workflow_listener self.workflow_run_event_listener = workflow_run_event_listener self.namespace = namespace @@ -291,3 +297,23 @@ def child_key(self): def parent_workflow_run_id(self): return self.action.parent_workflow_run_id + + def fetch_run_failures(self): + data = self.rest_client.workflow_run_get(self.action.workflow_run_id) + other_job_runs = [ + run for run in data.job_runs if run.job_id != self.action.job_id + ] + # TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary + failed_step_runs = [ + { + "step_id": step_run.step_id, + "step_run_action_name": step_run.step.action, + "error": step_run.error, + } + for job_run in other_job_runs + if job_run.step_runs + for step_run in job_run.step_runs + if step_run.error + ] + + return failed_step_runs diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index af2f896a..c9f005a6 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -334,6 +334,7 @@ async def handle_start_step_run(self, action: Action): self.dispatcher_client, self.admin_client, self.client.event, + self.client.rest, self.client.workflow_listener, self.workflow_run_event_listener, self.worker_context, @@ -345,6 +346,7 @@ async def handle_start_step_run(self, action: Action): self.dispatcher_client, self.admin_client, self.client.event, + self.client.rest, self.client.workflow_listener, self.workflow_run_event_listener, self.worker_context, @@ -373,7 +375,7 @@ async def handle_start_step_run(self, action: Action): try: await task - except Exception as e: + except Exception: # do nothing, this should be caught in the callback pass @@ -384,6 +386,7 @@ async def handle_start_group_key_run(self, action: Action): self.dispatcher_client, self.admin_client, self.client.event, + self.client.rest, self.client.workflow_listener, self.workflow_run_event_listener, self.worker_context, @@ -415,7 +418,7 @@ async def handle_start_group_key_run(self, action: Action): try: await task - except Exception as e: + except Exception: # do nothing, this should be caught in the callback pass