From 255b4645472652f8503ce16cd007fce9563b30d9 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 9 Feb 2024 17:03:11 +0530 Subject: [PATCH] remove syft imports from Orchestra - define NodeType and NodeSideType in hagrid util - Add a util class to import syft modules --- packages/hagrid/hagrid/orchestra.py | 22 ++++---------- packages/hagrid/hagrid/util.py | 45 +++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/packages/hagrid/hagrid/orchestra.py b/packages/hagrid/hagrid/orchestra.py index eecb5ca51bd..32af7fb198b 100644 --- a/packages/hagrid/hagrid/orchestra.py +++ b/packages/hagrid/hagrid/orchestra.py @@ -17,27 +17,13 @@ # relative from .cli import str_to_bool -from .dummynum import DummyNum from .grammar import find_available_port from .names import random_name +from .util import ImportFromSyft +from .util import NodeSideType +from .util import NodeType from .util import shell -try: - # syft absolute - from syft.abstract_node import NodeSideType - from syft.abstract_node import NodeType - from syft.protocol.data_protocol import stage_protocol_changes - from syft.service.response import SyftError -except Exception: # nosec - NodeSideType = DummyNum - NodeType = DummyNum - - def stage_protocol_changes(*args: Any, **kwargs: Any) -> None: - pass - - SyftError = DummyNum - # print("Please install syft with `pip install syft`") - DEFAULT_PORT = 8080 DEFAULT_URL = "http://localhost" # Gevent used instead of threading module ,as we monkey patch gevent in syft @@ -203,6 +189,7 @@ def register( institution: Optional[str] = None, website: Optional[str] = None, ) -> Any: + SyftError = ImportFromSyft.import_syft_error() if not email: email = input("Email: ") if not password: @@ -248,6 +235,7 @@ def deploy_to_python( create_producer: bool = False, queue_port: Optional[int] = None, ) -> Optional[NodeHandle]: + stage_protocol_changes = ImportFromSyft.import_stage_protocol_changes() sy = get_syft_client() if sy is None: return sy diff --git a/packages/hagrid/hagrid/util.py b/packages/hagrid/hagrid/util.py index 58bcdebf724..c4a9c82e83c 100644 --- a/packages/hagrid/hagrid/util.py +++ b/packages/hagrid/hagrid/util.py @@ -1,12 +1,57 @@ # stdlib +from enum import Enum import os import subprocess # nosec import sys from typing import Any +from typing import Callable from typing import Tuple from typing import Union from urllib.parse import urlparse +# relative +from .dummynum import DummyNum + + +class NodeSideType(str, Enum): + LOW_SIDE = "low" + HIGH_SIDE = "high" + + +class NodeType(Enum): + DOMAIN = "domain" + NETWORK = "network" + ENCLAVE = "enclave" + GATEWAY = "gateway" + + def __str__(self) -> str: + # Use values when transforming NodeType to str + return self.value + + +class ImportFromSyft: + @staticmethod + def import_syft_error() -> Callable: + try: + # syft absolute + from syft.service.response import SyftError + except Exception: + SyftError = DummyNum + + return SyftError + + @staticmethod + def import_stage_protocol_changes() -> Callable: + try: + # syft absolute + from syft.protocol.data_protocol import stage_protocol_changes + except Exception: + + def stage_protocol_changes(*args: Any, **kwargs: Any) -> None: + pass + + return stage_protocol_changes + def from_url(url: str) -> Tuple[str, str, int, str, Union[Any, str]]: try: