|
5 | 5 | """
|
6 | 6 | from __future__ import annotations
|
7 | 7 |
|
| 8 | +import time |
8 | 9 | from pathlib import Path
|
9 | 10 | from types import ModuleType
|
10 | 11 | from typing import Any, Dict, List, Optional, Union
|
|
44 | 45 | WorkflowLintRequest,
|
45 | 46 | WorkflowMetadata,
|
46 | 47 | WorkflowSpec as _ModelWorkflowSpec,
|
47 |
| - WorkflowStatus, |
| 48 | + WorkflowStatus as _ModelWorkflowStatus, |
48 | 49 | WorkflowTemplateRef,
|
49 | 50 | )
|
50 | 51 | from hera.workflows.protocol import Templatable, TTemplate, TWorkflow, VolumeClaimable
|
51 | 52 | from hera.workflows.service import WorkflowsService
|
| 53 | +from hera.workflows.workflow_status import WorkflowStatus |
52 | 54 |
|
53 | 55 | _yaml: Optional[ModuleType] = None
|
54 | 56 | try:
|
@@ -82,7 +84,7 @@ class Workflow(
|
82 | 84 | # Workflow fields - https://argoproj.github.io/argo-workflows/fields/#workflow
|
83 | 85 | api_version: Optional[str] = None
|
84 | 86 | kind: Optional[str] = None
|
85 |
| - status: Optional[WorkflowStatus] = None |
| 87 | + status: Optional[_ModelWorkflowStatus] = None |
86 | 88 |
|
87 | 89 | # ObjectMeta fields - https://argoproj.github.io/argo-workflows/fields/#objectmeta
|
88 | 90 | annotations: Optional[Dict[str, str]] = None
|
@@ -330,13 +332,58 @@ def to_yaml(self, *args, **kwargs) -> str:
|
330 | 332 | kwargs.setdefault("sort_keys", False)
|
331 | 333 | return _yaml.dump(self.to_dict(), *args, **kwargs)
|
332 | 334 |
|
333 |
| - def create(self) -> TWorkflow: |
334 |
| - """Creates the Workflow on the Argo cluster.""" |
| 335 | + def create(self, wait: bool = False, poll_interval: int = 5) -> TWorkflow: |
| 336 | + """Creates the Workflow on the Argo cluster. |
| 337 | +
|
| 338 | + Parameters |
| 339 | + ---------- |
| 340 | + wait: bool = False |
| 341 | + If true then the workflow is created asynchronously and the function returns immediately. |
| 342 | + If false then the workflow is created and the function blocks until the workflow is done executing. |
| 343 | + poll_interval: int = 5 |
| 344 | + The interval in seconds to poll the workflow status if wait is true. Ignored when wait is true. |
| 345 | + """ |
335 | 346 | assert self.workflows_service, "workflow service not initialized"
|
336 | 347 | assert self.namespace, "workflow namespace not defined"
|
337 |
| - return self.workflows_service.create_workflow( |
| 348 | + |
| 349 | + wf = self.workflows_service.create_workflow( |
338 | 350 | WorkflowCreateRequest(workflow=self.build()), namespace=self.namespace
|
339 | 351 | )
|
| 352 | + # set the workflow name to the name returned by the API, which helps cover the case of users relying on |
| 353 | + # `generate_name=True` |
| 354 | + self.name = wf.metadata.name |
| 355 | + |
| 356 | + if wait: |
| 357 | + return self.wait(poll_interval=poll_interval) |
| 358 | + return wf |
| 359 | + |
| 360 | + def wait(self, poll_interval: int = 5) -> TWorkflow: |
| 361 | + """Waits for the Workflow to complete execution. |
| 362 | +
|
| 363 | + Parameters |
| 364 | + ---------- |
| 365 | + poll_interval: int = 5 |
| 366 | + The interval in seconds to poll the workflow status. |
| 367 | + """ |
| 368 | + assert self.workflows_service is not None, "workflow service not initialized" |
| 369 | + assert self.namespace is not None, "workflow namespace not defined" |
| 370 | + assert self.name is not None, "workflow name not defined" |
| 371 | + |
| 372 | + wf = self.workflows_service.get_workflow(self.name, namespace=self.namespace) |
| 373 | + assert wf.metadata.name is not None, f"workflow name not defined for workflow {self.name}" |
| 374 | + |
| 375 | + assert wf.status is not None, f"workflow status not defined for workflow {wf.metadata.name}" |
| 376 | + assert wf.status.phase is not None, f"workflow phase not defined for workflow status {wf.status}" |
| 377 | + status = WorkflowStatus.from_argo_status(wf.status.phase) |
| 378 | + |
| 379 | + # keep polling for workflow status until completed, at the interval dictated by the user |
| 380 | + while status == WorkflowStatus.running: |
| 381 | + time.sleep(poll_interval) |
| 382 | + wf = self.workflows_service.get_workflow(wf.metadata.name, namespace=self.namespace) |
| 383 | + assert wf.status is not None, f"workflow status not defined for workflow {wf.metadata.name}" |
| 384 | + assert wf.status.phase is not None, f"workflow phase not defined for workflow status {wf.status}" |
| 385 | + status = WorkflowStatus.from_argo_status(wf.status.phase) |
| 386 | + return wf |
340 | 387 |
|
341 | 388 | def lint(self) -> TWorkflow:
|
342 | 389 | """Lints the Workflow using the Argo cluster."""
|
|
0 commit comments