From 950f9f71db6b0112bd2f897f13f6dd242a6c7cba Mon Sep 17 00:00:00 2001 From: srhinos <6531393+srhinos@users.noreply.github.com> Date: Tue, 20 Aug 2024 11:57:18 -0400 Subject: [PATCH] Update Example / Add Test --- examples/on_failure/test_on_failure.py | 38 +++++++++++++++++++------- examples/on_failure/worker.py | 10 +++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/examples/on_failure/test_on_failure.py b/examples/on_failure/test_on_failure.py index 55103aa6..2bd74280 100644 --- a/examples/on_failure/test_on_failure.py +++ b/examples/on_failure/test_on_failure.py @@ -1,14 +1,32 @@ -# from hatchet_sdk import Hatchet -# import pytest +import asyncio -# from tests.utils import fixture_bg_worker -# from tests.utils.hatchet_client import hatchet_client_fixture +import pytest +from hatchet_sdk import Hatchet +from hatchet_sdk.clients.rest.models.job_run_status import JobRunStatus +from tests.utils import fixture_bg_worker +from tests.utils.hatchet_client import hatchet_client_fixture -# hatchet = hatchet_client_fixture() -# worker = fixture_bg_worker(["poetry", "run", "manual_trigger"]) +hatchet = hatchet_client_fixture() +worker = fixture_bg_worker(["poetry", "run", "on_failure"]) -# # requires scope module or higher for shared event loop -# @pytest.mark.asyncio(scope="session") -# async def test_run(hatchet: Hatchet): -# # TODO + +# requires scope module or higher for shared event loop +@pytest.mark.asyncio(scope="session") +async def test_run_timeout(hatchet: Hatchet): + run = hatchet.admin.run_workflow("OnFailureWorkflow", {}) + try: + await run.result() + assert False, "Expected workflow to timeout" + except Exception as e: + assert "step1 failed" in str(e) + + await asyncio.sleep(2) # Wait for the on_failure job to finish + + job_runs = hatchet.rest.workflow_run_get(run.workflow_run_id).job_runs + assert len(job_runs) == 2 + + successful_job_runs = [jr for jr in job_runs if jr.status == JobRunStatus.SUCCEEDED] + failed_job_runs = [jr for jr in job_runs if jr.status == JobRunStatus.FAILED] + assert len(successful_job_runs) == 1 + assert len(failed_job_runs) == 1 diff --git a/examples/on_failure/worker.py b/examples/on_failure/worker.py index e1b36e3e..804203a3 100644 --- a/examples/on_failure/worker.py +++ b/examples/on_failure/worker.py @@ -11,14 +11,18 @@ @hatchet.workflow(on_events=["user:create"]) class OnFailureWorkflow: - @hatchet.step() + @hatchet.step(timeout="1s") def step1(self, context: Context): raise Exception("step1 failed") @hatchet.on_failure_step() - def on_failure(self, context): + def on_failure(self, context: Context): + failures = context.fetch_run_failures() print("executed on_failure") - print(context) + print(json.dumps(failures, indent=2)) + if len(failures) == 1 and "step1 failed" in failures[0]["error"]: + return {"status": "success"} + raise Exception("unexpected failure") def main():