Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pullImageSecrets for authenticated registries #8444

Merged
merged 6 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ metadata:
app.kubernetes.io/managed-by: Helm
rules:
- apiGroups: [""]
resources: ["pods", "configmaps"]
resources: ["pods", "configmaps", "secrets"]
verbs: ["create", "get", "list", "watch", "update", "patch", "delete"]
- apiGroups: [""]
resources: ["pods/log"]
Expand Down
92 changes: 89 additions & 3 deletions packages/syft/src/syft/custom_worker/runner_k8s.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# stdlib
import base64
import copy
import json
import os
from time import sleep
from typing import List
Expand All @@ -10,6 +12,7 @@
import kr8s
from kr8s.objects import APIObject
from kr8s.objects import Pod
from kr8s.objects import Secret
from kr8s.objects import StatefulSet

# relative
Expand All @@ -27,16 +30,35 @@ def create_pool(
tag: str,
replicas: int = 1,
env_vars: Optional[dict] = None,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
**kwargs,
) -> StatefulSet:
# create pull secret if registry credentials are passed
pull_secret = None
if reg_username and reg_password and reg_url:
pull_secret = self._create_image_pull_secret(
pool_name,
reg_username,
reg_password,
reg_url,
)

# create a stateful set deployment
deployment = self._create_stateful_set(
pool_name,
tag,
replicas,
env_vars,
pull_secret=pull_secret,
**kwargs,
)

# wait for replicas to be available and ready
self.wait(deployment, available_replicas=replicas)

# return
return deployment

def scale_pool(self, pool_name: str, replicas: int) -> Optional[StatefulSet]:
Expand All @@ -57,8 +79,11 @@ def delete_pool(self, pool_name: str) -> bool:
selector = {"app.kubernetes.io/component": pool_name}
for _set in self.client.get("statefulsets", label_selector=selector):
_set.delete()
return True
return False

for _secret in self.client.get("secrets", label_selector=selector):
_secret.delete()

return True

def delete_pod(self, pod_name: str) -> bool:
pods = self.client.get("pods", pod_name)
Expand Down Expand Up @@ -99,7 +124,7 @@ def wait(
self,
deployment: StatefulSet,
available_replicas: int,
timeout: int = 60,
timeout: int = 300,
) -> None:
# TODO: Report wait('jsonpath=') bug to kr8s
# Until then this is the substitute implementation
Expand Down Expand Up @@ -133,17 +158,50 @@ def _get_obj_from_list(self, objs: List[dict], name: str) -> dict:
if obj.name == name:
return obj

def _create_image_pull_secret(
self,
pool_name: str,
reg_username: str,
reg_password: str,
reg_url: str,
**kwargs,
):
_secret = Secret(
{
"metadata": {
"name": f"pull-secret-{pool_name}",
"labels": {
"app.kubernetes.io/name": KUBERNETES_NAMESPACE,
"app.kubernetes.io/component": pool_name,
"app.kubernetes.io/managed-by": "kr8s",
},
},
"type": "kubernetes.io/dockerconfigjson",
"data": {
".dockerconfigjson": self._create_dockerconfig_json(
reg_username,
reg_password,
reg_url,
)
},
}
)

return self._create_or_get(_secret)

def _create_stateful_set(
self,
pool_name: str,
tag: str,
replicas=1,
env_vars: Optional[dict] = None,
pull_secret: Optional[Secret] = None,
**kwargs,
) -> StatefulSet:
"""Create a stateful set for a pool"""

env_vars = env_vars or {}
pull_secret_obj = None

_pod = Pod.get(self._current_pod_name())

Expand All @@ -170,6 +228,13 @@ def _create_stateful_set(
for k, v in env_vars.items():
env_clone.append({"name": k, "value": v})

if pull_secret:
pull_secret_obj = [
{
"name": pull_secret.name,
}
]

stateful_set = StatefulSet(
{
"metadata": {
Expand Down Expand Up @@ -198,12 +263,14 @@ def _create_stateful_set(
"containers": [
{
"name": pool_name,
"imagePullPolicy": "IfNotPresent",
"image": tag,
"env": env_clone,
"volumeMounts": [creds_volume_mount],
}
],
"volumes": [creds_volume],
"imagePullSecrets": pull_secret_obj,
},
},
},
Expand All @@ -217,3 +284,22 @@ def _create_or_get(self, obj: APIObject) -> APIObject:
else:
obj.refresh()
return obj

def _create_dockerconfig_json(
self,
reg_username: str,
reg_password: str,
reg_url: str,
):
config = {
"auths": {
reg_url: {
"username": reg_username,
"password": reg_password,
"auth": base64.b64encode(
f"{reg_username}:{reg_password}".encode()
).decode(),
}
}
}
return base64.b64encode(json.dumps(config).encode()).decode()
13 changes: 11 additions & 2 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ def _run(
push_result = worker_image_service.push(
service_context,
image=worker_image.id,
username=context.extra_kwargs.get("reg_username", None),
password=context.extra_kwargs.get("reg_password", None),
)

if isinstance(push_result, SyftError):
Expand Down Expand Up @@ -299,6 +301,8 @@ def _run(
name=self.pool_name,
image_uid=self.image_uid,
num_workers=self.num_workers,
reg_username=context.extra_kwargs.get("reg_username", None),
reg_password=context.extra_kwargs.get("reg_password", None),
)
if isinstance(result, SyftError):
return Err(result)
Expand Down Expand Up @@ -487,7 +491,12 @@ def status(self) -> RequestStatus:

return request_status

def approve(self, disable_warnings: bool = False, approve_nested: bool = False):
def approve(
self,
disable_warnings: bool = False,
approve_nested: bool = False,
**kwargs: dict,
):
api = APIRegistry.api_for(
self.node_uid,
self.syft_client_verify_key,
Expand Down Expand Up @@ -518,7 +527,7 @@ def approve(self, disable_warnings: bool = False, approve_nested: bool = False):
prompt_warning_message(message=message, confirm=True)

print(f"Approving request for domain {api.node_name}")
return api.services.request.apply(self.id)
return api.services.request.apply(self.id, **kwargs)

def deny(self, reason: str):
"""Denies the particular request.
Expand Down
7 changes: 6 additions & 1 deletion packages/syft/src/syft/service/request/request_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,16 @@ def filter_all_info(
name="apply",
)
def apply(
self, context: AuthedServiceContext, uid: UID
self,
context: AuthedServiceContext,
uid: UID,
**kwargs: dict,
) -> Union[SyftSuccess, SyftError]:
request = self.stash.get_by_uid(context.credentials, uid)
if request.is_ok():
request = request.ok()

context.extra_kwargs = kwargs
result = request.apply(context=context)

filter_by_obj = context.node.get_service_method(
Expand Down
22 changes: 18 additions & 4 deletions packages/syft/src/syft/service/worker/image_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# stdlib
from urllib.parse import urlparse

# third party
from pydantic import validator

# relative
from ...serde.serializable import serializable
from ...types.syft_object import SYFT_OBJECT_VERSION_1
Expand All @@ -18,13 +24,21 @@ class SyftImageRegistry(SyftObject):
id: UID
url: str

@validator("url")
def validate_url(cls, val: str):
if val.startswith("http") or "://" in val:
raise ValueError("Registry URL must be a valid RFC 3986 URI")
return val

@classmethod
def from_url(cls, full_str: str):
return cls(id=UID(), url=full_str)
if "://" not in full_str:
full_str = f"http://{full_str}"

parsed = urlparse(full_str)

@property
def tls_enabled(self) -> bool:
return self.url.startswith("https")
# netloc includes the host & port, so local dev should work as expected
return cls(id=UID(), url=parsed.netloc)

def __hash__(self) -> int:
return hash(self.url + str(self.tls_enabled))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ def add(
context: AuthedServiceContext,
url: str,
) -> Union[SyftSuccess, SyftError]:
registry = SyftImageRegistry.from_url(url)
try:
registry = SyftImageRegistry.from_url(url)
except Exception as e:
return SyftError(message=f"Failed to create registry. {e}")

res = self.stash.set(context.credentials, registry)

if res.is_err():
return SyftError(message=res.err())
return SyftError(message=f"Failed to create registry. {res.err()}")

return SyftSuccess(
message=f"Image Registry ID: {registry.id} created successfully"
Expand Down
43 changes: 28 additions & 15 deletions packages/syft/src/syft/service/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def create_kubernetes_pool(
replicas: int,
queue_port: int,
debug: bool,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
**kwargs,
):
pool = None
error = False
Expand All @@ -285,6 +289,9 @@ def create_kubernetes_pool(
"CREATE_PRODUCER": "False",
"INMEMORY_WORKERS": "False",
},
reg_username=reg_username,
reg_password=reg_password,
reg_url=reg_url,
)
except Exception as e:
error = True
Expand Down Expand Up @@ -321,22 +328,25 @@ def run_workers_in_kubernetes(
queue_port: int,
start_idx=0,
debug: bool = False,
username: Optional[str] = None,
password: Optional[str] = None,
registry_url: Optional[str] = None,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
**kwargs,
) -> Union[List[ContainerSpawnStatus], SyftError]:
spawn_status = []
runner = KubernetesRunner()

if start_idx == 0:
pool_pods = create_kubernetes_pool(
runner,
worker_image,
pool_name,
worker_count,
queue_port,
debug,
runner=runner,
worker_image=worker_image,
pool_name=pool_name,
replicas=worker_count,
queue_port=queue_port,
debug=debug,
reg_username=reg_username,
reg_password=reg_password,
reg_url=reg_url,
)
else:
pool_pods = scale_kubernetes_pool(runner, pool_name, worker_count)
Expand Down Expand Up @@ -412,9 +422,9 @@ def run_containers(
queue_port: int,
dev_mode: bool = False,
start_idx: int = 0,
username: Optional[str] = None,
password: Optional[str] = None,
registry_url: Optional[str] = None,
reg_username: Optional[str] = None,
reg_password: Optional[str] = None,
reg_url: Optional[str] = None,
) -> Union[List[ContainerSpawnStatus], SyftError]:
results = []

Expand All @@ -435,9 +445,9 @@ def run_containers(
pool_name=pool_name,
queue_port=queue_port,
debug=dev_mode,
username=username,
password=password,
registry_url=registry_url,
username=reg_username,
password=reg_password,
registry_url=reg_url,
)
results.append(spawn_result)
elif orchestration == WorkerOrchestrationType.KUBERNETES:
Expand All @@ -448,6 +458,9 @@ def run_containers(
queue_port=queue_port,
debug=dev_mode,
start_idx=start_idx,
reg_username=reg_username,
reg_password=reg_password,
reg_url=reg_url,
)

return results
Expand Down
Loading
Loading