diff --git a/notebooks/api/0.8/11-container-images-k8s.ipynb b/notebooks/api/0.8/11-container-images-k8s.ipynb index 07c26a78b6a..c28b0187567 100644 --- a/notebooks/api/0.8/11-container-images-k8s.ipynb +++ b/notebooks/api/0.8/11-container-images-k8s.ipynb @@ -74,6 +74,14 @@ "domain_client" ] }, + { + "cell_type": "markdown", + "id": "fe3d0aa7", + "metadata": {}, + "source": [ + "### Scaling Default Worker Pool" + ] + }, { "cell_type": "markdown", "id": "55439eb5-1e92-46a6-a45a-471917a86265", @@ -92,6 +100,101 @@ "domain_client.worker_pools" ] }, + { + "cell_type": "markdown", + "id": "0ff8e268", + "metadata": {}, + "source": [ + "Scale up to 3 workers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de9872be", + "metadata": {}, + "outputs": [], + "source": [ + "result = domain_client.api.services.worker_pool.scale(\n", + " number=3, pool_name=\"default-pool\"\n", + ")\n", + "assert not isinstance(result, sy.SyftError), str(result)\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da6a499b", + "metadata": {}, + "outputs": [], + "source": [ + "result = domain_client.api.services.worker_pool.get_by_name(pool_name=\"default-pool\")\n", + "assert len(result.workers) == 3, str(result.to_dict())\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27761f0c", + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "# wait for some time for scale up to be ready\n", + "from time import sleep\n", + "\n", + "sleep(5)" + ] + }, + { + "cell_type": "markdown", + "id": "c1276b5c", + "metadata": {}, + "source": [ + "Scale down to 1 worker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f0aa94c", + "metadata": {}, + "outputs": [], + "source": [ + "default_worker_pool = domain_client.api.services.worker_pool.scale(\n", + " number=1, pool_name=\"default-pool\"\n", + ")\n", + "assert not isinstance(result, sy.SyftError), str(result)\n", + "default_worker_pool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52acc6f6", + "metadata": {}, + "outputs": [], + "source": [ + "result = domain_client.api.services.worker_pool.get_by_name(pool_name=\"default-pool\")\n", + "assert len(result.workers) == 1, str(result.to_dict())\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a7b40a3", + "metadata": {}, + "outputs": [], + "source": [ + "default_worker_pool = domain_client.api.services.worker_pool.get_by_name(\n", + " pool_name=\"default-pool\"\n", + ")\n", + "default_worker_pool" + ] + }, { "cell_type": "markdown", "id": "3c7a124a", @@ -1153,7 +1256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/custom_worker/builder_docker.py b/packages/syft/src/syft/custom_worker/builder_docker.py index 9446a5530a4..3f7f16cf185 100644 --- a/packages/syft/src/syft/custom_worker/builder_docker.py +++ b/packages/syft/src/syft/custom_worker/builder_docker.py @@ -9,6 +9,7 @@ import docker # relative +from .builder_types import BUILD_IMAGE_TIMEOUT_SEC from .builder_types import BuilderBase from .builder_types import ImageBuildResult from .builder_types import ImagePushResult @@ -18,8 +19,6 @@ class DockerBuilder(BuilderBase): - BUILD_MAX_WAIT = 30 * 60 - def build_image( self, tag: str, @@ -40,7 +39,7 @@ def build_image( with contextlib.closing(docker.from_env()) as client: image_result, logs = client.images.build( tag=tag, - timeout=self.BUILD_MAX_WAIT, + timeout=BUILD_IMAGE_TIMEOUT_SEC, buildargs=buildargs, **kwargs, ) diff --git a/packages/syft/src/syft/custom_worker/builder_k8s.py b/packages/syft/src/syft/custom_worker/builder_k8s.py index 395028d69b4..1be16d3c0ac 100644 --- a/packages/syft/src/syft/custom_worker/builder_k8s.py +++ b/packages/syft/src/syft/custom_worker/builder_k8s.py @@ -11,9 +11,11 @@ from kr8s.objects import Secret # relative +from .builder_types import BUILD_IMAGE_TIMEOUT_SEC from .builder_types import BuilderBase from .builder_types import ImageBuildResult from .builder_types import ImagePushResult +from .builder_types import PUSH_IMAGE_TIMEOUT_SEC from .k8s import INTERNAL_REGISTRY_HOST from .k8s import JOB_COMPLETION_TTL from .k8s import KUBERNETES_NAMESPACE @@ -66,7 +68,10 @@ def build_image( ) # wait for job to complete/fail - job.wait(["condition=Complete", "condition=Failed"]) + job.wait( + ["condition=Complete", "condition=Failed"], + timeout=BUILD_IMAGE_TIMEOUT_SEC, + ) # get logs logs = self._get_logs(job) @@ -119,7 +124,10 @@ def push_image( push_secret=push_secret, ) - job.wait(["condition=Complete", "condition=Failed"]) + job.wait( + ["condition=Complete", "condition=Failed"], + timeout=PUSH_IMAGE_TIMEOUT_SEC, + ) exit_code = self._get_exit_code(job)[0] logs = self._get_logs(job) except Exception: diff --git a/packages/syft/src/syft/custom_worker/builder_types.py b/packages/syft/src/syft/custom_worker/builder_types.py index 53c27788791..8007bf476e9 100644 --- a/packages/syft/src/syft/custom_worker/builder_types.py +++ b/packages/syft/src/syft/custom_worker/builder_types.py @@ -7,7 +7,17 @@ # third party from pydantic import BaseModel -__all__ = ["BuilderBase", "ImageBuildResult", "ImagePushResult"] +__all__ = [ + "BuilderBase", + "ImageBuildResult", + "ImagePushResult", + "BUILD_IMAGE_TIMEOUT_SEC", + "PUSH_IMAGE_TIMEOUT_SEC", +] + + +BUILD_IMAGE_TIMEOUT_SEC = 30 * 60 +PUSH_IMAGE_TIMEOUT_SEC = 10 * 60 class ImageBuildResult(BaseModel): diff --git a/packages/syft/src/syft/custom_worker/runner_k8s.py b/packages/syft/src/syft/custom_worker/runner_k8s.py index ff2c3120ebb..3b35830c0f4 100644 --- a/packages/syft/src/syft/custom_worker/runner_k8s.py +++ b/packages/syft/src/syft/custom_worker/runner_k8s.py @@ -16,6 +16,8 @@ from .k8s import get_kr8s_client JSONPATH_AVAILABLE_REPLICAS = "{.status.availableReplicas}" +CREATE_POOL_TIMEOUT_SEC = 60 +SCALE_POOL_TIMEOUT_SEC = 60 class KubernetesRunner: @@ -57,7 +59,10 @@ def create_pool( ) # wait for replicas to be available and ready - deployment.wait(f"jsonpath='{JSONPATH_AVAILABLE_REPLICAS}'={replicas}") + deployment.wait( + f"jsonpath='{JSONPATH_AVAILABLE_REPLICAS}'={replicas}", + timeout=CREATE_POOL_TIMEOUT_SEC, + ) except Exception: raise finally: @@ -72,9 +77,15 @@ def scale_pool(self, pool_name: str, replicas: int) -> Optional[StatefulSet]: if not deployment: return None deployment.scale(replicas) - deployment.wait(f"jsonpath='{JSONPATH_AVAILABLE_REPLICAS}'={replicas}") + deployment.wait( + f"jsonpath='{JSONPATH_AVAILABLE_REPLICAS}'={replicas}", + timeout=SCALE_POOL_TIMEOUT_SEC, + ) return deployment + def exists(self, pool_name: str) -> bool: + return bool(self.get_pool(pool_name)) + def get_pool(self, pool_name: str) -> Optional[StatefulSet]: selector = {"app.kubernetes.io/component": pool_name} for _set in self.client.get("statefulsets", label_selector=selector): diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 14b5799d825..fa216e4e6f7 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -384,7 +384,7 @@ def run_workers_in_kubernetes( spawn_status = [] runner = KubernetesRunner() - if start_idx == 0: + if not runner.exists(pool_name=pool_name): pool_pods = create_kubernetes_pool( runner=runner, tag=worker_image.image_identifier.full_name_with_tag, diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 806881547da..11e83b01112 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -11,6 +11,7 @@ from ...custom_worker.config import CustomWorkerConfig from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES +from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...store.linked_obj import LinkedObject @@ -36,6 +37,7 @@ from .utils import get_orchestration_type from .utils import run_containers from .utils import run_workers_in_threads +from .utils import scale_kubernetes_pool from .worker_image import SyftWorkerImage from .worker_image_stash import SyftWorkerImageStash from .worker_pool import ContainerSpawnStatus @@ -430,6 +432,94 @@ def add_workers( return container_statuses + @service_method( + path="worker_pool.scale", + name="scale", + roles=DATA_OWNER_ROLE_LEVEL, + ) + def scale( + self, + context: AuthedServiceContext, + number: int, + pool_id: Optional[UID] = None, + pool_name: Optional[str] = None, + ): + """ + Scale the worker pool to the given number of workers in Kubernetes. + Allows both scaling up and down the worker pool. + """ + + if not IN_KUBERNETES: + return SyftError(message="Scaling is only supported in Kubernetes mode") + elif number < 0: + # zero is a valid scale down + return SyftError(message=f"Invalid number of workers: {number}") + + result = self._get_worker_pool(context, pool_id, pool_name) + if isinstance(result, SyftError): + return result + + worker_pool = result + current_worker_count = len(worker_pool.worker_list) + + if current_worker_count == number: + return SyftSuccess(message=f"Worker pool already has {number} workers") + elif number > current_worker_count: + workers_to_add = number - current_worker_count + return self.add_workers( + context=context, + number=workers_to_add, + pool_id=pool_id, + pool_name=pool_name, + # kube scaling doesn't require password as it replicates an existing deployment + reg_username=None, + reg_password=None, + ) + else: + # scale down at kubernetes control plane + runner = KubernetesRunner() + result = scale_kubernetes_pool( + runner, + pool_name=worker_pool.name, + replicas=number, + ) + if isinstance(result, SyftError): + return result + + # scale down removes the last "n" workers + # workers to delete = len(workers) - number + workers_to_delete = worker_pool.worker_list[ + -(current_worker_count - number) : + ] + + worker_stash = context.node.get_service("WorkerService").stash + # delete linkedobj workers + for worker in workers_to_delete: + delete_result = worker_stash.delete_by_uid( + credentials=context.credentials, + uid=worker.object_uid, + ) + if delete_result.is_err(): + print(f"Failed to delete worker: {worker.object_uid}") + + # update worker_pool + worker_pool.max_count = number + worker_pool.worker_list = worker_pool.worker_list[:number] + update_result = self.stash.update( + credentials=context.credentials, + obj=worker_pool, + ) + + if update_result.is_err(): + return SyftError( + message=( + f"Pool {worker_pool.name} was scaled down, " + f"but failed update the stash with err: {result.err()}" + ) + ) + + return SyftSuccess(message=f"Worker pool scaled to {number} workers") + @service_method( path="worker_pool.filter_by_image_id", name="filter_by_image_id", diff --git a/tox.ini b/tox.ini index 570bad4c0e4..0362e12bfcb 100644 --- a/tox.ini +++ b/tox.ini @@ -783,19 +783,9 @@ commands = # ignore 06 because of opendp on arm64 # Run 0.8 notebooks - bash -c 'echo Gateway Cluster Info; kubectl describe all -A --context k3d-testgateway1 --namespace testgateway1' - bash -c 'echo Gateway Logs; kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testgateway1 --namespace testgateway1' - bash -c 'echo Domain Cluster Info; kubectl describe all -A --context k3d-testdomain1 --namespace testdomain1' - bash -c 'echo Domain Logs; kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testdomain1 --namespace testdomain1' - bash -c " source ./scripts/get_k8s_secret_ci.sh; \ pytest --nbmake notebooks/api/0.8 -p no:randomly -k 'not 10-container-images.ipynb' -vvvv --nbmake-timeout=1000" - bash -c 'echo Gateway Cluster Info; kubectl describe all -A --context k3d-testgateway1 --namespace testgateway1' - bash -c 'echo Gateway Logs; kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testgateway1 --namespace testgateway1' - bash -c 'echo Domain Cluster Info; kubectl describe all -A --context k3d-testdomain1 --namespace testdomain1' - bash -c 'echo Domain Logs; kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testdomain1 --namespace testdomain1' - #Integration + Gateway Connection Tests # Gateway tests are not run in kuberetes, as currently,it does not have a way to configure # high/low side warning flag.