Skip to content

Commit c42f773

Browse files
committed
Add support for all local containerized orchestrators
1 parent 3646195 commit c42f773

File tree

3 files changed

+29
-17
lines changed

3 files changed

+29
-17
lines changed

src/zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ class AirflowOrchestratorConfig(
119119

120120
local: bool = True
121121

122+
@property
123+
def is_local(self) -> bool:
124+
"""Checks if this stack component is running locally.
125+
126+
Returns:
127+
True if this config is for a local component, False otherwise.
128+
"""
129+
return self.local
130+
122131
@property
123132
def is_schedulable(self) -> bool:
124133
"""Whether the orchestrator is schedulable or not.

src/zenml/orchestrators/containerized_orchestrator.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
"""Containerized orchestrator class."""
1515

1616
from abc import ABC
17-
from typing import List, Optional
17+
from typing import List, Optional, Set
1818

19+
import zenml
1920
from zenml.config.build_configuration import BuildConfiguration
21+
from zenml.config.global_config import GlobalConfiguration
2022
from zenml.constants import ORCHESTRATOR_DOCKER_IMAGE_KEY
2123
from zenml.models import PipelineSnapshotBase, PipelineSnapshotResponse
2224
from zenml.orchestrators import BaseOrchestrator
@@ -25,6 +27,22 @@
2527
class ContainerizedOrchestrator(BaseOrchestrator, ABC):
2628
"""Base class for containerized orchestrators."""
2729

30+
@property
31+
def requirements(self) -> Set[str]:
32+
"""Set of PyPI requirements for the component.
33+
34+
Returns:
35+
A set of PyPI requirements for the component.
36+
"""
37+
requirements = super().requirements
38+
39+
if self.config.is_local and GlobalConfiguration().uses_sql_store:
40+
# If we're directly connected to a DB, we need to install the
41+
# `local` extra in the Docker image to include the DB dependencies.
42+
requirements.add(f"'zenml[local]=={zenml.__version__}'")
43+
44+
return requirements
45+
2846
@staticmethod
2947
def get_image(
3048
snapshot: "PipelineSnapshotResponse",

src/zenml/orchestrators/local_docker/local_docker_orchestrator.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
import os
1818
import sys
1919
import time
20-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, cast
20+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast
2121
from uuid import uuid4
2222

2323
from docker.errors import ContainerError
2424

25-
import zenml
2625
from zenml.config.base_settings import BaseSettings
2726
from zenml.config.global_config import GlobalConfiguration
2827
from zenml.constants import (
@@ -84,20 +83,6 @@ def validator(self) -> Optional[StackValidator]:
8483
required_components={StackComponentType.IMAGE_BUILDER}
8584
)
8685

87-
@property
88-
def requirements(self) -> Set[str]:
89-
"""Set of PyPI requirements for the component.
90-
91-
Returns:
92-
A set of PyPI requirements for the component.
93-
"""
94-
if GlobalConfiguration().uses_sql_store:
95-
# If we're directly connected to a DB, we need to install the
96-
# `local` extra in the Docker image to include the DB dependencies.
97-
return {f'"zenml[local]=={zenml.__version__}"'}
98-
99-
return set()
100-
10186
def get_orchestrator_run_id(self) -> str:
10287
"""Returns the active orchestrator run id.
10388

0 commit comments

Comments
 (0)