Skip to content

Commit

Permalink
Merge branch 'dev' into madhava/torch_cp12
Browse files Browse the repository at this point in the history
  • Loading branch information
kiendang authored Mar 6, 2024
2 parents 11f7b94 + 06226e1 commit 17e8a4a
Show file tree
Hide file tree
Showing 70 changed files with 1,823 additions and 2,078 deletions.
43 changes: 23 additions & 20 deletions packages/grid/backend/grid/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
import os
import secrets
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

# third party
from pydantic import AnyHttpUrl
from pydantic import BaseSettings
from pydantic import EmailStr
from pydantic import HttpUrl
from pydantic import validator
from pydantic import field_validator
from pydantic import model_validator
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict
from typing_extensions import Self

_truthy = {"yes", "y", "true", "t", "on", "1"}
_falsy = {"no", "n", "false", "f", "off", "0"}
Expand Down Expand Up @@ -50,7 +52,8 @@ class Settings(BaseSettings):
# "http://localhost:8080", "http://local.dockertoolbox.tiangolo.com"]'
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []

@validator("BACKEND_CORS_ORIGINS", pre=True)
@field_validator("BACKEND_CORS_ORIGINS", mode="before")
@classmethod
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
Expand All @@ -62,7 +65,8 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str

SENTRY_DSN: Optional[HttpUrl] = None

@validator("SENTRY_DSN", pre=True)
@field_validator("SENTRY_DSN", mode="before")
@classmethod
def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]:
if v is None or len(v) == 0:
return None
Expand All @@ -76,27 +80,28 @@ def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]:
EMAILS_FROM_EMAIL: Optional[EmailStr] = None
EMAILS_FROM_NAME: Optional[str] = None

@validator("EMAILS_FROM_NAME")
def get_project_name(cls, v: Optional[str], values: Dict[str, Any]) -> str:
if not v:
return values["PROJECT_NAME"]
return v
@model_validator(mode="after")
def get_project_name(self) -> Self:
if not self.EMAILS_FROM_NAME:
self.EMAILS_FROM_NAME = self.PROJECT_NAME

return self

EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
EMAIL_TEMPLATES_DIR: str = os.path.expandvars(
"$HOME/app/grid/email-templates/build"
)
EMAILS_ENABLED: bool = False

@validator("EMAILS_ENABLED", pre=True)
def get_emails_enabled(cls, v: bool, values: Dict[str, Any]) -> bool:
return bool(
values.get("SMTP_HOST")
and values.get("SMTP_PORT")
and values.get("EMAILS_FROM_EMAIL")
@model_validator(mode="after")
def get_emails_enabled(self) -> Self:
self.EMAILS_ENABLED = bool(
self.SMTP_HOST and self.SMTP_PORT and self.EMAILS_FROM_EMAIL
)

DEFAULT_ROOT_EMAIL: EmailStr = EmailStr("[email protected]")
return self

DEFAULT_ROOT_EMAIL: EmailStr = "[email protected]"
DEFAULT_ROOT_PASSWORD: str = "changethis"
USERS_OPEN_REGISTRATION: bool = False

Expand Down Expand Up @@ -149,9 +154,7 @@ def get_emails_enabled(cls, v: bool, values: Dict[str, Any]) -> bool:
True if os.getenv("TEST_MODE", "false").lower() == "true" else False
)
ASSOCIATION_TIMEOUT: int = 10

class Config:
case_sensitive = True
model_config = SettingsConfigDict(case_sensitive=True)


settings = Settings()
13 changes: 4 additions & 9 deletions packages/grid/backend/grid/logger/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Union

# third party
from pydantic import BaseSettings
from pydantic_settings import BaseSettings


# LOGURU_LEVEL type for version>3.8
Expand Down Expand Up @@ -40,14 +40,9 @@ class LogConfig(BaseSettings):

