diff --git a/notebooks/api/0.8/10-container-images.ipynb b/notebooks/api/0.8/10-container-images.ipynb index 996a638e943..804ea8aa75d 100644 --- a/notebooks/api/0.8/10-container-images.ipynb +++ b/notebooks/api/0.8/10-container-images.ipynb @@ -1428,7 +1428,7 @@ "\n", "# Adding some sleep to allow containers to be fully removed,\n", "# before removing the image\n", - "time.sleep(10)" + "time.sleep(15)" ] }, { diff --git a/packages/grid/enclave/attestation/attestation.dockerfile b/packages/grid/enclave/attestation/attestation.dockerfile new file mode 100644 index 00000000000..3ddb0377ca0 --- /dev/null +++ b/packages/grid/enclave/attestation/attestation.dockerfile @@ -0,0 +1,85 @@ +ARG AZ_GUEST_LIB_VERSION="1.0.5" +ARG AZ_CLIENT_COMMIT="b613bcd" +ARG PYTHON_VERSION="3.10" +ARG NVTRUST_VERSION="1.3.0" + + +FROM ubuntu:22.04 as builder +ARG AZ_GUEST_LIB_VERSION +ARG AZ_CLIENT_COMMIT + +# ======== [Stage 1] Install Dependencies ========== # + +ENV DEBIAN_FRONTEND=noninteractive +RUN --mount=type=cache,target=/var/cache/apt/archives \ + apt update && apt upgrade -y && \ + apt-get install -y \ + build-essential \ + libcurl4-openssl-dev \ + libjsoncpp-dev \ + libboost-all-dev \ + nlohmann-json3-dev \ + cmake \ + wget \ + git + +RUN wget https://packages.microsoft.com/repos/azurecore/pool/main/a/azguestattestation1/azguestattestation1_${AZ_GUEST_LIB_VERSION}_amd64.deb && \ + dpkg -i azguestattestation1_${AZ_GUEST_LIB_VERSION}_amd64.deb + +# ======== [Stage 2] Build Attestation Client ========== # + +RUN git clone https://github.com/Azure/confidential-computing-cvm-guest-attestation.git && \ + cd confidential-computing-cvm-guest-attestation && \ + git checkout ${AZ_CLIENT_COMMIT} && \ + cd cvm-attestation-sample-app && \ + cmake . && make && cp ./AttestationClient / + + +# ======== [Step 3] Build Final Image ========== # +FROM python:${PYTHON_VERSION}-slim +ARG AZ_GUEST_LIB_VERSION +ARG NVTRUST_VERSION +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y \ + wget \ + git + +WORKDIR /app + +RUN wget https://packages.microsoft.com/repos/azurecore/pool/main/a/azguestattestation1/azguestattestation1_${AZ_GUEST_LIB_VERSION}_amd64.deb && \ + dpkg -i azguestattestation1_${AZ_GUEST_LIB_VERSION}_amd64.deb + +COPY --from=builder /AttestationClient /app + +# Clone Nvidia nvtrust Repo +RUN git clone -b v${NVTRUST_VERSION} https://github.com/NVIDIA/nvtrust.git + + +# Install Nvidia Local Verifier +RUN --mount=type=cache,target=/root/.cache \ + cd nvtrust/guest_tools/gpu_verifiers/local_gpu_verifier && \ + pip install . + +# Install Nvidia Attestation SDK +RUN --mount=type=cache,target=/root/.cache \ + cd nvtrust/guest_tools/attestation_sdk/dist && \ + pip install ./nv_attestation_sdk-${NVTRUST_VERSION}-py3-none-any.whl + + +COPY ./requirements.txt /app/requirements.txt +RUN --mount=type=cache,target=/root/.cache \ + pip install --user -r requirements.txt + +COPY ./start.sh /app/start.sh +RUN chmod +x /app/start.sh +COPY ./server /app/server + +# ========== [Step 4] Start Python Web Server ========== # + +CMD ["sh", "-c", "/app/start.sh"] +EXPOSE 4455 + +# Cleanup +RUN rm -rf /var/lib/apt/lists/* && \ + rm -rf /app/nvtrust \ No newline at end of file diff --git a/packages/grid/enclave/attestation/enclave-development.md b/packages/grid/enclave/attestation/enclave-development.md new file mode 100644 index 00000000000..f7217a5f7b7 --- /dev/null +++ b/packages/grid/enclave/attestation/enclave-development.md @@ -0,0 +1,104 @@ +# Enclave Development + +## Building Attestion Containers + +NOTE: Even on Arm machines, we build x64 images. +As some dependent packages in the dockerfile do not have arm64 equivalent. +It would take 10 minutes to build the image in emulation for the first time +in Arm machines.After which , the subsequent builds would be instant. + +```sh +cd packages/grid/enclave/attestation && \ +docker build -f attestation.dockerfile . -t attestation:0.1 --platform linux/amd64 +``` + +## Running the container in development mode + +```sh +cd packages/grid/enclave/attestation && \ +docker run -it --rm -e DEV_MODE=True -p 4455:4455 -v $(pwd)/server:/app/server attestation:0.1 +``` + +## For fetching attestation report by FastAPI + +### CPU Attestation + +```sh +docker run -it --rm --privileged \ + -p 4455:4455 \ + -v /sys/kernel/security:/sys/kernel/security \ + -v /dev/tpmrm0:/dev/tpmrm0 attestation:0.1 +``` + +```sh +curl localhost:4455/attest/cpu +``` + +### GPU Attestation + +#### Nvidia GPU Requirements + +We would need to install Nvidia Container Toolkit on host system and ensure we have CUDA Drivers installed. +Link: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html + +```sh +docker run -it --rm --privileged --gpus all --runtime=nvidia \ + -p 4455:4455 \ + -v /sys/kernel/security:/sys/kernel/security \ + -v /dev/tpmrm0:/dev/tpmrm0 attestation:0.1 +``` + +```sh +curl localhost:4455/attest/gpu +``` + +## For fetching attestation report directly by docker + +### CPU Attestation + +```sh +docker run -it --rm --privileged \ + -v /sys/kernel/security:/sys/kernel/security \ + -v /dev/tpmrm0:/dev/tpmrm0 attestation:0.1 /bin/bash +``` + +In the shell run + +```sh +./AttestationClient +``` + +This would return either True or False indicating status of attestation + +This could also be customized with Appraisal Policy + +To retrieve JWT from Microsoft Azure Attestation (MAA) + +```sh +./AttestationClient -o token +``` + +### For GPU Attestation + +```sh +docker run -it --rm --privileged --gpus all --runtime=nvidia \ + -v /sys/kernel/security:/sys/kernel/security \ + -v /dev/tpmrm0:/dev/tpmrm0 attestation:0.1 /bin/bash +``` + +Invoke python shell +In the python shell run + +```python3 +from nv_attestation_sdk import attestation + + +NRAS_URL="https://nras.attestation.nvidia.com/v1/attest/gpu" +client = attestation.Attestation() +client.set_name("thisNode1") +client.set_nonce("931d8dd0add203ac3d8b4fbde75e115278eefcdceac5b87671a748f32364dfcb") +print ("[RemoteGPUTest] node name :", client.get_name()) + +client.add_verifier(attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, "") +client.attest() +``` diff --git a/packages/grid/enclave/attestation/requirements.txt b/packages/grid/enclave/attestation/requirements.txt new file mode 100644 index 00000000000..bd5059ad68d --- /dev/null +++ b/packages/grid/enclave/attestation/requirements.txt @@ -0,0 +1,3 @@ +fastapi==0.110.0 +loguru==0.7.2 +uvicorn[standard]==0.27.1 diff --git a/packages/grid/enclave/attestation/server/cpu_attestation.py b/packages/grid/enclave/attestation/server/cpu_attestation.py new file mode 100644 index 00000000000..af13ff259a1 --- /dev/null +++ b/packages/grid/enclave/attestation/server/cpu_attestation.py @@ -0,0 +1,20 @@ +# stdlib +import subprocess + +# third party +from loguru import logger + + +def attest_cpu() -> str: + # Fetch report from Micrsoft Attestation library + cpu_report = subprocess.run( + ["/app/AttestationClient"], capture_output=True, text=True + ) + logger.debug(f"Stdout: {cpu_report.stdout}") + logger.debug(f"Stderr: {cpu_report.stderr}") + + logger.info("Attestation Return Code: {}", cpu_report.returncode) + if cpu_report.returncode == 0 and cpu_report.stdout == "true": + return "True" + + return "False" diff --git a/packages/grid/enclave/attestation/server/gpu_attestation.py b/packages/grid/enclave/attestation/server/gpu_attestation.py new file mode 100644 index 00000000000..ec0acd14c05 --- /dev/null +++ b/packages/grid/enclave/attestation/server/gpu_attestation.py @@ -0,0 +1,21 @@ +# third party +from loguru import logger +from nv_attestation_sdk import attestation + +NRAS_URL = "https://nras.attestation.nvidia.com/v1/attest/gpu" + + +def attest_gpu() -> str: + # Fetch report from Nvidia Attestation SDK + client = attestation.Attestation("Attestation Node") + + # TODO: Add the ability to generate nonce later. + logger.info("[RemoteGPUTest] node name : {}", client.get_name()) + + client.add_verifier( + attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, "" + ) + gpu_report = client.attest() + logger.info("[RemoteGPUTest] report : {}, {}", gpu_report, type(gpu_report)) + + return str(gpu_report) diff --git a/packages/grid/enclave/attestation/server/main.py b/packages/grid/enclave/attestation/server/main.py new file mode 100644 index 00000000000..408c1dcb9fe --- /dev/null +++ b/packages/grid/enclave/attestation/server/main.py @@ -0,0 +1,38 @@ +# stdlib +import os +import sys + +# third party +from fastapi import FastAPI +from loguru import logger + +# relative +from .cpu_attestation import attest_cpu +from .gpu_attestation import attest_gpu +from .models import CPUAttestationResponseModel +from .models import GPUAttestationResponseModel +from .models import ResponseModel + +# Logging Configuration +log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper() +logger.remove() +logger.add(sys.stderr, colorize=True, level=log_level) + +app = FastAPI(title="Attestation API") + + +@app.get("/", response_model=ResponseModel) +async def read_root() -> ResponseModel: + return ResponseModel(message="Server is running") + + +@app.get("/attest/cpu", response_model=CPUAttestationResponseModel) +async def attest_cpu_endpoint() -> CPUAttestationResponseModel: + cpu_attest_res = attest_cpu() + return CPUAttestationResponseModel(result=cpu_attest_res) + + +@app.get("/attest/gpu", response_model=GPUAttestationResponseModel) +async def attest_gpu_endpoint() -> GPUAttestationResponseModel: + gpu_attest_res = attest_gpu() + return GPUAttestationResponseModel(result=gpu_attest_res) diff --git a/packages/grid/enclave/attestation/server/models.py b/packages/grid/enclave/attestation/server/models.py new file mode 100644 index 00000000000..01ffdb01b19 --- /dev/null +++ b/packages/grid/enclave/attestation/server/models.py @@ -0,0 +1,16 @@ +# third party +from pydantic import BaseModel + + +class ResponseModel(BaseModel): + message: str + + +class CPUAttestationResponseModel(BaseModel): + result: str + vendor: str | None = None # Hardware Manufacturer + + +class GPUAttestationResponseModel(BaseModel): + result: str + vendor: str | None = None # Hardware Manufacturer diff --git a/packages/grid/enclave/attestation/start.sh b/packages/grid/enclave/attestation/start.sh new file mode 100644 index 00000000000..e0dd2b15ab5 --- /dev/null +++ b/packages/grid/enclave/attestation/start.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -e +export PATH="/root/.local/bin:${PATH}" + +APP_MODULE=server.main:app +APP_LOG_LEVEL=${APP_LOG_LEVEL:-info} +UVICORN_LOG_LEVEL=${UVICORN_LOG_LEVEL:-info} +HOST=${HOST:-0.0.0.0} +PORT=${PORT:-4455} +RELOAD="" + +if [[ ${DEV_MODE} == "True" ]]; +then + echo "DEV_MODE Enabled" + RELOAD="--reload" +fi + + +exec uvicorn $RELOAD --host $HOST --port $PORT --log-level $UVICORN_LOG_LEVEL "$APP_MODULE" \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/NOTES.txt b/packages/grid/helm/syft/templates/NOTES.txt index 7629a68fee3..e5fbfaf912b 100644 --- a/packages/grid/helm/syft/templates/NOTES.txt +++ b/packages/grid/helm/syft/templates/NOTES.txt @@ -95,7 +95,7 @@ "action": "add" } }, - "SyncView": { + "SyncTableObject": { "1": { "version": 1, "hash": "4e87744e86cd7781e3d5cf4618e63516f3d26309a4da919033dacc5ed338d76d", diff --git a/packages/hagrid/.bumpversion.cfg b/packages/hagrid/.bumpversion.cfg index cfb8f9c7286..72e8b1b4f24 100644 --- a/packages/hagrid/.bumpversion.cfg +++ b/packages/hagrid/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.114 +current_version = 0.3.115 tag = False tag_name = {new_version} commit = True diff --git a/packages/hagrid/hagrid/manifest_template.yml b/packages/hagrid/hagrid/manifest_template.yml index a2fa362e8b9..0444c49bf28 100644 --- a/packages/hagrid/hagrid/manifest_template.yml +++ b/packages/hagrid/hagrid/manifest_template.yml @@ -1,9 +1,9 @@ manifestVersion: 0.1 -hagrid_version: 0.3.114 +hagrid_version: 0.3.115 syft_version: 0.8.7-beta.1 dockerTag: 0.8.7-beta.1 baseUrl: https://raw.githubusercontent.com/OpenMined/PySyft/ -hash: 6503b79943cbdf27c701e0a2d3f9308f4f3a76a4 +hash: a2f8839726edd94a5759407d63c900e77bb3b466 target_dir: ~/.hagrid/PySyft/ files: grid: diff --git a/packages/hagrid/hagrid/version.py b/packages/hagrid/hagrid/version.py index 772d3d29c59..5086b560c42 100644 --- a/packages/hagrid/hagrid/version.py +++ b/packages/hagrid/hagrid/version.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # HAGrid Version -__version__ = "0.3.114" +__version__ = "0.3.115" if __name__ == "__main__": print(__version__) diff --git a/packages/hagrid/setup.py b/packages/hagrid/setup.py index 1a852c52843..37a26b212d5 100644 --- a/packages/hagrid/setup.py +++ b/packages/hagrid/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages from setuptools import setup -__version__ = "0.3.114" +__version__ = "0.3.115" DATA_FILES = {"img": ["hagrid/img/*.png"], "hagrid": ["*.yml"]} diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 0de91ffb8cf..cdc143c7b5d 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -67,6 +67,7 @@ # relative from ..service.network.node_peer import NodePeer + # use to enable mitm proxy # from syft.grid.connections.http_connection import HTTPConnection # HTTPConnection.proxies = {"http": "http://127.0.0.1:8080"} @@ -106,7 +107,7 @@ def forward_message_to_proxy( # generate a random signing key credentials = SyftSigningKey.generate() - signed_message = call.sign(credentials=credentials) + signed_message: SignedSyftAPICall = call.sign(credentials=credentials) signed_result = make_call(signed_message) response = debox_signed_syftapicall_response(signed_result) return response @@ -205,7 +206,9 @@ def _make_post( return response.content - def get_node_metadata(self, credentials: SyftSigningKey) -> NodeMetadataJSON: + def get_node_metadata( + self, credentials: SyftSigningKey + ) -> NodeMetadataJSON | SyftError: if self.proxy_target_uid: response = forward_message_to_proxy( make_call=self.make_call, @@ -304,7 +307,7 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.proxy_target_uid) + hash(self.url) - def get_client_type(self) -> type[SyftClient]: + def get_client_type(self) -> type[SyftClient] | SyftError: # TODO: Rasswanth, should remove passing in credentials # when metadata are proxy forwarded in the grid routes # in the gateway fixes PR @@ -335,7 +338,9 @@ class PythonConnection(NodeConnection): def with_proxy(self, proxy_target_uid: UID) -> Self: return PythonConnection(node=self.node, proxy_target_uid=proxy_target_uid) - def get_node_metadata(self, credentials: SyftSigningKey) -> NodeMetadataJSON: + def get_node_metadata( + self, credentials: SyftSigningKey + ) -> NodeMetadataJSON | SyftError: if self.proxy_target_uid: response = forward_message_to_proxy( make_call=self.make_call, @@ -434,7 +439,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"{type(self).__name__}" - def get_client_type(self) -> type[SyftClient]: + def get_client_type(self) -> type[SyftClient] | SyftError: # relative from .domain_client import DomainClient from .enclave_client import EnclaveClient diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index 10760223bd2..d94381bc166 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -306,10 +306,10 @@ def connect_to_gateway( if isinstance(res, SyftSuccess): if self.metadata: return SyftSuccess( - message=f"Connected {self.metadata.node_type} to {client.name} gateway" + message=f"Connected {self.metadata.node_type} '{self.metadata.name}' to gateway '{client.name}'" ) else: - return SyftSuccess(message=f"Connected to {client.name} gateway") + return SyftSuccess(message=f"Connected to '{client.name}' gateway") return res @property diff --git a/packages/syft/src/syft/client/enclave_client.py b/packages/syft/src/syft/client/enclave_client.py index 7a29b37fd7b..dd8302c0741 100644 --- a/packages/syft/src/syft/client/enclave_client.py +++ b/packages/syft/src/syft/client/enclave_client.py @@ -96,7 +96,7 @@ def connect_to_gateway( if isinstance(res, SyftSuccess): return SyftSuccess( - message=f"Connected {self.metadata.node_type} to {client.name} gateway" + message=f"Connected {self.metadata.node_type} {self.metadata.name} to {client.name} gateway" ) return res diff --git a/packages/syft/src/syft/client/gateway_client.py b/packages/syft/src/syft/client/gateway_client.py index 3957f141bba..2e989b1bc7c 100644 --- a/packages/syft/src/syft/client/gateway_client.py +++ b/packages/syft/src/syft/client/gateway_client.py @@ -7,6 +7,7 @@ from ..img.base64 import base64read from ..node.credentials import SyftSigningKey from ..serde.serializable import serializable +from ..service.metadata.node_metadata import NodeMetadataJSON from ..service.network.node_peer import NodePeer from ..service.response import SyftError from ..service.response import SyftException @@ -14,6 +15,7 @@ from ..types.syft_object import SyftObject from ..util.fonts import FONT_CSS from .client import SyftClient +from .connection import NodeConnection @serializable() @@ -25,8 +27,12 @@ def proxy_to(self, peer: Any) -> SyftClient: from .domain_client import DomainClient from .enclave_client import EnclaveClient - connection = self.connection.with_proxy(peer.id) - metadata = connection.get_node_metadata(credentials=SyftSigningKey.generate()) + connection: type[NodeConnection] = self.connection.with_proxy(peer.id) + metadata: NodeMetadataJSON | SyftError = connection.get_node_metadata( + credentials=SyftSigningKey.generate() + ) + if isinstance(metadata, SyftError): + return metadata if metadata.node_type == NodeType.DOMAIN.value: client_type: type[SyftClient] = DomainClient elif metadata.node_type == NodeType.ENCLAVE.value: diff --git a/packages/syft/src/syft/client/registry.py b/packages/syft/src/syft/client/registry.py index 125f46b6f8d..018c101de36 100644 --- a/packages/syft/src/syft/client/registry.py +++ b/packages/syft/src/syft/client/registry.py @@ -50,20 +50,21 @@ def __init__(self) -> None: def load_network_registry_json() -> dict: try: # Get the environment variable - network_registry_json = os.getenv("NETWORK_REGISTRY_JSON") + network_registry_json: str | None = os.getenv("NETWORK_REGISTRY_JSON") # If the environment variable exists, use it if network_registry_json is not None: network_json: dict = json.loads(network_registry_json) else: # Load the network registry from the NETWORK_REGISTRY_URL - response = requests.get(NETWORK_REGISTRY_URL, timeout=10) # nosec + response = requests.get(NETWORK_REGISTRY_URL, timeout=30) # nosec + response.raise_for_status() # raise an exception if the HTTP request returns an error network_json = response.json() return network_json except Exception as e: warning( - f"Failed to get Network Registry, go checkout: {NETWORK_REGISTRY_REPO}. {e}" + f"Failed to get Network Registry from {NETWORK_REGISTRY_REPO}. Exception: {e}" ) return {} diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 8f24b1ab3cd..4a71e5fbd58 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -1030,7 +1030,7 @@ def resolve_future( def forward_message( self, api_call: SyftAPICall | SignedSyftAPICall - ) -> Result[QueueItem | SyftObject, Err]: + ) -> Result | QueueItem | SyftObject | SyftError | Any: node_uid = api_call.message.node_uid if "networkservice" not in self.service_path_map: return SyftError( @@ -1051,14 +1051,21 @@ def forward_message( # Since we have several routes to a peer # we need to cache the client for a given node_uid along with the route peer_cache_key = hash(node_uid) + hash(peer.pick_highest_priority_route()) - if peer_cache_key in self.peer_client_cache: client = self.peer_client_cache[peer_cache_key] else: context = AuthedServiceContext( node=self, credentials=api_call.credentials ) + client = peer.client_with_context(context=context) + if client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{peer.id}. Error: {client.err()}" + ) + client = client.ok() + self.peer_client_cache[peer_cache_key] = client if client: @@ -1129,6 +1136,7 @@ def handle_api_call_with_unsigned_result( if api_call.message.node_uid != self.id and check_call_location: return self.forward_message(api_call=api_call) + if api_call.message.path == "queue": return self.resolve_future( credentials=api_call.credentials, uid=api_call.message.kwargs["uid"] diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 4e587772a00..53ae91cdd3d 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -98,13 +98,6 @@ "action": "add" } }, - "SyncView": { - "1": { - "version": 1, - "hash": "4e87744e86cd7781e3d5cf4618e63516f3d26309a4da919033dacc5ed338d76d", - "action": "add" - } - }, "SyncStateItem": { "1": { "version": 1, diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 0fc971d2612..c92c41b15a2 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -3,6 +3,7 @@ # stdlib from collections.abc import Callable +from collections.abc import Iterable from enum import Enum import inspect from io import BytesIO @@ -678,7 +679,15 @@ def syft_get_diffs(self, ext_obj: Any) -> list[AttrDiff]: low_data = ext_obj.syft_action_data high_data = self.syft_action_data - if low_data != high_data: + + try: + cmp = low_data != high_data + if isinstance(cmp, Iterable): + cmp = all(cmp) + except Exception: + cmp = False + + if cmp: diff_attr = AttrDiff( attr_name="syft_action_data", low_attr=low_data, high_attr=high_data ) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index b1a47f5636d..6ad47ccb820 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -326,6 +326,10 @@ def subjobs(self) -> list["Job"] | SyftError: ) return api.services.job.get_subjobs(self.id) + def get_subjobs(self, context: AuthedServiceContext) -> list["Job"] | SyftError: + job_service = context.node.get_service("jobservice") + return job_service.get_subjobs(context, self.id) + @property def owner(self) -> UserView | SyftError: api = APIRegistry.api_for( @@ -517,11 +521,11 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: # if self.log_id: dependencies.append(self.log_id) - subjobs = self.subjobs + subjobs = self.get_subjobs(context) if isinstance(subjobs, SyftError): return subjobs - subjob_ids = [subjob.id for subjob in self.subjobs] + subjob_ids = [subjob.id for subjob in subjobs] dependencies.extend(subjob_ids) if self.user_code_id is not None: diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 6ae9e681943..98b05f50bdd 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -4,6 +4,7 @@ from typing import Any # third party +from result import Err from result import Result # relative @@ -27,10 +28,12 @@ from ...types.transforms import transform_method from ...types.uid import UID from ...util.telemetry import instrument +from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey from ..metadata.node_metadata import NodeMetadataV3 from ..response import SyftError +from ..response import SyftInfo from ..response import SyftSuccess from ..service import AbstractService from ..service import SERVICE_TO_TYPES @@ -42,6 +45,7 @@ from .node_peer import NodePeer from .routes import HTTPNodeRoute from .routes import NodeRoute +from .routes import NodeRouteType from .routes import PythonNodeRoute VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) @@ -74,12 +78,24 @@ def update( ) -> Result[NodePeer, str]: valid = self.check_type(peer, NodePeer) if valid.is_err(): - return SyftError(message=valid.err()) + return Err(message=valid.err()) return super().update(credentials, peer) - def update_peer( + def create_or_update_peer( self, credentials: SyftVerifyKey, peer: NodePeer ) -> Result[NodePeer, str]: + """ + Update the selected peer and its route priorities if the peer already exists + If the peer does not exist, simply adds it to the database. + + Args: + credentials (SyftVerifyKey): The credentials used to authenticate the request. + peer (NodePeer): The peer to be updated or added. + + Returns: + Result[NodePeer, str]: The updated or added peer if the operation + was successful, or an error message if the operation failed. + """ valid = self.check_type(peer, NodePeer) if valid.is_err(): return SyftError(message=valid.err()) @@ -95,9 +111,9 @@ def update_peer( result = self.set(credentials, peer) return result - def get_for_verify_key( + def get_by_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey - ) -> Result[NodePeer, SyftError]: + ) -> Result[NodePeer | None, SyftError]: qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) return self.query_one(credentials, qks) @@ -151,6 +167,7 @@ def exchange_credentials_with( ) random_challenge = secrets.token_bytes(16) + # ask the remote client to add this node (represented by `self_node_peer`) as a peer remote_res = remote_client.api.services.network.add_peer( peer=self_node_peer, challenge=random_challenge, @@ -164,7 +181,6 @@ def exchange_credentials_with( challenge_signature, remote_node_peer = remote_res # Verifying if the challenge is valid - try: remote_node_verify_key.verify_key.verify( random_challenge, challenge_signature @@ -173,8 +189,7 @@ def exchange_credentials_with( return SyftError(message=str(e)) # save the remote peer for later - - result = self.stash.update_peer( + result = self.stash.create_or_update_peer( context.node.verify_key, remote_node_peer, ) @@ -192,13 +207,15 @@ def add_peer( self_node_route: NodeRoute, verify_key: SyftVerifyKey, ) -> list | SyftError: - """Add a Network Node Peer""" + """Add a Network Node Peer. Called by a remote node to add + itself as a peer for the current node. + """ # Using the verify_key of the peer to verify the signature # It is also our single source of truth for the peer if peer.verify_key != context.credentials: return SyftError( message=( - f"The {type(peer)}.verify_key: " + f"The {type(peer).__name__}.verify_key: " f"{peer.verify_key} does not match the signature of the message" ) ) @@ -209,7 +226,14 @@ def add_peer( ) try: - remote_client: SyftClient = peer.client_with_context(context=context) + remote_client = peer.client_with_context(context=context) + if remote_client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{peer.id}. Error: {remote_client.err()}" + ) + remote_client = remote_client.ok() + random_challenge = secrets.token_bytes(16) remote_res = remote_client.api.services.network.ping( challenge=random_challenge @@ -228,7 +252,7 @@ def add_peer( except Exception as e: return SyftError(message=str(e)) - result = self.stash.update_peer(context.node.verify_key, peer) + result = self.stash.create_or_update_peer(context.node.verify_key, peer) if result.is_err(): return SyftError(message=str(result.err())) @@ -266,52 +290,6 @@ def ping( return challenge_signature - @service_method(path="network.add_route_for", name="add_route_for") - def add_route_for( - self, - context: AuthedServiceContext, - route: NodeRoute, - peer: NodePeer, - ) -> SyftSuccess | SyftError: - """Add Route for this Node to another Node""" - # check root user is asking for the exchange - client = peer.client_with_context(context=context) - result = client.api.services.network.verify_route(route) - - if not isinstance(result, SyftSuccess): - return result - return SyftSuccess(message="Route Verified") - - @service_method( - path="network.verify_route", name="verify_route", roles=GUEST_ROLE_LEVEL - ) - def verify_route( - self, context: AuthedServiceContext, route: NodeRoute - ) -> SyftSuccess | SyftError: - """Add a Network Node Route""" - # get the peer asking for route verification from its verify_key - - peer = self.stash.get_for_verify_key( - context.node.verify_key, - context.credentials, - ) - if peer.is_err(): - return SyftError(message=peer.err()) - peer = peer.ok() - - if peer.verify_key != context.credentials: - return SyftError( - message=( - f"verify_key: {context.credentials} at route {route} " - f"does not match listed peer: {peer}" - ) - ) - peer.update_routes([route]) - result = self.stash.update_peer(context.node.verify_key, peer) - if result.is_err(): - return SyftError(message=str(result.err())) - return SyftSuccess(message="Network Route Verified") - @service_method( path="network.get_all_peers", name="get_all_peers", roles=GUEST_ROLE_LEVEL ) @@ -377,7 +355,373 @@ def delete_peer_by_id( result = self.stash.delete_by_uid(context.credentials, uid) if result.is_err(): return SyftError(message=str(result.err())) - return SyftSuccess(message="Node Peer Deleted") + # TODO: Notify the peer (either by email or by other form of notifications) + # that it has been deleted from the network + return SyftSuccess(message=f"Node Peer with id {uid} Deleted") + + @service_method(path="network.add_route_on_peer", name="add_route_on_peer") + def add_route_on_peer( + self, + context: AuthedServiceContext, + peer: NodePeer, + route: NodeRoute, + ) -> SyftSuccess | SyftError: + """ + Add or update the route information on the remote peer. + + Args: + context (AuthedServiceContext): The authentication context. + peer (NodePeer): The peer representing the remote node. + route (NodeRoute): The route to be added. + + Returns: + SyftSuccess | SyftError: A success message if the route is verified, + otherwise an error message. + """ + # creates a client on the remote node based on the credentials + # of the current node's client + remote_client = peer.client_with_context(context=context) + if remote_client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{peer.id}. Error: {remote_client.err()}" + ) + remote_client = remote_client.ok() + # ask the remote node to add the route to the self node + result = remote_client.api.services.network.add_route( + peer_verify_key=context.credentials, + route=route, + called_by_peer=True, + ) + return result + + @service_method(path="network.add_route", name="add_route", roles=GUEST_ROLE_LEVEL) + def add_route( + self, + context: AuthedServiceContext, + peer_verify_key: SyftVerifyKey, + route: NodeRoute, + called_by_peer: bool = False, + ) -> SyftSuccess | SyftError: + """ + Add a route to the peer. If the route already exists, update its priority. + + Args: + context (AuthedServiceContext): The authentication context of the remote node. + peer_verify_key (SyftVerifyKey): The verify key of the remote node peer. + route (NodeRoute): The route to be added. + called_by_peer (bool): The flag to indicate that it's called by a remote peer. + + Returns: + SyftSuccess | SyftError + """ + # verify if the peer is truly the one sending the request to add the route to itself + if called_by_peer and peer_verify_key != context.credentials: + return SyftError( + message=( + f"The {type(peer_verify_key).__name__}: " + f"{peer_verify_key} does not match the signature of the message" + ) + ) + # get the full peer object from the store to update its routes + remote_node_peer: NodePeer | SyftError = ( + self._get_remote_node_peer_by_verify_key(context, peer_verify_key) + ) + if isinstance(remote_node_peer, SyftError): + return remote_node_peer + # add and update the priority for the peer + existed_route: NodeRoute | None = remote_node_peer.update_route(route) + # update the peer in the store with the updated routes + result = self.stash.update( + credentials=context.node.verify_key, + peer=remote_node_peer, + ) + if result.is_err(): + return SyftError(message=str(result.err())) + if existed_route: + return SyftSuccess( + message=f"The route already exists between '{context.node.name}' and " + f"peer '{remote_node_peer.name}' with id '{existed_route.id}', so its priority was updated" + ) + return SyftSuccess( + message=f"New route ({str(route)}) with id '{route.id}' " + f"to peer {remote_node_peer.node_type.value} '{remote_node_peer.name}' " + f"was added for {str(context.node.node_type)} '{context.node.name}'" + ) + + @service_method(path="network.delete_route_on_peer", name="delete_route_on_peer") + def delete_route_on_peer( + self, + context: AuthedServiceContext, + peer: NodePeer, + route: NodeRoute | None = None, + route_id: UID | None = None, + ) -> SyftSuccess | SyftError | SyftInfo: + """ + Delete the route on the remote peer. + + Args: + context (AuthedServiceContext): The authentication context for the service. + peer (NodePeer): The peer for which the route will be deleted. + route (NodeRoute): The route to be deleted. + route_id (UID): The UID of the route to be deleted. + + Returns: + SyftSuccess: If the route is successfully deleted. + SyftError: If there is an error deleting the route. + SyftInfo: If there is only one route left for the peer and + the admin chose not to remove it + """ + if route is None and route_id is None: + return SyftError( + message="Either `route` or `route_id` arg must be provided" + ) + + if route and route_id and route.id != route_id: + return SyftError( + message=f"Both `route` and `route_id` are provided, but " + f"route's id ({route.id}) and route_id ({route_id}) do not match" + ) + + # creates a client on the remote node based on the credentials + # of the current node's client + remote_client = peer.client_with_context(context=context) + if remote_client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{peer.id}. Error: {remote_client.err()}" + ) + remote_client = remote_client.ok() + # ask the remote node to delete the route to the self node, + result = remote_client.api.services.network.delete_route( + peer_verify_key=context.credentials, + route=route, + route_id=route_id, + called_by_peer=True, + ) + return result + + @service_method(path="network.", name="delete_route", roles=GUEST_ROLE_LEVEL) + def delete_route( + self, + context: AuthedServiceContext, + peer_verify_key: SyftVerifyKey, + route: NodeRoute | None = None, + route_id: UID | None = None, + called_by_peer: bool = False, + ) -> SyftSuccess | SyftError | SyftInfo: + """ + Delete a route for a given peer. + If a peer has no routes left, there will be a prompt asking if the user want to remove it. + If the answer is yes, it will be removed from the stash and will no longer be a peer. + + Args: + context (AuthedServiceContext): The authentication context for the service. + peer_verify_key (SyftVerifyKey): The verify key of the remote node peer. + route (NodeRoute): The route to be deleted. + route_id (UID): The UID of the route to be deleted. + called_by_peer (bool): The flag to indicate that it's called by a remote peer. + + Returns: + SyftSuccess: If the route is successfully deleted. + SyftError: If there is an error deleting the route. + SyftInfo: If there is only one route left for the peer and + the admin chose not to remove it + """ + if called_by_peer and peer_verify_key != context.credentials: + # verify if the peer is truly the one sending the request to delete the route to itself + return SyftError( + message=( + f"The {type(peer_verify_key).__name__}: " + f"{peer_verify_key} does not match the signature of the message" + ) + ) + + remote_node_peer: NodePeer | SyftError = ( + self._get_remote_node_peer_by_verify_key( + context=context, peer_verify_key=peer_verify_key + ) + ) + + if len(remote_node_peer.node_routes) == 1: + warning_message = ( + f"There is only one route left to peer " + f"{remote_node_peer.node_type.value} '{remote_node_peer.name}'. " + f"Removing this route will remove the peer for " + f"{str(context.node.node_type)} '{context.node.name}'." + ) + response: bool = prompt_warning_message( + message=warning_message, + confirm=False, + ) + if not response: + return SyftInfo( + message=f"The last route to {remote_node_peer.node_type.value} " + f"'{remote_node_peer.name}' with id " + f"'{remote_node_peer.node_routes[0].id}' was not deleted." + ) + + if route: + result = remote_node_peer.delete_route(route=route) + return_message = ( + f"Route '{str(route)}' with id '{route.id}' to peer " + f"{remote_node_peer.node_type.value} '{remote_node_peer.name}' " + f"was deleted for {str(context.node.node_type)} '{context.node.name}'." + ) + if route_id: + result = remote_node_peer.delete_route(route_id=route_id) + return_message = ( + f"Route with id '{route_id}' to peer " + f"{remote_node_peer.node_type.value} '{remote_node_peer.name}' " + f"was deleted for {str(context.node.node_type)} '{context.node.name}'." + ) + if isinstance(result, SyftError): + return result + + if len(remote_node_peer.node_routes) == 0: + # remove the peer + # TODO: should we do this as we are deleting the peer with a guest role level? + result = self.stash.delete_by_uid( + credentials=context.node.verify_key, uid=remote_node_peer.id + ) + if isinstance(result, SyftError): + return result + return_message += ( + f" There is no routes left to connect to peer " + f"{remote_node_peer.node_type.value} '{remote_node_peer.name}', so it is deleted for " + f"{str(context.node.node_type)} '{context.node.name}'." + ) + else: + # update the peer with the route removed + result = self.stash.update( + credentials=context.node.verify_key, peer=remote_node_peer + ) + if result.is_err(): + return SyftError(message=str(result.err())) + + return SyftSuccess(message=return_message) + + @service_method( + path="network.update_route_priority_on_peer", + name="update_route_priority_on_peer", + ) + def update_route_priority_on_peer( + self, + context: AuthedServiceContext, + peer: NodePeer, + route: NodeRoute, + priority: int | None = None, + ) -> SyftSuccess | SyftError: + """ + Update the route priority on the remote peer. + + Args: + context (AuthedServiceContext): The authentication context. + peer (NodePeer): The peer representing the remote node. + route (NodeRoute): The route to be added. + priority (int | None): The new priority value for the route. If not + provided, it will be assigned the highest priority among all peers + + Returns: + SyftSuccess | SyftError: A success message if the route is verified, + otherwise an error message. + """ + # creates a client on the remote node based on the credentials + # of the current node's client + remote_client = peer.client_with_context(context=context) + if remote_client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{peer.id}. Error: {remote_client.err()}" + ) + remote_client = remote_client.ok() + result = remote_client.api.services.network.update_route_priority( + peer_verify_key=context.credentials, + route=route, + priority=priority, + called_by_peer=True, + ) + return result + + @service_method( + path="network.update_route_priority", + name="update_route_priority", + roles=GUEST_ROLE_LEVEL, + ) + def update_route_priority( + self, + context: AuthedServiceContext, + peer_verify_key: SyftVerifyKey, + route: NodeRoute, + priority: int | None = None, + called_by_peer: bool = False, + ) -> SyftSuccess | SyftError: + """ + Updates a route's priority for the given peer + + Args: + context (AuthedServiceContext): The authentication context for the service. + peer_verify_key (SyftVerifyKey): The verify key of the peer whose route priority needs to be updated. + route (NodeRoute): The route for which the priority needs to be updated. + priority (int | None): The new priority value for the route. If not + provided, it will be assigned the highest priority among all peers + + Returns: + SyftSuccess | SyftError: Successful / Error response + """ + if called_by_peer and peer_verify_key != context.credentials: + return SyftError( + message=( + f"The {type(peer_verify_key).__name__}: " + f"{peer_verify_key} does not match the signature of the message" + ) + ) + # get the full peer object from the store to update its routes + remote_node_peer: NodePeer | SyftError = ( + self._get_remote_node_peer_by_verify_key(context, peer_verify_key) + ) + if isinstance(remote_node_peer, SyftError): + return remote_node_peer + # update the route's priority for the peer + updated_node_route: NodeRouteType | SyftError = ( + remote_node_peer.update_existed_route_priority( + route=route, priority=priority + ) + ) + if isinstance(updated_node_route, SyftError): + return updated_node_route + new_priority: int = updated_node_route.priority + # update the peer in the store + result = self.stash.update(context.node.verify_key, remote_node_peer) + if result.is_err(): + return SyftError(message=str(result.err())) + + return SyftSuccess( + message=f"Route {route.id}'s priority updated to " + f"{new_priority} for peer {remote_node_peer.name}" + ) + + def _get_remote_node_peer_by_verify_key( + self, context: AuthedServiceContext, peer_verify_key: SyftVerifyKey + ) -> NodePeer | SyftError: + """ + Helper function to get the full node peer object from t + he stash using its verify key + """ + remote_node_peer: Result[NodePeer | None, SyftError] = ( + self.stash.get_by_verify_key( + credentials=context.node.verify_key, + verify_key=peer_verify_key, + ) + ) + if remote_node_peer.is_err(): + return SyftError(message=str(remote_node_peer.err())) + remote_node_peer = remote_node_peer.ok() + if remote_node_peer is None: + return SyftError( + message=f"Can't retrive {remote_node_peer.name} from the store of peers (None)." + ) + return remote_node_peer TYPE_TO_SERVICE[NodePeer] = NetworkService diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index ecf4c08193c..70e6f9bfb40 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -1,10 +1,14 @@ -# stdlib - # stdlib from collections.abc import Callable +# third party +from result import Err +from result import Ok +from result import Result + # relative from ...abstract_node import NodeType +from ...client.client import NodeConnection from ...client.client import SyftClient from ...node.credentials import SyftSigningKey from ...node.credentials import SyftVerifyKey @@ -63,98 +67,167 @@ class NodePeer(SyftObject): node_type: NodeType admin_email: str - def update_routes(self, new_routes: list[NodeRoute]) -> None: - add_routes = [] - new_routes = self.update_route_priorities(new_routes) - for new_route in new_routes: - existed, index = self.existed_route(new_route) - if existed and index is not None: - # if the route already exists, we do not append it to self.new_route, - # but update its priority - self.node_routes[index].priority = new_route.priority - else: - add_routes.append(new_route) - - self.node_routes += add_routes - - def update_route_priorities(self, new_routes: list[NodeRoute]) -> list[NodeRoute]: - """ - Since we pick the newest route has the highest priority, we - update the priority of the newly added routes here to be increments of - current routes' highest priority. - """ - current_max_priority = max(route.priority for route in self.node_routes) - for route in new_routes: - route.priority = current_max_priority + 1 - current_max_priority += 1 - return new_routes - - def existed_route(self, route: NodeRoute) -> tuple[bool, int | None]: + def existed_route( + self, route: NodeRouteType | None = None, route_id: UID | None = None + ) -> tuple[bool, int | None]: """Check if a route exists in self.node_routes - - For HTTPNodeRoute: check based on protocol, host_or_ip (url) and port - - For PythonNodeRoute: check if the route exists in the set of all node_routes + Args: - route: the route to be checked + route: the route to be checked. For now it can be either + HTTPNodeRoute or PythonNodeRoute or VeilidNodeRoute + route_id: the id of the route to be checked + Returns: if the route exists, returns (True, index of the existed route in self.node_routes) if the route does not exist returns (False, None) """ - if isinstance(route, HTTPNodeRoute): + if route_id is None and route is None: + raise ValueError("Either route or route_id should be provided in args") + + if route: + if not isinstance(route, HTTPNodeRoute | PythonNodeRoute | VeilidNodeRoute): + raise ValueError(f"Unsupported route type: {type(route)}") for i, r in enumerate(self.node_routes): - if ( - (route.host_or_ip == r.host_or_ip) - and (route.port == r.port) - and (route.protocol == r.protocol) - ): + if route == r: return (True, i) - return (False, None) - elif isinstance(route, PythonNodeRoute): # PythonNodeRoute - for i, r in enumerate(self.node_routes): # something went wrong here - if ( - (route.worker_settings.id == r.worker_settings.id) - and (route.worker_settings.name == r.worker_settings.name) - and (route.worker_settings.node_type == r.worker_settings.node_type) - and ( - route.worker_settings.node_side_type - == r.worker_settings.node_side_type - ) - and ( - route.worker_settings.signing_key - == r.worker_settings.signing_key - ) - ): + + elif route_id: + for i, r in enumerate(self.node_routes): + if r.id == route_id: return (True, i) - return (False, None) + + return (False, None) + + def assign_highest_priority(self, route: NodeRoute) -> NodeRoute: + """ + Assign the new_route's to have the highest priority + + Args: + route (NodeRoute): The new route whose priority is to be updated. + + Returns: + NodeRoute: The new route with the updated priority + """ + current_max_priority: int = max(route.priority for route in self.node_routes) + route.priority = current_max_priority + 1 + return route + + def update_route(self, new_route: NodeRoute) -> NodeRoute | None: + """ + Update the route for the node. + If the route already exists, updates the priority of the existing route. + If it doesn't, it append the new route to the peer's list of node routes. + + Args: + new_route (NodeRoute): The new route to be added to the node. + + Returns: + NodeRoute | None: if the route already exists, return it, else returns None + """ + new_route = self.assign_highest_priority(new_route) + existed, index = self.existed_route(new_route) + if existed and index is not None: + self.node_routes[index].priority = new_route.priority + return self.node_routes[index] + else: + self.node_routes.append(new_route) + return None + + def update_routes(self, new_routes: list[NodeRoute]) -> None: + """ + Update multiple routes of the node peer. + + This method takes a list of new routes as input. + It first updates the priorities of the new routes. + Then, for each new route, it checks if the route already exists for the node peer. + If it does, it updates the priority of the existing route. + If it doesn't, it adds the new route to the node. + + Args: + new_routes (list[NodeRoute]): The new routes to be added to the node. + + Returns: + None + """ + for new_route in new_routes: + self.update_route(new_route) + + def update_existed_route_priority( + self, route: NodeRoute, priority: int | None = None + ) -> NodeRouteType | SyftError: + """ + Update the priority of an existed route. + + Args: + route (NodeRoute): The route whose priority is to be updated. + priority (int | None): The new priority of the route. If not given, + the route will be assigned with the highest priority. + + Returns: + NodeRoute: The route with updated priority if the route exists + SyftError: If the route does not exist or the priority is invalid + """ + if priority is not None and priority <= 0: + return SyftError( + message="Priority must be greater than 0. Now it is {priority}." + ) + + existed, index = self.existed_route(route_id=route.id) + + if not existed or index is None: + return SyftError(message=f"Route with id {route.id} does not exist.") + + if priority is not None: + self.node_routes[index].priority = priority else: - raise ValueError(f"Unsupported route type: {type(route)}") + self.node_routes[index].priority = self.assign_highest_priority( + route + ).priority + + return self.node_routes[index] @staticmethod def from_client(client: SyftClient) -> "NodePeer": if not client.metadata: - raise Exception("Client has to have metadata first") + raise ValueError("Client has to have metadata first") peer = client.metadata.to(NodeMetadataV3).to(NodePeer) route = connection_to_route(client.connection) peer.node_routes.append(route) return peer - def client_with_context(self, context: NodeServiceContext) -> SyftClient: - if len(self.node_routes) < 1: - raise Exception(f"No routes to peer: {self}") - # select the latest added route - final_route = self.pick_highest_priority_route() - connection = route_to_connection(route=final_route) + def client_with_context( + self, context: NodeServiceContext + ) -> Result[type[SyftClient], str]: + # third party + from loguru import logger - client_type = connection.get_client_type() + if len(self.node_routes) < 1: + raise ValueError(f"No routes to peer: {self}") + # select the highest priority route (i.e. added or updated the latest) + final_route: NodeRoute = self.pick_highest_priority_route() + connection: NodeConnection = route_to_connection(route=final_route) + try: + client_type = connection.get_client_type() + except Exception as e: + logger.error( + f"Failed to establish a connection with {self.node_type} '{self.name}'. Exception: {e}" + ) + return Err( + f"Failed to establish a connection with {self.node_type} '{self.name}'" + ) if isinstance(client_type, SyftError): - return client_type - return client_type(connection=connection, credentials=context.node.signing_key) + return Err(client_type.message) + return Ok( + client_type(connection=connection, credentials=context.node.signing_key) + ) - def client_with_key(self, credentials: SyftSigningKey) -> SyftClient: + def client_with_key(self, credentials: SyftSigningKey) -> SyftClient | SyftError: if len(self.node_routes) < 1: - raise Exception(f"No routes to peer: {self}") + raise ValueError(f"No routes to peer: {self}") # select the latest added route - final_route = self.pick_highest_priority_route() + final_route: NodeRoute = self.pick_highest_priority_route() + connection = route_to_connection(route=final_route) client_type = connection.get_client_type() if isinstance(client_type, SyftError): @@ -171,11 +244,43 @@ def proxy_from(self, client: SyftClient) -> SyftClient: return client.proxy_to(self) def pick_highest_priority_route(self) -> NodeRoute: - final_route: NodeRoute = self.node_routes[-1] + highest_priority_route: NodeRoute = self.node_routes[-1] for route in self.node_routes: - if route.priority > final_route.priority: - final_route = route - return final_route + if route.priority > highest_priority_route.priority: + highest_priority_route = route + return highest_priority_route + + def delete_route( + self, route: NodeRouteType | None = None, route_id: UID | None = None + ) -> SyftError | None: + """ + Deletes a route from the peer's route list. + Takes O(n) where is n is the number of routes in self.node_routes. + + Args: + route (NodeRouteType): The route to be deleted; + route_id (UID): The id of the route to be deleted; + + Returns: + SyftError: If deleting failed + """ + if route_id: + try: + self.node_routes = [r for r in self.node_routes if r.id != route_id] + except Exception as e: + return SyftError( + message=f"Error deleting route with id {route_id}. Exception: {e}" + ) + + if route: + try: + self.node_routes = [r for r in self.node_routes if r != route] + except Exception as e: + return SyftError( + message=f"Error deleting route with id {route.id}. Exception: {e}" + ) + + return None def drop_veilid_route() -> Callable: diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index eea17b6b835..f3fa9b1ad1a 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -32,14 +32,29 @@ class NodeRoute: - def client_with_context(self, context: NodeServiceContext) -> SyftClient: + def client_with_context( + self, context: NodeServiceContext + ) -> SyftClient | SyftError: + """ + Convert the current route (self) to a connection (either HTTP, Veilid or Python) + and create a SyftClient from the connection. + + Args: + context (NodeServiceContext): The NodeServiceContext containing the node information. + + Returns: + SyftClient | SyftError: Returns the created SyftClient, or SyftError + if the client type is not valid or if the context's node is None. + """ connection = route_to_connection(route=self, context=context) client_type = connection.get_client_type() if isinstance(client_type, SyftError): return client_type return client_type(connection=connection, credentials=context.node.signing_key) - def validate_with_context(self, context: AuthedServiceContext) -> NodePeer: + def validate_with_context( + self, context: AuthedServiceContext + ) -> NodePeer | SyftError: # relative from .node_peer import NodePeer @@ -63,7 +78,7 @@ def validate_with_context(self, context: AuthedServiceContext) -> NodePeer: return SyftError(message="Signature Verification Failed in ping") # Step 2: Create a Node Peer with the given route - self_node_peer = context.node.settings.to(NodePeer) + self_node_peer: NodePeer = context.node.settings.to(NodePeer) self_node_peer.node_routes.append(self) return self_node_peer @@ -82,12 +97,20 @@ class HTTPNodeRoute(SyftObject, NodeRoute): priority: int = 1 def __eq__(self, other: Any) -> bool: - if isinstance(other, HTTPNodeRoute): - return hash(self) == hash(other) - return self == other + if not isinstance(other, HTTPNodeRoute): + return False + return hash(self) == hash(other) def __hash__(self) -> int: - return hash(self.host_or_ip) + hash(self.port) + hash(self.protocol) + return ( + hash(self.host_or_ip) + + hash(self.port) + + hash(self.protocol) + + hash(self.proxy_target_uid) + ) + + def __str__(self) -> str: + return f"{self.protocol}://{self.host_or_ip}:{self.port}" @serializable() @@ -122,12 +145,21 @@ def with_node(cls, node: AbstractNode) -> Self: return cls(id=worker_settings.id, worker_settings=worker_settings) def __eq__(self, other: Any) -> bool: - if isinstance(other, PythonNodeRoute): - return hash(self) == hash(other) - return self == other + if not isinstance(other, PythonNodeRoute): + return False + return hash(self) == hash(other) def __hash__(self) -> int: - return hash(self.worker_settings.id) + return ( + hash(self.worker_settings.id) + + hash(self.worker_settings.name) + + hash(self.worker_settings.node_type) + + hash(self.worker_settings.node_side_type) + + hash(self.worker_settings.signing_key) + ) + + def __str__(self) -> str: + return "PythonNodeRoute" @serializable() @@ -140,12 +172,12 @@ class VeilidNodeRoute(SyftObject, NodeRoute): priority: int = 1 def __eq__(self, other: Any) -> bool: - if isinstance(other, VeilidNodeRoute): - return hash(self) == hash(other) - return self == other + if not isinstance(other, VeilidNodeRoute): + return False + return hash(self) == hash(other) def __hash__(self) -> int: - return hash(self.vld_key) + return hash(self.vld_key) + hash(self.proxy_target_uid) NodeRouteTypeV1 = HTTPNodeRoute | PythonNodeRoute | VeilidNodeRoute diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 67b7b2df744..3b8cef606ac 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -91,7 +91,7 @@ def create_project( # For followers the leader node route is retrieved from its peer if leader_node.verify_key != context.node.verify_key: network_service = context.node.get_service("networkservice") - peer = network_service.stash.get_for_verify_key( + peer = network_service.stash.get_by_verify_key( credentials=context.node.verify_key, verify_key=leader_node.verify_key, ) @@ -228,7 +228,7 @@ def broadcast_event( for member in project.members: if member.verify_key != context.node.verify_key: # Retrieving the NodePeer Object to communicate with the node - peer = network_service.stash.get_for_verify_key( + peer = network_service.stash.get_by_verify_key( credentials=context.node.verify_key, verify_key=member.verify_key, ) @@ -239,8 +239,17 @@ def broadcast_event( + " Kindly exchange routes with the peer" ) peer = peer.ok() - client = peer.client_with_context(context) - event_result = client.api.services.project.add_event(project_event) + remote_client = peer.client_with_context(context=context) + if remote_client.is_err(): + return SyftError( + message=f"Failed to create remote client for peer: " + f"{peer.id}. Error: {remote_client.err()}" + ) + remote_client = remote_client.ok() + + event_result = remote_client.api.services.project.add_event( + project_event + ) if isinstance(event_result, SyftError): return event_result diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index ad3d4e62411..64dbb12ceee 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -32,6 +32,7 @@ from ...util.colors import SURFACE from ...util.fonts import FONT_CSS from ...util.fonts import ITABLES_CSS +from ...util.notebook_ui.components.sync import SyncTableObject from ...util.notebook_ui.notebook_addons import ARROW_ICON from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission @@ -45,7 +46,6 @@ from ..request.request import Request from ..response import SyftError from .sync_state import SyncState -from .sync_state import SyncView sketchy_tab = "‎ " * 4 @@ -182,6 +182,7 @@ class ObjectDiff(SyftObject): # StateTuple (compare 2 objects) "low_state", "high_state", ] + __syft_include_id_coll_repr__ = False def is_mock(self, side: str) -> bool: # An object is a mock object if it exists on both sides, @@ -742,12 +743,12 @@ def _coll_repr_(self) -> dict[str, Any]: if self.root_diff.low_obj is None: low_html = no_obj_html else: - low_html = SyncView(object=self.root_diff.low_obj).summary_html() + low_html = SyncTableObject(object=self.root_diff.low_obj).to_html() if self.root_diff.high_obj is None: high_html = no_obj_html else: - high_html = SyncView(object=self.root_diff.high_obj).summary_html() + high_html = SyncTableObject(object=self.root_diff.high_obj).to_html() return { "Merge status": self.status_badge(), diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index a884cc40faf..24866d70922 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -2,15 +2,11 @@ from enum import Enum from enum import auto import html -import json from typing import Any # third party -from IPython.display import Javascript -from IPython.display import display import ipywidgets as widgets from ipywidgets import Button -from ipywidgets import HBox from ipywidgets import HTML from ipywidgets import Layout from ipywidgets import VBox @@ -22,6 +18,7 @@ from ...client.sync_decision import SyncDirection from ...node.credentials import SyftVerifyKey from ...types.uid import UID +from ...util.notebook_ui.components.sync import SyncWidgetHeader from ...util.notebook_ui.notebook_addons import CSS_CODE from ..action.action_object import ActionObject from ..log.log import SyftLog @@ -31,7 +28,6 @@ from .diff_state import ObjectDiffBatch from .diff_state import ResolvedSyncState from .diff_state import SyncInstruction -from .sync_state import SyncView # Standard div Jupyter Lab uses for notebook outputs # This is needed to use alert styles from SyftSuccess and SyftError @@ -70,120 +66,6 @@ class DiffStatus(Enum): } -class HeaderWidget: - def __init__( - self, - item_type: str, - item_name: str, - item_id: str, - num_diffs: int, - source_side: str, - target_side: str, - ): - self.item_type = item_type - self.item_name = item_name - self.item_id = item_id - self.num_diffs = num_diffs - self.source_side = source_side - self.target_side = target_side - self.widget = self.create_widget() - - @classmethod - def from_object_diff_batch(cls, obj_diff_batch: ObjectDiffBatch) -> Self: - """ - ( - diff=self.obj_diff_batch.root_diff, - item_type=self.obj_diff_batch.root_type_name, - item_name="compute_mean", - item_id=self.obj_diff_batch.root_id, - num_diffs=2, - source_side="Low", - target_side="High", - ) - - """ - if obj_diff_batch.sync_direction == SyncDirection.LOW_TO_HIGH: - source_side = "Low side" - target_side = "High side" - else: - source_side = "High side" - target_side = "Low side" - - root_diff = obj_diff_batch.root_diff - root_obj = ( - root_diff.low_obj if root_diff.low_obj is not None else root_diff.high_obj - ) - obj_view = SyncView(object=root_obj) - return cls( - item_type=obj_view.object_type_name, - item_name=obj_view.main_object_description_str(), - item_id=str(root_obj.id.id), # type: ignore - num_diffs=len(obj_diff_batch.get_dependencies(include_roots=True)), - source_side=source_side, - target_side=target_side, - ) - - def copy_text_button(self, text: str) -> widgets.Widget: - button = widgets.Button( - icon="clone", - layout=widgets.Layout(width="25px", height="25px", margin="0", padding="0"), - ) - output = widgets.Output(layout=widgets.Layout(display="none")) - copy_js = Javascript(f"navigator.clipboard.writeText({json.dumps(text)})") - - def on_click(_: widgets.Button) -> None: - output.clear_output() - with output: - display(copy_js) - - button.on_click(on_click) - - return widgets.Box( - (button, output), - layout=widgets.Layout(display="flex", align_items="center"), - ) - - def create_item_type_label(self, item_type: str) -> HTML: - # TODO different bg for different types (levels?) - style = ( - "background-color: #C2DEF0; " - "border-radius: 4px; " - "padding: 4px 6px; " - "color: #373B7B;" - ) - return HTML( - value=f"{item_type.upper()}", - layout=Layout(margin="0 5px 0 0"), - ) - - def create_name_id_label(self, item_name: str, item_id: str) -> HTML: - item_id_short = item_id[:4] + "..." if len(item_id) > 4 else item_id - return HTML( - value=( - f"{item_name} " - f"#{item_id_short}" - ) - ) - - def create_widget(self) -> VBox: - type_box = self.create_item_type_label(self.item_type) - name_id_label = self.create_name_id_label(self.item_name, self.item_id) - copy_button = self.copy_text_button(self.item_id) - - first_line = HTML( - value="Syncing changes on" - ) - second_line = HBox( - [type_box, name_id_label, copy_button], layout=Layout(align_items="center") - ) - third_line = HTML( - value=f"This would sync {self.num_diffs} changes from {self.source_side} Node to {self.target_side} Node" # noqa: E501 - ) - fourth_line = HTML(value="
") - header = VBox([first_line, second_line, third_line, fourth_line]) - return header - - # TODO use ObjectDiff instead class ObjectDiffWidget: def __init__( @@ -579,7 +461,7 @@ def build(self) -> VBox: full_widget = widgets.VBox( [ - self.build_header().widget, + self.build_header(), self.main_object_diff_widget.widget, self.spacer(16), main_batch_items, @@ -613,5 +495,6 @@ def separator(self) -> widgets.HTML: layout=Layout(width="100%"), ) - def build_header(self) -> HeaderWidget: - return HeaderWidget.from_object_diff_batch(self.obj_diff_batch) + def build_header(self) -> HTML: + header_html = SyncWidgetHeader(diff_batch=self.obj_diff_batch).to_html() + return HTML(value=header_html) diff --git a/packages/syft/src/syft/service/sync/sync_state.py b/packages/syft/src/syft/service/sync/sync_state.py index c1e4e70e2f3..89c9791f0ca 100644 --- a/packages/syft/src/syft/service/sync/sync_state.py +++ b/packages/syft/src/syft/service/sync/sync_state.py @@ -21,10 +21,8 @@ from ...util.colors import SURFACE from ...util.fonts import FONT_CSS from ...util.fonts import ITABLES_CSS -from ..code.user_code import UserCode +from ...util.notebook_ui.components.sync import SyncTableObject from ..context import AuthedServiceContext -from ..job.job_stash import Job -from ..request.request import Request def get_hierarchy_level_prefix(level: int) -> str: @@ -34,78 +32,6 @@ def get_hierarchy_level_prefix(level: int) -> str: return "--" * level + " " -@serializable() -class SyncView(SyftObject): - __canonical_name__ = "SyncView" - __version__ = SYFT_OBJECT_VERSION_1 - - object: SyftObject - - def main_object_description_str(self) -> str: - if isinstance(self.object, UserCode): - return self.object.service_func_name - elif isinstance(self.object, Job): # type: ignore - return self.object.user_code_name - elif isinstance(self.object, Request): # type: ignore - # TODO: handle other requests - return f"Execute {self.object.code.service_func_name}" - else: - return "" - - @property - def object_type_name(self) -> str: - return type(self.object).__name__ - - def type_badge_class(self) -> str: - if isinstance(self.object, UserCode): - return "label-light-blue" - elif isinstance(self.object, Job): # type: ignore - return "label-light-blue" - elif isinstance(self.object, Request): # type: ignore - # TODO: handle other requests - return "label-light-purple" - else: - return "" - - def get_status_str(self) -> str: - if isinstance(self.object, UserCode): - return "" - elif isinstance(self.object, Job): # type: ignore - return f"Status: {self.object.status.value}" - elif isinstance(self.object, Request): - code = self.object.code - statusses = list(code.status.status_dict.values()) - if len(statusses) != 1: - raise ValueError("Request code should have exactly one status") - status_tuple = statusses[0] - status, _ = status_tuple - return status.value - else: - return "" - - def summary_html(self) -> str: - try: - type_html = f'
{self.object_type_name.upper()}
' - description_html = f"{self.main_object_description_str()}" - updated_delta_str = "29m ago" - updated_by = "john@doe.org" - status_str = self.get_status_str() - status_seperator = " • " if len(status_str) else "" - summary_html = f""" -
- {type_html} {description_html} -
-
- {status_str}{status_seperator}Updated by {updated_by} {updated_delta_str} -
- """ - summary_html = summary_html.replace("\n", "") - except Exception as e: - print("Failed to build table", e) - raise - return summary_html - - class SyncStateRow(SyftObject): """A row in the SyncState table""" @@ -139,7 +65,7 @@ def status_badge(self) -> dict[str, str]: return {"value": status.upper(), "type": badge_color} def _coll_repr_(self) -> dict[str, Any]: - obj_view = SyncView(object=self.object) + obj_view = SyncTableObject(object=self.object) if self.last_sync_date is not None: last_sync_date = self.last_sync_date @@ -154,7 +80,7 @@ def _coll_repr_(self) -> dict[str, Any]: last_sync_html = "

n/a

" return { "Status": self.status_badge(), - "Summary": obj_view.summary_html(), + "Summary": obj_view.to_html(), "Last Sync": last_sync_html, } diff --git a/packages/syft/src/syft/util/notebook_ui/__init__.py b/packages/syft/src/syft/util/notebook_ui/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/util/notebook_ui/components/__init__.py b/packages/syft/src/syft/util/notebook_ui/components/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/util/notebook_ui/components/base.py b/packages/syft/src/syft/util/notebook_ui/components/base.py new file mode 100644 index 00000000000..6fd09c458b4 --- /dev/null +++ b/packages/syft/src/syft/util/notebook_ui/components/base.py @@ -0,0 +1,20 @@ +# third party +import ipywidgets as widgets + +# relative +from ....types.syft_object import SYFT_OBJECT_VERSION_1 +from ....types.syft_object import SyftBaseObject + + +class HTMLComponentBase(SyftBaseObject): + __canonical_name__ = "HTMLComponentBase" + __version__ = SYFT_OBJECT_VERSION_1 + + def to_html(self) -> str: + raise NotImplementedError() + + def to_widget(self) -> widgets.Widget: + return widgets.HTML(value=self.to_html()) + + def _repr_html_(self) -> str: + return self.to_html() diff --git a/packages/syft/src/syft/util/notebook_ui/components/sync.py b/packages/syft/src/syft/util/notebook_ui/components/sync.py new file mode 100644 index 00000000000..37935b7479c --- /dev/null +++ b/packages/syft/src/syft/util/notebook_ui/components/sync.py @@ -0,0 +1,217 @@ +# stdlib +from typing import Any + +# third party +from pydantic import model_validator + +# relative +from ....client.sync_decision import SyncDirection +from ....service.code.user_code import UserCode +from ....service.job.job_stash import Job +from ....service.request.request import Request +from ....types.syft_object import SYFT_OBJECT_VERSION_1 +from ....types.syft_object import SyftObject +from ..notebook_addons import CSS_CODE +from .base import HTMLComponentBase + +COPY_ICON = ( + '' + '' + "" +) + +COPY_CSS = """ +.copy-container { + cursor: pointer; + border-radius: 3px; + padding: 0px 3px; + display: inline-block; + transition: background-color 0.3s; + user-select: none; + color: #B4B0BF; + overflow: hidden; + white-space: nowrap; +; +} + +.copy-container:hover { + background-color: #f5f5f5; +} + +.copy-container:active { + background-color: #ebebeb; +} + +.copy-text-display { + display: inline-block; + max-width: 50px; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + vertical-align: bottom; +} +""" + + +class CopyIDButton(HTMLComponentBase): + __canonical_name__ = "CopyButton" + __version__ = SYFT_OBJECT_VERSION_1 + copy_text: str + max_width: int = 50 + + def to_html(self) -> str: + button_html = f""" + +
+ #{self.copy_text}{COPY_ICON} +
+ """ + return button_html + + +class SyncTableObject(HTMLComponentBase): + __canonical_name__ = "SyncTableObject" + __version__ = SYFT_OBJECT_VERSION_1 + + object: SyftObject + + def main_object_description_str(self) -> str: + if isinstance(self.object, UserCode): + return self.object.service_func_name + elif isinstance(self.object, Job): # type: ignore + return self.object.user_code_name + elif isinstance(self.object, Request): # type: ignore + # TODO: handle other requests + return f"Execute {self.object.code.service_func_name}" + return "" # type: ignore + + def type_badge_class(self) -> str: + if isinstance(self.object, UserCode): + return "label-light-blue" + elif isinstance(self.object, Job): # type: ignore + return "label-light-blue" + elif isinstance(self.object, Request): # type: ignore + # TODO: handle other requests + return "label-light-purple" + return "label-light-blue" # type: ignore + + @property + def object_type_name(self) -> str: + return type(self.object).__name__ + + def get_status_str(self) -> str: + if isinstance(self.object, UserCode): + return "" + elif isinstance(self.object, Job): # type: ignore + return f"Status: {self.object.status.value}" + elif isinstance(self.object, Request): + code = self.object.code + statusses = list(code.status.status_dict.values()) + if len(statusses) != 1: + raise ValueError("Request code should have exactly one status") + status_tuple = statusses[0] + status, _ = status_tuple + return status.value + return "" # type: ignore + + def to_html(self) -> str: + badge_class = self.type_badge_class() + object_type = self.object_type_name.upper() + type_html = ( + f'
{object_type}
' + ) + + description_str = self.main_object_description_str() + description_style = "white-space: nowrap; overflow: ellipsis; flex-grow: 1;" + description_html = f'{description_str}' + + copy_id_button = CopyIDButton(copy_text=str(self.object.id.id), max_width=60) + + updated_delta_str = "29m ago" + updated_by = "john@doe.org" + status_str = self.get_status_str() + status_seperator = " • " if len(status_str) else "" + summary_html = f""" +
+
+ {type_html} {description_html} +
+ {copy_id_button.to_html()} +
+
+ + {status_str}{status_seperator}Updated by {updated_by} {updated_delta_str} + +
+ """ # noqa: E501 + summary_html = summary_html.replace("\n", "").replace(" ", "") + return summary_html + + +class SyncWidgetHeader(SyncTableObject): + diff_batch: Any + + @model_validator(mode="before") + @classmethod + def add_object(cls, values: dict) -> dict: + if "diff_batch" not in values: + raise ValueError("diff_batch is required") + diff_batch = values["diff_batch"] + values["object"] = diff_batch.root_diff.non_empty_object + return values + + def to_html(self) -> str: + # CSS Styles + style = CSS_CODE + + first_line_html = "Syncing changes on" + + badge_class = self.type_badge_class() + object_type = self.object_type_name.upper() + type_html = ( + f'
{object_type}
' + ) + + description_str = self.main_object_description_str() + description_style = "white-space: nowrap; overflow: ellipsis; flex-grow: 1;" + description_html = f'{description_str}' + + copy_id_button = CopyIDButton(copy_text=str(self.object.id.id), max_width=60) + + second_line_html = f""" +
+
+ {type_html} {description_html} +
+ {copy_id_button.to_html()} +
+ """ # noqa: E501 + + num_diffs = len(self.diff_batch.get_dependencies(include_roots=True)) + if self.diff_batch.sync_direction == SyncDirection.HIGH_TO_LOW: + source_side = "High" + target_side = "Low" + else: + source_side = "Low" + target_side = "High" + + # Third line HTML + third_line_html = f"This would sync {num_diffs} changes from {source_side} Node to {target_side} Node" # noqa: E501 + + header_html = f""" + + {first_line_html} + {second_line_html} + {third_line_html} +
+ """ + + return header_html diff --git a/packages/syft/src/syft/util/notebook_ui/notebook_addons.py b/packages/syft/src/syft/util/notebook_ui/notebook_addons.py index d045a14e745..c1b3b13b9c5 100644 --- a/packages/syft/src/syft/util/notebook_ui/notebook_addons.py +++ b/packages/syft/src/syft/util/notebook_ui/notebook_addons.py @@ -326,7 +326,6 @@ code-text; border-radius: 4px; padding: 0px 4px; - } .label-light-purple { diff --git a/scripts/hagrid_hash b/scripts/hagrid_hash index d99652fc424..fbfec796a91 100644 --- a/scripts/hagrid_hash +++ b/scripts/hagrid_hash @@ -1 +1 @@ -118a684a37c514b125ad9b67327bf999 +b5899a371c339fe23cc841f44c1a8f20 diff --git a/tests/integration/network/gateway_test.py b/tests/integration/network/gateway_test.py index 007aed1aa56..44fe5477a45 100644 --- a/tests/integration/network/gateway_test.py +++ b/tests/integration/network/gateway_test.py @@ -11,13 +11,17 @@ import syft as sy from syft.abstract_node import NodeType from syft.client.client import HTTPConnection +from syft.client.client import SyftClient from syft.client.domain_client import DomainClient from syft.client.gateway_client import GatewayClient from syft.client.registry import NetworkRegistry from syft.client.search import SearchResults from syft.service.dataset.dataset import Dataset from syft.service.network.node_peer import NodePeer +from syft.service.network.routes import HTTPNodeRoute +from syft.service.network.routes import NodeRouteType from syft.service.request.request import Request +from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.service.user.user_roles import ServiceRole @@ -58,6 +62,17 @@ def _random_hash() -> str: return uuid.uuid4().hex[:16] +def _remove_existing_peers(client: SyftClient) -> SyftSuccess | SyftError: + peers: list[NodePeer] | SyftError = client.api.services.network.get_all_peers() + if isinstance(peers, SyftError): + return peers + for peer in peers: + res = client.api.services.network.delete_peer_by_id(peer.id) + if isinstance(res, SyftError): + return res + return SyftSuccess(message="All peers removed.") + + @pytest.mark.skip(reason="Will be tested when the network registry URL works.") def test_network_registry_from_url() -> None: assert isinstance(sy.gateways, NetworkRegistry) @@ -74,21 +89,25 @@ def test_network_registry_env_var(set_env_var) -> None: def test_domain_connect_to_gateway( set_env_var, domain_1_port: int, gateway_port: int ) -> None: + # check if we can see the online gateways assert isinstance(sy.gateways, NetworkRegistry) assert len(sy.gateways.all_networks) == len(sy.gateways.online_networks) == 1 - gateway_client: GatewayClient = sy.login_as_guest(port=gateway_port) - + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) domain_client: DomainClient = sy.login( port=domain_1_port, email="info@openmined.org", password="changethis" ) + # connecting the domain to the gateway result = domain_client.connect_to_gateway(gateway_client) assert isinstance(result, SyftSuccess) - assert len(domain_client.peers) == 1 assert len(gateway_client.peers) == 1 + # check that the domain is online on the network assert len(sy.domains.all_domains) == 1 assert len(sy.domains.online_domains) == 1 @@ -124,15 +143,21 @@ def test_domain_connect_to_gateway( proxy_domain_client.api.endpoints.keys() == domain_client.api.endpoints.keys() ) + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + -def test_dataset_search(set_env_var, domain_1_port, gateway_port): +def test_dataset_search(set_env_var, gateway_port: int, domain_1_port: int) -> None: """ Scenario: Connecting a domain node to a gateway node. The domain client then upload a dataset, which should be searchable by the syft network. People who install syft can see the mock data and metadata of the uploaded datasets """ # login to the domain and gateway - gateway_client: GatewayClient = sy.login_as_guest(port=gateway_port) + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) domain_client: DomainClient = sy.login( port=domain_1_port, email="info@openmined.org", password="changethis" ) @@ -170,10 +195,18 @@ def test_dataset_search(set_env_var, domain_1_port, gateway_port): # the domain client delete the dataset domain_client.api.services.dataset.delete_by_uid(uid=dataset.id) + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) -def test_domain_gateway_user_code(set_env_var, domain_1_port, gateway_port): + +def test_domain_gateway_user_code( + set_env_var, domain_1_port: int, gateway_port: int +) -> None: # login to the domain and gateway - gateway_client: GatewayClient = sy.login_as_guest(port=gateway_port) + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) domain_client: DomainClient = sy.login( port=domain_1_port, email="info@openmined.org", password="changethis" ) @@ -232,3 +265,459 @@ def mock_function(asset): result = proxy_ds.code.mock_function(asset=asset) final_result = result.get() assert (final_result == input_data + 1).all() + + # the domain client delete the dataset + domain_client.api.services.dataset.delete_by_uid(uid=dataset.id) + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_deleting_peers(set_env_var, domain_1_port: int, gateway_port: int) -> None: + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + assert len(domain_client.peers) == 1 + assert len(gateway_client.peers) == 1 + # check that the domain is online on the network + assert len(sy.domains.all_domains) == 1 + assert len(sy.domains.online_domains) == 1 + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + # check that removing peers work as expected + assert len(sy.gateways.all_networks) == 1 + assert len(sy.domains.all_domains) == 0 + assert len(sy.domains.all_domains) == 0 + assert len(sy.domains.online_domains) == 0 + assert len(domain_client.peers) == 0 + assert len(gateway_client.peers) == 0 + + # reconnect the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + assert len(domain_client.peers) == 1 + assert len(gateway_client.peers) == 1 + # check that the domain + assert len(sy.domains.all_domains) == 1 + assert len(sy.domains.online_domains) == 1 + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + # check that removing peers work as expected + assert len(sy.domains.all_domains) == 0 + assert len(sy.domains.all_domains) == 0 + assert len(sy.domains.online_domains) == 0 + assert len(domain_client.peers) == 0 + assert len(gateway_client.peers) == 0 + + +def test_add_route(set_env_var, gateway_port: int, domain_1_port: int) -> None: + """ + Test the network service's `add_route` functionalities to add routes directly + for a self domain. + Scenario: Connect a domain to a gateway. The gateway adds 2 new routes to the domain + and check their priorities. + Then add an existed route and check if its priority gets updated. + Check for the gateway if the proxy client to connect to the domain uses the + route with the highest priority. + """ + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + assert len(domain_client.peers) == 1 + assert len(gateway_client.peers) == 1 + + # add a new route to connect to the domain + new_route = HTTPNodeRoute(host_or_ip="localhost", port=10000) + domain_peer: NodePeer = gateway_client.api.services.network.get_all_peers()[0] + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=new_route + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert len(domain_peer.node_routes) == 2 + assert domain_peer.node_routes[-1].port == new_route.port + + # adding another route to the domain + new_route2 = HTTPNodeRoute(host_or_ip="localhost", port=10001) + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert len(domain_peer.node_routes) == 3 + assert domain_peer.node_routes[-1].port == new_route2.port + assert domain_peer.node_routes[-1].priority == 3 + + # add an existed route to the domain and check its priority gets updated + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=domain_peer.node_routes[0] + ) + assert "route already exists" in res.message + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert len(domain_peer.node_routes) == 3 + assert domain_peer.node_routes[0].priority == 4 + + # the gateway gets the proxy client to the domain + # the proxy client should use the route with the highest priority + proxy_domain_client = gateway_client.peers[0] + assert isinstance(proxy_domain_client, DomainClient) + + # add another existed route (port 10000) + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=domain_peer.node_routes[1] + ) + assert "route already exists" in res.message + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert len(domain_peer.node_routes) == 3 + assert domain_peer.node_routes[1].priority == 5 + # getting the proxy client using the current highest priority route should + # give back an error since it is a route with a random port (10000) + proxy_domain_client = gateway_client.peers[0] + assert isinstance(proxy_domain_client, SyftError) + assert "Failed to establish a connection with" in proxy_domain_client.message + + # the routes the domain client uses to connect to the gateway should stay the same + gateway_peer: NodePeer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 1 + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_delete_route(set_env_var, gateway_port: int, domain_1_port: int) -> None: + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + assert len(domain_client.peers) == 1 + assert len(gateway_client.peers) == 1 + + # add a new route to connect to the domain + new_route = HTTPNodeRoute(host_or_ip="localhost", port=10000) + domain_peer: NodePeer = gateway_client.api.services.network.get_all_peers()[0] + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=new_route + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert len(domain_peer.node_routes) == 2 + assert domain_peer.node_routes[-1].port == new_route.port + + # delete the added route + res = gateway_client.api.services.network.delete_route( + peer_verify_key=domain_peer.verify_key, route=new_route + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert len(domain_peer.node_routes) == 1 + assert domain_peer.node_routes[-1].port == domain_1_port + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_add_route_on_peer(set_env_var, gateway_port: int, domain_1_port: int) -> None: + """ + Test the `add_route_on_peer` of network service. + Connect a domain to a gateway. + The gateway adds 2 new routes for the domain and check their priorities. + Then add an existed route and check if its priority gets updated. + Then the domain adds a route to itself for the gateway. + """ + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + assert len(domain_client.peers) == 1 + assert len(gateway_client.peers) == 1 + gateway_peer: NodePeer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 1 + assert gateway_peer.node_routes[-1].priority == 1 + + # adding a new route for the domain + new_route = HTTPNodeRoute(host_or_ip="localhost", port=10000) + domain_peer: NodePeer = gateway_client.api.services.network.get_all_peers()[0] + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=new_route + ) + assert isinstance(res, SyftSuccess) + gateway_peer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 2 + assert gateway_peer.node_routes[-1].port == new_route.port + assert gateway_peer.node_routes[-1].priority == 2 + + # adding another route for the domain + new_route2 = HTTPNodeRoute(host_or_ip="localhost", port=10001) + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + gateway_peer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 3 + assert gateway_peer.node_routes[-1].port == new_route2.port + assert gateway_peer.node_routes[-1].priority == 3 + + # add an existed route for the domain and check its priority gets updated + existed_route = gateway_peer.node_routes[0] + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=existed_route + ) + assert "route already exists" in res.message + assert isinstance(res, SyftSuccess) + gateway_peer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 3 + assert gateway_peer.node_routes[0].priority == 4 + + # the domain calls `add_route_on_peer` to to add a route to itself for the gateway + assert len(domain_peer.node_routes) == 1 + res = domain_client.api.services.network.add_route_on_peer( + peer=gateway_peer, route=new_route + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + assert domain_peer.node_routes[-1].port == new_route.port + assert len(domain_peer.node_routes) == 2 + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_delete_route_on_peer( + set_env_var, gateway_port: int, domain_1_port: int +) -> None: + """ + Connect a domain to a gateway, the gateway adds 2 new routes for the domain + , then delete them. + """ + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + + # gateway adds 2 new routes for the domain + new_route = HTTPNodeRoute(host_or_ip="localhost", port=10000) + new_route2 = HTTPNodeRoute(host_or_ip="localhost", port=10001) + domain_peer: NodePeer = gateway_client.api.services.network.get_all_peers()[0] + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=new_route + ) + assert isinstance(res, SyftSuccess) + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + + gateway_peer: NodePeer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 3 + + # gateway delete the routes for the domain + res = gateway_client.api.services.network.delete_route_on_peer( + peer=domain_peer, route_id=new_route.id + ) + assert isinstance(res, SyftSuccess) + gateway_peer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 2 + + res = gateway_client.api.services.network.delete_route_on_peer( + peer=domain_peer, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + gateway_peer = domain_client.peers[0] + assert len(gateway_peer.node_routes) == 1 + + # gateway deletes the last the route to it for the domain + last_route: NodeRouteType = gateway_peer.node_routes[0] + res = gateway_client.api.services.network.delete_route_on_peer( + peer=domain_peer, route=last_route + ) + assert isinstance(res, SyftSuccess) + assert "There is no routes left" in res.message + assert len(domain_client.peers) == 0 # gateway is no longer a peer of the domain + + # The gateway client also removes the domain as a peer + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_update_route_priority( + set_env_var, gateway_port: int, domain_1_port: int +) -> None: + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + + # gateway adds 2 new routes to the domain + new_route = HTTPNodeRoute(host_or_ip="localhost", port=10000) + new_route2 = HTTPNodeRoute(host_or_ip="localhost", port=10001) + domain_peer: NodePeer = gateway_client.api.services.network.get_all_peers()[0] + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=new_route + ) + assert isinstance(res, SyftSuccess) + res = gateway_client.api.services.network.add_route( + peer_verify_key=domain_peer.verify_key, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + + # check if the priorities of the routes are correct + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + routes_port_priority: dict = { + route.port: route.priority for route in domain_peer.node_routes + } + assert routes_port_priority[domain_1_port] == 1 + assert routes_port_priority[new_route.port] == 2 + assert routes_port_priority[new_route2.port] == 3 + + # update the priorities for the routes + res = gateway_client.api.services.network.update_route_priority( + peer_verify_key=domain_peer.verify_key, route=new_route, priority=5 + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + routes_port_priority: dict = { + route.port: route.priority for route in domain_peer.node_routes + } + assert routes_port_priority[new_route.port] == 5 + + res = gateway_client.api.services.network.update_route_priority( + peer_verify_key=domain_peer.verify_key, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + domain_peer = gateway_client.api.services.network.get_all_peers()[0] + routes_port_priority: dict = { + route.port: route.priority for route in domain_peer.node_routes + } + assert routes_port_priority[new_route2.port] == 6 + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + +def test_update_route_priority_on_peer( + set_env_var, gateway_port: int, domain_1_port: int +) -> None: + # login to the domain and gateway + gateway_client: GatewayClient = sy.login( + port=gateway_port, email="info@openmined.org", password="changethis" + ) + domain_client: DomainClient = sy.login( + port=domain_1_port, email="info@openmined.org", password="changethis" + ) + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess) + + # connecting the domain to the gateway + result = domain_client.connect_to_gateway(gateway_client) + assert isinstance(result, SyftSuccess) + + # gateway adds 2 new routes for the domain to itself + domain_peer: NodePeer = gateway_client.api.services.network.get_all_peers()[0] + new_route = HTTPNodeRoute(host_or_ip="localhost", port=10000) + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=new_route + ) + assert isinstance(res, SyftSuccess) + + new_route2 = HTTPNodeRoute(host_or_ip="localhost", port=10001) + res = gateway_client.api.services.network.add_route_on_peer( + peer=domain_peer, route=new_route2 + ) + assert isinstance(res, SyftSuccess) + + # check if the priorities of the routes are correct + gateway_peer = domain_client.api.services.network.get_all_peers()[0] + routes_port_priority: dict = { + route.port: route.priority for route in gateway_peer.node_routes + } + assert routes_port_priority[gateway_port] == 1 + assert routes_port_priority[new_route.port] == 2 + assert routes_port_priority[new_route2.port] == 3 + + # gateway updates the route priorities for the domain remotely + res = gateway_client.api.services.network.update_route_priority_on_peer( + peer=domain_peer, route=new_route, priority=5 + ) + assert isinstance(res, SyftSuccess) + res = gateway_client.api.services.network.update_route_priority_on_peer( + peer=domain_peer, route=gateway_peer.node_routes[0] + ) + assert isinstance(res, SyftSuccess) + + gateway_peer = domain_client.api.services.network.get_all_peers()[0] + routes_port_priority: dict = { + route.port: route.priority for route in gateway_peer.node_routes + } + assert routes_port_priority[new_route.port] == 5 + assert routes_port_priority[gateway_port] == 6 + + # Remove existing peers + assert isinstance(_remove_existing_peers(domain_client), SyftSuccess) + assert isinstance(_remove_existing_peers(gateway_client), SyftSuccess)