diff --git a/examples/durable_sticky_with_affinity/worker.py b/examples/durable_sticky_with_affinity/worker.py new file mode 100644 index 00000000..ef18406d --- /dev/null +++ b/examples/durable_sticky_with_affinity/worker.py @@ -0,0 +1,72 @@ +import asyncio + +from dotenv import load_dotenv + +from hatchet_sdk import Context, StickyStrategy, WorkerLabelComparator +from hatchet_sdk.v2.callable import DurableContext +from hatchet_sdk.v2.hatchet import Hatchet + +load_dotenv() + +hatchet = Hatchet(debug=True) + + +@hatchet.durable( + sticky=StickyStrategy.HARD, + desired_worker_labels={ + "running_workflow": { + "value": "True", + "required": True, + "comparator": WorkerLabelComparator.NOT_EQUAL, + }, + }, +) +async def my_durable_func(context: DurableContext): + try: + ref = await context.aio.spawn_workflow( + "StickyChildWorkflow", {}, options={"sticky": True} + ) + result = await ref.result() + except Exception as e: + result = str(e) + + await context.worker.async_upsert_labels({"running_workflow": "False"}) + return {"worker_result": result} + + +@hatchet.workflow(on_events=["sticky:child"], sticky=StickyStrategy.HARD) +class StickyChildWorkflow: + @hatchet.step( + desired_worker_labels={ + "running_workflow": { + "value": "True", + "required": True, + "comparator": WorkerLabelComparator.NOT_EQUAL, + }, + }, + ) + async def child(self, context: Context): + await context.worker.async_upsert_labels({"running_workflow": "True"}) + + print(f"Heavy work started on {context.worker.id()}") + await asyncio.sleep(15) + print(f"Finished Heavy work on {context.worker.id()}") + + return {"worker": context.worker.id()} + + +def main(): + + worker = hatchet.worker( + "sticky-worker", + max_runs=10, + labels={"running_workflow": "False"}, + ) + + worker.register_workflow(StickyChildWorkflow()) + + worker.start() + + +if __name__ == "__main__": + main() diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index f9b3e40d..30f26494 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -145,3 +145,22 @@ def upsert_worker_labels(self, worker_id: str, labels: dict[str, str | int]): timeout=DEFAULT_REGISTER_TIMEOUT, metadata=get_metadata(self.token), ) + + async def async_upsert_worker_labels( + self, + worker_id: str, + labels: dict[str, str | int], + ): + worker_labels = {} + + for key, value in labels.items(): + if isinstance(value, int): + worker_labels[key] = WorkerLabels(intValue=value) + else: + worker_labels[key] = WorkerLabels(strValue=str(value)) + + await self.aio_client.UpsertWorkerLabels( + UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels), + timeout=DEFAULT_REGISTER_TIMEOUT, + metadata=get_metadata(self.token), + ) diff --git a/hatchet_sdk/context/worker_context.py b/hatchet_sdk/context/worker_context.py index 27771d2a..96cc76bc 100644 --- a/hatchet_sdk/context/worker_context.py +++ b/hatchet_sdk/context/worker_context.py @@ -14,7 +14,12 @@ def labels(self): return self._labels def upsert_labels(self, labels: dict[str, str | int]): - return self.client.upsert_worker_labels(self._worker_id, labels) + self.client.upsert_worker_labels(self._worker_id, labels) + self._labels.update(labels) + + async def async_upsert_labels(self, labels: dict[str, str | int]): + await self.client.async_upsert_worker_labels(self._worker_id, labels) + self._labels.update(labels) def id(self): return self._worker_id diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 7f92d430..c39a1c96 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -1,15 +1,5 @@ import asyncio -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - TypedDict, - TypeVar, - Union, -) +from typing import Callable, Dict, Generic, List, Optional, TypedDict, TypeVar, Union from hatchet_sdk.context import Context from hatchet_sdk.contracts.workflows_pb2 import ( @@ -18,6 +8,7 @@ CreateWorkflowStepOpts, CreateWorkflowVersionOpts, DesiredWorkerLabels, + StickyStrategy, WorkflowConcurrencyOpts, WorkflowKind, ) @@ -41,6 +32,7 @@ def __init__( version: str = "", timeout: str = "60m", schedule_timeout: str = "5m", + sticky: StickyStrategy = None, retries: int = 0, rate_limits: List[RateLimit] | None = None, concurrency: ConcurrencyFunction | None = None, @@ -70,7 +62,7 @@ def __init__( weight=d["weight"] if "weight" in d else None, comparator=d["comparator"] if "comparator" in d else None, ) - + self.sticky = sticky self.durable = durable self.function_name = name.lower() or str(func.__name__).lower() self.function_version = version @@ -131,6 +123,7 @@ def to_workflow_opts(self) -> CreateWorkflowVersionOpts: event_triggers=self.function_on_events, cron_triggers=self.function_on_crons, schedule_timeout=self.function_schedule_timeout, + sticky=self.sticky, on_failure_job=on_failure_job, concurrency=concurrency, jobs=[ diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 67e5a0d2..35e6ce92 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, List, Optional, TypeVar +from typing import Callable, List, Optional, TypeVar from hatchet_sdk.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy +from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy, StickyStrategy from hatchet_sdk.hatchet import Hatchet as HatchetV1 from hatchet_sdk.hatchet import workflow +from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.callable import HatchetCallable from hatchet_sdk.v2.concurrency import ConcurrencyFunction @@ -22,8 +23,10 @@ def function( version: str = "", timeout: str = "60m", schedule_timeout: str = "5m", + sticky: StickyStrategy = None, retries: int = 0, rate_limits: List[RateLimit] | None = None, + desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, concurrency: ConcurrencyFunction | None = None, on_failure: Optional["HatchetCallable"] = None, ): @@ -37,8 +40,10 @@ def inner(func: Callable[[Context], T]) -> HatchetCallable[T]: version=version, timeout=timeout, schedule_timeout=schedule_timeout, + sticky=sticky, retries=retries, rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, concurrency=concurrency, on_failure=on_failure, ) @@ -54,8 +59,10 @@ def durable( version: str = "", timeout: str = "60m", schedule_timeout: str = "5m", + sticky: StickyStrategy = None, retries: int = 0, rate_limits: List[RateLimit] | None = None, + desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, concurrency: ConcurrencyFunction | None = None, on_failure: HatchetCallable | None = None, ): @@ -70,8 +77,10 @@ def inner(func: HatchetCallable) -> HatchetCallable: version=version, timeout=timeout, schedule_timeout=schedule_timeout, + sticky=sticky, retries=retries, rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, concurrency=concurrency, on_failure=on_failure, ) @@ -113,6 +122,7 @@ def function( schedule_timeout: str = "5m", retries: int = 0, rate_limits: List[RateLimit] | None = None, + desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, concurrency: ConcurrencyFunction | None = None, on_failure: Optional["HatchetCallable"] = None, ): @@ -126,6 +136,7 @@ def function( schedule_timeout=schedule_timeout, retries=retries, rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, concurrency=concurrency, on_failure=on_failure, ) @@ -151,8 +162,10 @@ def durable( version: str = "", timeout: str = "60m", schedule_timeout: str = "5m", + sticky: StickyStrategy = None, retries: int = 0, rate_limits: List[RateLimit] | None = None, + desired_worker_labels: dict[str:DesiredWorkerLabel] = {}, concurrency: ConcurrencyFunction | None = None, on_failure: Optional["HatchetCallable"] = None, ) -> Callable[[HatchetCallable], HatchetCallable]: @@ -164,8 +177,10 @@ def durable( version=version, timeout=timeout, schedule_timeout=schedule_timeout, + sticky=sticky, retries=retries, rate_limits=rate_limits, + desired_worker_labels=desired_worker_labels, concurrency=concurrency, on_failure=on_failure, ) @@ -182,10 +197,13 @@ def wrapper(func: Callable[[Context], T]) -> HatchetCallable[T]: return wrapper - def worker(self, name: str, max_runs: int | None = None): + def worker( + self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + ): worker = Worker( name=name, max_runs=max_runs, + labels=labels, config=self._client.config, debug=self._client.debug, ) diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 8f768beb..33463f04 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -55,6 +55,7 @@ class WorkerActionListenerProcess: event_queue: Queue handle_kill: bool = True debug: bool = False + labels: dict = field(default_factory=dict) listener: ActionListener = field(init=False, default=None) @@ -93,6 +94,7 @@ async def start(self, retry_attempt=0): services=["default"], actions=self.actions, max_runs=self.max_runs, + _labels=self.labels, ) ) ) diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 10a219b4..315f2f4a 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -191,6 +191,7 @@ def _start_listener(self): self.event_queue, self.handle_kill, self.client.debug, + self.labels, ), ) process.start()