Skip to content

Commit

Permalink
MyPy Fixes for Worker, Workflow metaclass, Context, and some ot…
Browse files Browse the repository at this point in the history
…hers (#273)

* fix: lint

* feat: fixing mypy for the worker class

* cleanup: print statements, unused imports

* feat: exclude examples explicitly

* chore: version

* MyPy, Part II or III: `Workflow` metaclass (#274)

* feat: fix mypy for Workflow

* fix: type stub

* fix: remove any in a couple places

* fix: attempting using `Union`

* fix: try with `.Value`

* fix: name attr

* fix: remove more type: ignore

* MyPy for Context, etc. (#276)

* fix: types for config, etc.

* fix: queue type hint

* fix: typeddict items

* fix: install all the deps

* fix: backoff factors

* Feat: First Pass at Pydantic (#275)

* fix: types for config, etc.

* fix: typeddict items

* feat: initial pass at pydantic support

* fix: types

* fix: test

* fix: more types

* debug: adding prints, fixing validator not present

* fix: function sigs

* fix: typed dict key

* fix: comment

* feat: clean up validators

* cleanup: import

* cleanup: type hint improvements

* feat: add simple example

* feat: type casts

* fix: tests and example

* fix: worker fixture

* fix: deprecation warning for overrides

* feat: try out using a validator registry, part I

* feat: validator registry, part ii

* feat: clean up pydantic example

* fix: tests and lint

* fix: one more type cast

* cleanup: cruft and print statements

* fix: worker input val and validator key

* feat: expand tests a little

* fix: more typing

* feat: documentation updates for Pydantic

* cleanup: cruft
  • Loading branch information
hatchet-temporary authored Dec 6, 2024
1 parent e504fda commit 960846a
Show file tree
Hide file tree
Showing 24 changed files with 578 additions and 245 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
virtualenvs-in-project: true

- name: Install linting tools
run: poetry install --no-root --only lint
run: poetry install --no-root

- name: Run Black
run: poetry run black . --check --verbose --diff --color
Expand Down
30 changes: 30 additions & 0 deletions examples/pydantic/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from hatchet_sdk import Hatchet


# requires scope module or higher for shared event loop
@pytest.mark.asyncio(scope="session")
@pytest.mark.parametrize("worker", ["pydantic"], indirect=True)
async def test_run_validation_error(hatchet: Hatchet, worker):
run = hatchet.admin.run_workflow(
"Parent",
{},
)

with pytest.raises(Exception, match="1 validation error for ParentInput"):
await run.result()


# requires scope module or higher for shared event loop
@pytest.mark.asyncio(scope="session")
@pytest.mark.parametrize("worker", ["pydantic"], indirect=True)
async def test_run(hatchet: Hatchet, worker):
run = hatchet.admin.run_workflow(
"Parent",
{"x": "foobar"},
)

result = await run.result()

assert len(result["spawn"]) == 3
19 changes: 19 additions & 0 deletions examples/pydantic/trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import asyncio

from dotenv import load_dotenv

from hatchet_sdk import new_client


async def main():
load_dotenv()
hatchet = new_client()

hatchet.admin.run_workflow(
"Parent",
{"x": "foo bar baz"},
)


if __name__ == "__main__":
asyncio.run(main())
82 changes: 82 additions & 0 deletions examples/pydantic/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import cast

from dotenv import load_dotenv
from pydantic import BaseModel

from hatchet_sdk import Context, Hatchet

load_dotenv()

hatchet = Hatchet(debug=True)


# ❓ Pydantic
# This workflow shows example usage of Pydantic within Hatchet
class ParentInput(BaseModel):
x: str


@hatchet.workflow(input_validator=ParentInput)
class Parent:
@hatchet.step(timeout="5m")
async def spawn(self, context: Context):
## Use `typing.cast` to cast your `workflow_input`
## to the type of your `input_validator`
input = cast(ParentInput, context.workflow_input()) ## This is a `ParentInput`

child = await context.aio.spawn_workflow(
"Child",
{"a": 1, "b": "10"},
)

return await child.result()


class ChildInput(BaseModel):
a: int
b: int


class StepResponse(BaseModel):
status: str


@hatchet.workflow(input_validator=ChildInput)
class Child:
@hatchet.step()
def process(self, context: Context) -> StepResponse:
## This is an instance `ChildInput`
input = cast(ChildInput, context.workflow_input())

return StepResponse(status="success")

@hatchet.step(parents=["process"])
def process2(self, context: Context) -> StepResponse:
## This is an instance of `StepResponse`
process_output = cast(StepResponse, context.step_output("process"))

return {"status": "step 2 - success"}

@hatchet.step(parents=["process2"])
def process3(self, context: Context) -> StepResponse:
## This is an instance of `StepResponse`, even though the
## response of `process2` was a dictionary. Note that
## Hatchet will attempt to parse that dictionary into
## an object of type `StepResponse`
process_2_output = cast(StepResponse, context.step_output("process2"))

return StepResponse(status="step 3 - success")


# ‼️


def main():
worker = hatchet.worker("pydantic-worker")
worker.register_workflow(Parent())
worker.register_workflow(Child())
worker.start()


if __name__ == "__main__":
main()
18 changes: 10 additions & 8 deletions hatchet_sdk/clients/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,10 @@ class ChildTriggerWorkflowOptions(TypedDict):
sticky: bool | None = None


class WorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: Optional[dict]


class ChildWorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: ChildTriggerWorkflowOptions[dict]
options: ChildTriggerWorkflowOptions
key: str | None = None


Expand All @@ -73,6 +67,12 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
namespace: str | None = None


class WorkflowRunDict(TypedDict):
workflow_name: str
input: Any
options: TriggerWorkflowOptions | None


class DedupeViolationErr(Exception):
"""Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""

Expand Down Expand Up @@ -260,7 +260,9 @@ async def run_workflow(

@tenacity_retry
async def run_workflows(
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
self,
workflows: list[WorkflowRunDict],
options: TriggerWorkflowOptions | None = None,
) -> List[WorkflowRunRef]:
if len(workflows) == 0:
raise ValueError("No workflows to run")
Expand Down
2 changes: 1 addition & 1 deletion hatchet_sdk/clients/dispatcher/action_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Action:
worker_id: str
tenant_id: str
workflow_run_id: str
get_group_key_run_id: Optional[str]
get_group_key_run_id: str
job_id: str
job_name: str
job_run_id: str
Expand Down
4 changes: 2 additions & 2 deletions hatchet_sdk/clients/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def put_overrides_data(self, data: OverridesData):

return response

def release_slot(self, step_run_id: str):
def release_slot(self, step_run_id: str) -> None:
self.client.ReleaseSlot(
ReleaseSlotRequest(stepRunId=step_run_id),
timeout=DEFAULT_REGISTER_TIMEOUT,
metadata=get_metadata(self.token),
)

def refresh_timeout(self, step_run_id: str, increment_by: str):
def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
self.client.RefreshTimeout(
RefreshTimeoutRequest(
stepRunId=step_run_id,
Expand Down
7 changes: 6 additions & 1 deletion hatchet_sdk/clients/rest/tenacity_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import Callable, ParamSpec, TypeVar

import grpc
import tenacity

from hatchet_sdk.logger import logger

P = ParamSpec("P")
R = TypeVar("R")


def tenacity_retry(func):
def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]:
return tenacity.retry(
reraise=True,
wait=tenacity.wait_exponential_jitter(),
Expand Down
Loading

0 comments on commit 960846a

Please sign in to comment.