LOGURU_LEVEL: str = LogLevel.INFO.value
LOGURU_SINK: Optional[str] = "/var/log/pygrid/grid.log"
LOGURU_COMPRESSION: Optional[str]
LOGURU_ROTATION: Union[
Optional[str],
Optional[int],
Optional[time],
Optional[timedelta],
]
LOGURU_RETENTION: Union[Optional[str], Optional[int], Optional[timedelta]]
LOGURU_COMPRESSION: Optional[str] = None
LOGURU_ROTATION: Union[str, int, time, timedelta, None] = None
LOGURU_RETENTION: Union[str, int, timedelta, None] = None
LOGURU_COLORIZE: Optional[bool] = True
LOGURU_SERIALIZE: Optional[bool] = False
LOGURU_BACKTRACE: Optional[bool] = True
Expand Down
6 changes: 5 additions & 1 deletion packages/hagrid/hagrid/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import os
import subprocess # nosec
import sys
from threading import Thread
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -598,7 +599,10 @@ def shutdown(
elif "No resource found to remove for project" in land_output:
print(f" ✅ {snake_name} Container does not exist")
else:
print(f"❌ Unable to remove container: {snake_name} :{land_output}")
print(
f"❌ Unable to remove container: {snake_name} :{land_output}",
file=sys.stderr,
)

@staticmethod
def reset(name: str, deployment_type_enum: DeploymentType) -> None:
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ syft =
pyarrow==14.0.1
# pycapnp is beta version, update to stable version when available
pycapnp==2.0.0b2
pydantic[email]==1.10.13
pydantic[email]==2.6.0
pydantic-settings==2.2.1
pymongo==4.6.1
pynacl==1.5.0
pyzmq>=23.2.1,<=25.1.1
Expand Down
20 changes: 10 additions & 10 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ class APIEndpoint(SyftObject):
module_path: str
name: str
description: str
doc_string: Optional[str]
doc_string: Optional[str] = None
signature: Signature
has_self: bool = False
pre_kwargs: Optional[Dict[str, Any]]
warning: Optional[APIEndpointWarning]
pre_kwargs: Optional[Dict[str, Any]] = None
warning: Optional[APIEndpointWarning] = None


@serializable()
Expand All @@ -134,10 +134,10 @@ class LibEndpoint(SyftBaseObject):
module_path: str
name: str
description: str
doc_string: Optional[str]
doc_string: Optional[str] = None
signature: Signature
has_self: bool = False
pre_kwargs: Optional[Dict[str, Any]]
pre_kwargs: Optional[Dict[str, Any]] = None


@serializable(attrs=["signature", "credentials", "serialized_message"])
Expand Down Expand Up @@ -207,7 +207,7 @@ class SyftAPIData(SyftBaseObject):
__version__ = SYFT_OBJECT_VERSION_1

# fields
data: Any
data: Any = None

def sign(self, credentials: SyftSigningKey) -> SignedSyftAPICall:
signed_message = credentials.signing_key.sign(_serialize(self, to_bytes=True))
Expand All @@ -233,9 +233,9 @@ class RemoteFunction(SyftObject):
signature: Signature
path: str
make_call: Callable
pre_kwargs: Optional[Dict[str, Any]]
pre_kwargs: Optional[Dict[str, Any]] = None
communication_protocol: PROTOCOL_TYPE
warning: Optional[APIEndpointWarning]
warning: Optional[APIEndpointWarning] = None

@property
def __ipython_inspector_signature_override__(self) -> Optional[Signature]:
Expand Down Expand Up @@ -1078,5 +1078,5 @@ def validate_callable_args_and_kwargs(
return _valid_args, _valid_kwargs


RemoteFunction.update_forward_refs()
RemoteUserCodeFunction.update_forward_refs()
RemoteFunction.model_rebuild(force=True)
RemoteUserCodeFunction.model_rebuild(force=True)
19 changes: 12 additions & 7 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# third party
from argon2 import PasswordHasher
import pydantic
from pydantic import field_validator
import requests
from requests import Response
from requests import Session
Expand Down Expand Up @@ -135,14 +135,19 @@ class HTTPConnection(NodeConnection):
__canonical_name__ = "HTTPConnection"
__version__ = SYFT_OBJECT_VERSION_1

proxy_target_uid: Optional[UID]
url: GridURL
proxy_target_uid: Optional[UID] = None
routes: Type[Routes] = Routes
session_cache: Optional[Session]
session_cache: Optional[Session] = None

@pydantic.validator("url", pre=True, always=True)
def make_url(cls, v: Union[GridURL, str]) -> GridURL:
return GridURL.from_url(v).as_container_host()
@field_validator("url", mode="before")
@classmethod
def make_url(cls, v: Any) -> Any:
return (
GridURL.from_url(v).as_container_host()
if isinstance(v, (str, GridURL))
else v
)

def with_proxy(self, proxy_target_uid: UID) -> Self:
return HTTPConnection(url=self.url, proxy_target_uid=proxy_target_uid)
Expand Down Expand Up @@ -329,7 +334,7 @@ class PythonConnection(NodeConnection):
__version__ = SYFT_OBJECT_VERSION_1

node: AbstractNode
proxy_target_uid: Optional[UID]
proxy_target_uid: Optional[UID] = None

def with_proxy(self, proxy_target_uid: UID) -> Self:
return PythonConnection(node=self.node, proxy_target_uid=proxy_target_uid)
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/client/gateway_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class ProxyClient(SyftObject):
__version__ = SYFT_OBJECT_VERSION_1

routing_client: GatewayClient
node_type: Optional[NodeType]
node_type: Optional[NodeType] = None

def retrieve_nodes(self) -> List[NodePeer]:
if self.node_type in [NodeType.DOMAIN, NodeType.ENCLAVE]:
Expand Down
14 changes: 8 additions & 6 deletions packages/syft/src/syft/custom_worker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# third party
import docker
from packaging import version
from pydantic import validator
from pydantic import field_validator
from typing_extensions import Self
import yaml

Expand Down Expand Up @@ -54,7 +54,8 @@ class CustomBuildConfig(SyftBaseModel):
# f"Python version must be between {PYTHON_MIN_VER} and {PYTHON_MAX_VER}"
# )

@validator("python_packages")
@field_validator("python_packages")
@classmethod
def validate_python_packages(cls, pkgs: List[str]) -> List[str]:
for pkg in pkgs:
ver_parts: Union[tuple, list] = ()
Expand Down Expand Up @@ -114,7 +115,7 @@ def get_signature(self) -> str:
class PrebuiltWorkerConfig(WorkerConfig):
# tag that is already built and pushed in some registry
tag: str
description: Optional[str]
description: Optional[str] = None

def __str__(self) -> str:
if self.description:
Expand All @@ -129,10 +130,11 @@ def set_description(self, description_text: str) -> None:
@serializable()
class DockerWorkerConfig(WorkerConfig):
dockerfile: str
file_name: Optional[str]
description: Optional[str]
file_name: Optional[str] = None
description: Optional[str] = None

@validator("dockerfile")
@field_validator("dockerfile")
@classmethod
def validate_dockerfile(cls, dockerfile: str) -> str:
if not dockerfile:
raise ValueError("Dockerfile cannot be empty")
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/custom_worker/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class ContainerStatus(BaseModel):
ready: bool
running: bool
waiting: bool
reason: Optional[str] # when waiting=True
message: Optional[str] # when waiting=True
startedAt: Optional[str] # when running=True
reason: Optional[str] = None # when waiting=True
message: Optional[str] = None # when waiting=True
startedAt: Optional[str] = None # when running=True

@classmethod
def from_status(cls, cstatus: dict) -> Self:
Expand Down
12 changes: 7 additions & 5 deletions packages/syft/src/syft/external/oblv/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# third party
from oblv_ctl import OblvClient
from pydantic import validator
from pydantic import field_validator
import requests

# relative
Expand Down Expand Up @@ -46,10 +46,11 @@
class OblvMetadata(EnclaveMetadata):
"""Contains Metadata to connect to Oblivious Enclave"""

deployment_id: Optional[str]
oblv_client: Optional[OblvClient]
deployment_id: Optional[str] = None
oblv_client: Optional[OblvClient] = None

@validator("deployment_id")
@field_validator("deployment_id")
@classmethod
def check_valid_deployment_id(cls, deployment_id: str) -> str:
if not deployment_id and not LOCAL_MODE:
raise ValueError(
Expand All @@ -59,7 +60,8 @@ def check_valid_deployment_id(cls, deployment_id: str) -> str:
)
return deployment_id

@validator("oblv_client")
@field_validator("oblv_client")
@classmethod
def check_valid_oblv_client(cls, oblv_client: OblvClient) -> OblvClient:
if not oblv_client and not LOCAL_MODE:
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions packages/syft/src/syft/node/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from nacl.encoding import HexEncoder
from nacl.signing import SigningKey
from nacl.signing import VerifyKey
import pydantic
from pydantic import field_validator

# relative
from ..serde.serializable import serializable
Expand Down Expand Up @@ -54,8 +54,9 @@ def __hash__(self) -> int:
class SyftSigningKey(SyftBaseModel):
signing_key: SigningKey

@pydantic.validator("signing_key", pre=True, always=True)
def make_signing_key(cls, v: Union[str, SigningKey]) -> SigningKey:
@field_validator("signing_key", mode="before")
@classmethod
def make_signing_key(cls, v: Any) -> Any:
return SigningKey(bytes.fromhex(v)) if isinstance(v, str) else v

@property
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,8 +1213,8 @@ def handle_api_call_with_unsigned_result(
if api_call.path not in user_config_registry:
if ServiceConfigRegistry.path_exists(api_call.path):
return SyftError(
message=f"As a `{role}`,"
f"you have has no access to: {api_call.path}"
message=f"As a `{role}`, "
f"you have no access to: {api_call.path}"
)
else:
return SyftError(
Expand Down
Loading

0 comments on commit 17e8a4a

Please sign in to comment.