Skip to content

Commit

Permalink
Improve Usage of Sticky, Labels, and Durable (#152)
Browse files Browse the repository at this point in the history
* Ensure Labels and Sticky are Usable within v2 Functions

* Cursed Example

* Fixes

---------

Co-authored-by: srhinos <[email protected]>
  • Loading branch information
macwilk and srhinos authored Aug 23, 2024
1 parent f70e09b commit 803385e
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 16 deletions.
72 changes: 72 additions & 0 deletions examples/durable_sticky_with_affinity/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import asyncio

from dotenv import load_dotenv

from hatchet_sdk import Context, StickyStrategy, WorkerLabelComparator
from hatchet_sdk.v2.callable import DurableContext
from hatchet_sdk.v2.hatchet import Hatchet

load_dotenv()

hatchet = Hatchet(debug=True)


@hatchet.durable(
sticky=StickyStrategy.HARD,
desired_worker_labels={
"running_workflow": {
"value": "True",
"required": True,
"comparator": WorkerLabelComparator.NOT_EQUAL,
},
},
)
async def my_durable_func(context: DurableContext):
try:
ref = await context.aio.spawn_workflow(
"StickyChildWorkflow", {}, options={"sticky": True}
)
result = await ref.result()
except Exception as e:
result = str(e)

await context.worker.async_upsert_labels({"running_workflow": "False"})
return {"worker_result": result}


@hatchet.workflow(on_events=["sticky:child"], sticky=StickyStrategy.HARD)
class StickyChildWorkflow:
@hatchet.step(
desired_worker_labels={
"running_workflow": {
"value": "True",
"required": True,
"comparator": WorkerLabelComparator.NOT_EQUAL,
},
},
)
async def child(self, context: Context):
await context.worker.async_upsert_labels({"running_workflow": "True"})

print(f"Heavy work started on {context.worker.id()}")
await asyncio.sleep(15)
print(f"Finished Heavy work on {context.worker.id()}")

return {"worker": context.worker.id()}


def main():

worker = hatchet.worker(
"sticky-worker",
max_runs=10,
labels={"running_workflow": "False"},
)

worker.register_workflow(StickyChildWorkflow())

worker.start()


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions hatchet_sdk/clients/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,22 @@ def upsert_worker_labels(self, worker_id: str, labels: dict[str, str | int]):
timeout=DEFAULT_REGISTER_TIMEOUT,
metadata=get_metadata(self.token),
)

async def async_upsert_worker_labels(
self,
worker_id: str,
labels: dict[str, str | int],
):
worker_labels = {}

for key, value in labels.items():
if isinstance(value, int):
worker_labels[key] = WorkerLabels(intValue=value)
else:
worker_labels[key] = WorkerLabels(strValue=str(value))

await self.aio_client.UpsertWorkerLabels(
UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
timeout=DEFAULT_REGISTER_TIMEOUT,
metadata=get_metadata(self.token),
)
7 changes: 6 additions & 1 deletion hatchet_sdk/context/worker_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ def labels(self):
return self._labels

def upsert_labels(self, labels: dict[str, str | int]):
return self.client.upsert_worker_labels(self._worker_id, labels)
self.client.upsert_worker_labels(self._worker_id, labels)
self._labels.update(labels)

async def async_upsert_labels(self, labels: dict[str, str | int]):
await self.client.async_upsert_worker_labels(self._worker_id, labels)
self._labels.update(labels)

def id(self):
return self._worker_id
Expand Down
17 changes: 5 additions & 12 deletions hatchet_sdk/v2/callable.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
import asyncio
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
TypedDict,
TypeVar,
Union,
)
from typing import Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union

from hatchet_sdk.context import Context
from hatchet_sdk.contracts.workflows_pb2 import (
Expand All @@ -18,6 +8,7 @@
CreateWorkflowStepOpts,
CreateWorkflowVersionOpts,
DesiredWorkerLabels,
StickyStrategy,
WorkflowConcurrencyOpts,
WorkflowKind,
)
Expand All @@ -41,6 +32,7 @@ def __init__(
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
sticky: StickyStrategy = None,
retries: int = 0,
rate_limits: List[RateLimit] | None = None,
concurrency: ConcurrencyFunction | None = None,
Expand Down Expand Up @@ -70,7 +62,7 @@ def __init__(
weight=d["weight"] if "weight" in d else None,
comparator=d["comparator"] if "comparator" in d else None,
)

self.sticky = sticky
self.durable = durable
self.function_name = name.lower() or str(func.__name__).lower()
self.function_version = version
Expand Down Expand Up @@ -131,6 +123,7 @@ def to_workflow_opts(self) -> CreateWorkflowVersionOpts:
event_triggers=self.function_on_events,
cron_triggers=self.function_on_crons,
schedule_timeout=self.function_schedule_timeout,
sticky=self.sticky,
on_failure_job=on_failure_job,
concurrency=concurrency,
jobs=[
Expand Down
24 changes: 21 additions & 3 deletions hatchet_sdk/v2/hatchet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, List, Optional, TypeVar
from typing import Callable, List, Optional, TypeVar

from hatchet_sdk.context import Context
from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy
from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy
from hatchet_sdk.hatchet import Hatchet as HatchetV1
from hatchet_sdk.hatchet import workflow
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk.rate_limit import RateLimit
from hatchet_sdk.v2.callable import HatchetCallable
from hatchet_sdk.v2.concurrency import ConcurrencyFunction
Expand All @@ -22,8 +23,10 @@ def function(
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
sticky: StickyStrategy = None,
retries: int = 0,
rate_limits: List[RateLimit] | None = None,
desired_worker_labels: dict[str:DesiredWorkerLabel] = {},
concurrency: ConcurrencyFunction | None = None,
on_failure: Optional["HatchetCallable"] = None,
):
Expand All @@ -37,8 +40,10 @@ def inner(func: Callable[[Context], T]) -> HatchetCallable[T]:
version=version,
timeout=timeout,
schedule_timeout=schedule_timeout,
sticky=sticky,
retries=retries,
rate_limits=rate_limits,
desired_worker_labels=desired_worker_labels,
concurrency=concurrency,
on_failure=on_failure,
)
Expand All @@ -54,8 +59,10 @@ def durable(
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
sticky: StickyStrategy = None,
retries: int = 0,
rate_limits: List[RateLimit] | None = None,
desired_worker_labels: dict[str:DesiredWorkerLabel] = {},
concurrency: ConcurrencyFunction | None = None,
on_failure: HatchetCallable | None = None,
):
Expand All @@ -70,8 +77,10 @@ def inner(func: HatchetCallable) -> HatchetCallable:
version=version,
timeout=timeout,
schedule_timeout=schedule_timeout,
sticky=sticky,
retries=retries,
rate_limits=rate_limits,
desired_worker_labels=desired_worker_labels,
concurrency=concurrency,
on_failure=on_failure,
)
Expand Down Expand Up @@ -113,6 +122,7 @@ def function(
schedule_timeout: str = "5m",
retries: int = 0,
rate_limits: List[RateLimit] | None = None,
desired_worker_labels: dict[str:DesiredWorkerLabel] = {},
concurrency: ConcurrencyFunction | None = None,
on_failure: Optional["HatchetCallable"] = None,
):
Expand All @@ -126,6 +136,7 @@ def function(
schedule_timeout=schedule_timeout,
retries=retries,
rate_limits=rate_limits,
desired_worker_labels=desired_worker_labels,
concurrency=concurrency,
on_failure=on_failure,
)
Expand All @@ -151,8 +162,10 @@ def durable(
version: str = "",
timeout: str = "60m",
schedule_timeout: str = "5m",
sticky: StickyStrategy = None,
retries: int = 0,
rate_limits: List[RateLimit] | None = None,
desired_worker_labels: dict[str:DesiredWorkerLabel] = {},
concurrency: ConcurrencyFunction | None = None,
on_failure: Optional["HatchetCallable"] = None,
) -> Callable[[HatchetCallable], HatchetCallable]:
Expand All @@ -164,8 +177,10 @@ def durable(
version=version,
timeout=timeout,
schedule_timeout=schedule_timeout,
sticky=sticky,
retries=retries,
rate_limits=rate_limits,
desired_worker_labels=desired_worker_labels,
concurrency=concurrency,
on_failure=on_failure,
)
Expand All @@ -182,10 +197,13 @@ def wrapper(func: Callable[[Context], T]) -> HatchetCallable[T]:

return wrapper

def worker(self, name: str, max_runs: int | None = None):
def worker(
self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {}
):
worker = Worker(
name=name,
max_runs=max_runs,
labels=labels,
config=self._client.config,
debug=self._client.debug,
)
Expand Down
2 changes: 2 additions & 0 deletions hatchet_sdk/worker/action_listener_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class WorkerActionListenerProcess:
event_queue: Queue
handle_kill: bool = True
debug: bool = False
labels: dict = field(default_factory=dict)

listener: ActionListener = field(init=False, default=None)

Expand Down Expand Up @@ -93,6 +94,7 @@ async def start(self, retry_attempt=0):
services=["default"],
actions=self.actions,
max_runs=self.max_runs,
_labels=self.labels,
)
)
)
Expand Down
1 change: 1 addition & 0 deletions hatchet_sdk/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _start_listener(self):
self.event_queue,
self.handle_kill,
self.client.debug,
self.labels,
),
)
process.start()
Expand Down

0 comments on commit 803385e

Please sign in to comment.