diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e78c3f534f5..9b38fa96129 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -150,15 +150,3 @@ repos: "--install-types", "--non-interactive", ] - - repo: local - hooks: - - id: flynt - always_run: true - name: flynt - entry: flynt - exclude: ^(packages/grid/ansible) - args: [--fail-on-change] - types: [python] - language: python - additional_dependencies: - - flynt diff --git a/notebooks/course3/DO Notebook - Upload Dataset Revamped.ipynb b/notebooks/course3/mock_api_notebooks/DO Notebook - Upload Dataset Revamped.ipynb similarity index 100% rename from notebooks/course3/DO Notebook - Upload Dataset Revamped.ipynb rename to notebooks/course3/mock_api_notebooks/DO Notebook - Upload Dataset Revamped.ipynb diff --git a/notebooks/course3/DO Reviews Requests.ipynb b/notebooks/course3/mock_api_notebooks/DO Reviews Requests.ipynb similarity index 100% rename from notebooks/course3/DO Reviews Requests.ipynb rename to notebooks/course3/mock_api_notebooks/DO Reviews Requests.ipynb diff --git a/notebooks/course3/DO Selects a Dataset.ipynb b/notebooks/course3/mock_api_notebooks/DO Selects a Dataset.ipynb similarity index 100% rename from notebooks/course3/DO Selects a Dataset.ipynb rename to notebooks/course3/mock_api_notebooks/DO Selects a Dataset.ipynb diff --git a/notebooks/course3/DO Uploads a Dataset.ipynb b/notebooks/course3/mock_api_notebooks/DO Uploads a Dataset.ipynb similarity index 100% rename from notebooks/course3/DO Uploads a Dataset.ipynb rename to notebooks/course3/mock_api_notebooks/DO Uploads a Dataset.ipynb diff --git a/notebooks/course3/DS Access a Logged Session.ipynb b/notebooks/course3/mock_api_notebooks/DS Access a Logged Session.ipynb similarity index 100% rename from notebooks/course3/DS Access a Logged Session.ipynb rename to notebooks/course3/mock_api_notebooks/DS Access a Logged Session.ipynb diff --git a/notebooks/course3/DS Performs Analysis on the Dataset.ipynb b/notebooks/course3/mock_api_notebooks/DS Performs Analysis on the Dataset.ipynb similarity index 100% rename from notebooks/course3/DS Performs Analysis on the Dataset.ipynb rename to notebooks/course3/mock_api_notebooks/DS Performs Analysis on the Dataset.ipynb diff --git a/notebooks/course3/DS Search Domains via Networks and Login.ipynb b/notebooks/course3/mock_api_notebooks/DS Search Domains via Networks and Login.ipynb similarity index 100% rename from notebooks/course3/DS Search Domains via Networks and Login.ipynb rename to notebooks/course3/mock_api_notebooks/DS Search Domains via Networks and Login.ipynb diff --git a/notebooks/course3/DS Search and Select a Network.ipynb b/notebooks/course3/mock_api_notebooks/DS Search and Select a Network.ipynb similarity index 100% rename from notebooks/course3/DS Search and Select a Network.ipynb rename to notebooks/course3/mock_api_notebooks/DS Search and Select a Network.ipynb diff --git a/notebooks/course3/DS Searches for DataSet.ipynb b/notebooks/course3/mock_api_notebooks/DS Searches for DataSet.ipynb similarity index 100% rename from notebooks/course3/DS Searches for DataSet.ipynb rename to notebooks/course3/mock_api_notebooks/DS Searches for DataSet.ipynb diff --git a/notebooks/course3/DS Selects a Dataset.ipynb b/notebooks/course3/mock_api_notebooks/DS Selects a Dataset.ipynb similarity index 100% rename from notebooks/course3/DS Selects a Dataset.ipynb rename to notebooks/course3/mock_api_notebooks/DS Selects a Dataset.ipynb diff --git a/notebooks/course3/mock_api_notebooks/README.md b/notebooks/course3/mock_api_notebooks/README.md new file mode 100644 index 00000000000..3862111d991 --- /dev/null +++ b/notebooks/course3/mock_api_notebooks/README.md @@ -0,0 +1 @@ +These notebooks represent Mock APIs for the syft library. Actual implementation of the API may slightly differ from the mock ones. \ No newline at end of file diff --git a/notebooks/course3/Request Results (Without AutoDP).ipynb b/notebooks/course3/mock_api_notebooks/Request Results (Without AutoDP).ipynb similarity index 100% rename from notebooks/course3/Request Results (Without AutoDP).ipynb rename to notebooks/course3/mock_api_notebooks/Request Results (Without AutoDP).ipynb diff --git a/notebooks/course3/Requests Results (WithAutoDP).ipynb b/notebooks/course3/mock_api_notebooks/Requests Results (WithAutoDP).ipynb similarity index 100% rename from notebooks/course3/Requests Results (WithAutoDP).ipynb rename to notebooks/course3/mock_api_notebooks/Requests Results (WithAutoDP).ipynb diff --git a/notebooks/trade_demo/Part 1 - Setup Use Case.ipynb b/notebooks/trade_demo/mock_notebooks/Part 1 - Setup Use Case.ipynb similarity index 100% rename from notebooks/trade_demo/Part 1 - Setup Use Case.ipynb rename to notebooks/trade_demo/mock_notebooks/Part 1 - Setup Use Case.ipynb diff --git a/notebooks/trade_demo/Part 2 -Search for relevant data across a network.ipynb b/notebooks/trade_demo/mock_notebooks/Part 2 -Search for relevant data across a network.ipynb similarity index 100% rename from notebooks/trade_demo/Part 2 -Search for relevant data across a network.ipynb rename to notebooks/trade_demo/mock_notebooks/Part 2 -Search for relevant data across a network.ipynb diff --git a/notebooks/trade_demo/Part 3 - Select and ETL data into proper format.ipynb b/notebooks/trade_demo/mock_notebooks/Part 3 - Select and ETL data into proper format.ipynb similarity index 100% rename from notebooks/trade_demo/Part 3 - Select and ETL data into proper format.ipynb rename to notebooks/trade_demo/mock_notebooks/Part 3 - Select and ETL data into proper format.ipynb diff --git a/notebooks/trade_demo/Part 4 - Perform JOIN (backed by SMPC).ipynb b/notebooks/trade_demo/mock_notebooks/Part 4 - Perform JOIN (backed by SMPC).ipynb similarity index 100% rename from notebooks/trade_demo/Part 4 - Perform JOIN (backed by SMPC).ipynb rename to notebooks/trade_demo/mock_notebooks/Part 4 - Perform JOIN (backed by SMPC).ipynb diff --git a/notebooks/trade_demo/Part 5 - Perform analysis producting remote result.ipynb b/notebooks/trade_demo/mock_notebooks/Part 5 - Perform analysis producting remote result.ipynb similarity index 100% rename from notebooks/trade_demo/Part 5 - Perform analysis producting remote result.ipynb rename to notebooks/trade_demo/mock_notebooks/Part 5 - Perform analysis producting remote result.ipynb diff --git a/notebooks/trade_demo/Part 6 - Publish analysis result with DP noise.ipynb b/notebooks/trade_demo/mock_notebooks/Part 6 - Publish analysis result with DP noise.ipynb similarity index 100% rename from notebooks/trade_demo/Part 6 - Publish analysis result with DP noise.ipynb rename to notebooks/trade_demo/mock_notebooks/Part 6 - Publish analysis result with DP noise.ipynb diff --git a/notebooks/trade_demo/Part 7 - Download noisy results.ipynb b/notebooks/trade_demo/mock_notebooks/Part 7 - Download noisy results.ipynb similarity index 100% rename from notebooks/trade_demo/Part 7 - Download noisy results.ipynb rename to notebooks/trade_demo/mock_notebooks/Part 7 - Download noisy results.ipynb diff --git a/notebooks/trade_demo/mock_notebooks/README.md b/notebooks/trade_demo/mock_notebooks/README.md new file mode 100644 index 00000000000..dc5d219803e --- /dev/null +++ b/notebooks/trade_demo/mock_notebooks/README.md @@ -0,0 +1,5 @@ +**Note** + +These notebooks represent only mock APIs. They represent a vision on how we want to drive the user experience of the syft library in the notebooks in the near future. + +The functionalities presented here may or may not be present in the current or future versions of the syft library. \ No newline at end of file diff --git a/packages/grid/backend/requirements.txt b/packages/grid/backend/requirements.txt index 145b642d50e..96c13987f8c 100644 --- a/packages/grid/backend/requirements.txt +++ b/packages/grid/backend/requirements.txt @@ -1,5 +1,5 @@ alembic==1.6.5; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.6.0") -amqp==5.0.6; python_version >= "3.6" +amqp==5.0.9; python_version >= "3.6" anyio==3.3.0; python_full_version >= "3.6.2" and python_version >= "3.6" appdirs==1.4.4; python_version >= "3.6" asgi-lifespan==1.0.1; python_version >= "3.6" @@ -11,13 +11,13 @@ bcrypt==3.2.0; python_version >= "3.6" billiard==3.6.4.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" black==19.10b0; python_version >= "3.6" cachetools==4.2.2; python_version >= "3.5" and python_version < "4.0" -celery==5.1.2; python_version >= "3.6" +celery==5.2.3; python_version >= "3.6" certifi==2021.5.30; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.6" cffi==1.14.6 chardet==4.0.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" charset-normalizer==2.0.4; python_full_version >= "3.6.0" and python_version >= "3" checksumdir==1.2.0; python_version >= "3.6" and python_version < "4.0" -click==7.1.2; python_version >= "3.6" +click==8.0.3; python_version >= "3.6" colorama==0.4.4; python_version >= "3.6" and python_full_version < "3.0.0" and sys_platform == "win32" and platform_system == "Windows" and (python_version >= "3.5" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.5") or sys_platform == "win32" and python_version >= "3.6" and python_full_version >= "3.5.0" and platform_system == "Windows" and (python_version >= "3.5" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.5") coverage==5.5; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version < "4" cryptography==3.4.7; python_version >= "3.6" @@ -39,9 +39,9 @@ idna==3.2; python_full_version >= "3.6.2" and python_version >= "3.6" importlib-metadata==4.6.4; python_version >= "3.6" and python_full_version < "3.0.0" and python_version < "3.8" and (python_version >= "3.5" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.5") or python_full_version >= "3.6.0" and python_version < "3.8" and python_version >= "3.6" and (python_version >= "3.5" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.5") isort==4.3.21; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.4.0") jinja2==2.11.3; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.5.0") -kombu==5.1.0; python_version >= "3.6" +kombu==5.2.3; python_version >= "3.6" loguru==0.5.3; python_version >= "3.5" -lxml==4.6.3; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +lxml==4.7.1; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" mako==1.1.5; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" markupsafe==2.0.1; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.6" mccabe==0.6.1; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" @@ -69,7 +69,7 @@ python-dateutil==2.8.2; python_version >= "2.7" and python_full_version < "3.0.0 python-editor==1.0.4; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" python-jose==3.3.0 python-multipart==0.0.5 -pytz==2021.1; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" +pytz==2021.3; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" raven==6.10.0 regex==2021.8.21; python_version >= "3.6" requests==2.26.0; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.6.0") diff --git a/packages/grid/frontend/yarn.lock b/packages/grid/frontend/yarn.lock index dfbf955fda3..3c0fd1b4288 100644 --- a/packages/grid/frontend/yarn.lock +++ b/packages/grid/frontend/yarn.lock @@ -4671,10 +4671,10 @@ json-schema-traverse@^1.0.0: resolved "https://registry.yarnpkg.com/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz#ae7bcb3656ab77a73ba5c49bf654f38e6b6860e2" integrity sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug== -json-schema@0.2.3: - version "0.2.3" - resolved "https://registry.yarnpkg.com/json-schema/-/json-schema-0.2.3.tgz#b480c892e59a2f05954ce727bd3f2a4e882f9e13" - integrity sha1-tIDIkuWaLwWVTOcnvT8qTogvnhM= +json-schema@0.4.0: + version "0.4.0" + resolved "https://registry.yarnpkg.com/json-schema/-/json-schema-0.4.0.tgz#f7de4cf6efab838ebaeb3236474cbba5a1930ab5" + integrity sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA== json-stable-stringify-without-jsonify@^1.0.1: version "1.0.1" diff --git a/packages/grid/vpn/tailscale.dockerfile b/packages/grid/vpn/tailscale.dockerfile index 0be26ed4cb9..67d7dd33124 100644 --- a/packages/grid/vpn/tailscale.dockerfile +++ b/packages/grid/vpn/tailscale.dockerfile @@ -1,4 +1,4 @@ -FROM shaynesweeney/tailscale:latest +FROM tailscale/tailscale:v1.16 RUN --mount=type=cache,target=/var/cache/apk \ apk add --no-cache python3 py3-pip ca-certificates diff --git a/packages/syft/proto/core/node/domain/service/get_all_requests_message.proto b/packages/syft/proto/core/node/domain/service/get_all_requests_message.proto deleted file mode 100644 index 5336c5cdc59..00000000000 --- a/packages/syft/proto/core/node/domain/service/get_all_requests_message.proto +++ /dev/null @@ -1,19 +0,0 @@ -syntax = "proto3"; - -package syft.core.node.domain.service; - -import "proto/core/node/domain/service/request_message.proto"; -import "proto/core/common/common_object.proto"; -import "proto/core/io/address.proto"; - -message GetAllRequestsMessage { - syft.core.common.UID msg_id = 1; - syft.core.io.Address address = 2; - syft.core.io.Address reply_to = 3; -} - -message GetAllRequestsResponseMessage { - syft.core.common.UID msg_id = 1; - syft.core.io.Address address = 2; - repeated syft.core.node.domain.service.RequestMessage requests = 3; -} diff --git a/packages/syft/proto/grid/messages/request_messages.proto b/packages/syft/proto/grid/messages/request_messages.proto index 6a989a928bd..71ffd950655 100644 --- a/packages/syft/proto/grid/messages/request_messages.proto +++ b/packages/syft/proto/grid/messages/request_messages.proto @@ -5,6 +5,7 @@ package syft.grid.messages; import "proto/core/common/common_object.proto"; import "proto/core/io/address.proto"; import "proto/lib/python/dict.proto"; +import "proto/core/node/domain/service/request_message.proto"; // CREATE Request @@ -105,3 +106,16 @@ message UpdateRequestResponse { string request_id = 4; syft.core.io.Address address = 5; } + + +message GetAllRequestsMessage { + syft.core.common.UID msg_id = 1; + syft.core.io.Address address = 2; + syft.core.io.Address reply_to = 3; +} + +message GetAllRequestsResponseMessage { + syft.core.common.UID msg_id = 1; + syft.core.io.Address address = 2; + repeated syft.core.node.domain.service.RequestMessage requests = 3; +} diff --git a/packages/syft/src/syft/core/adp/entity.py b/packages/syft/src/syft/core/adp/entity.py index 4cbf145a325..8b9ef61d959 100644 --- a/packages/syft/src/syft/core/adp/entity.py +++ b/packages/syft/src/syft/core/adp/entity.py @@ -119,9 +119,7 @@ class DataSubjectGroup: def __init__(self, list_of_entities: Optional[Union[list, set]] = None): self.entity_set: set = set() # Ensure each entity being tracked is unique - if isinstance(list_of_entities, list): - self.entity_set = self.entity_set.union(set(list_of_entities)) - elif isinstance(list_of_entities, set): + if isinstance(list_of_entities, (list, set)): self.entity_set = self.entity_set.union(list_of_entities) elif isinstance(list_of_entities, Entity): self.entity_set.add(list_of_entities) # type: ignore @@ -143,10 +141,8 @@ def __contains__(self, item: Entity) -> bool: return item in self.entity_set def to_string(self) -> str: - output_string = "" - for item in self.entity_set: - output_string += item.to_string() + ";" - return output_string[:-1] + output_string = ";".join(item.to_string() for item in self.entity_set) + return output_string @staticmethod def from_string(blob: str) -> DataSubjectGroup: @@ -155,7 +151,7 @@ def from_string(blob: str) -> DataSubjectGroup: entity_set = set() for entity_blob in entity_list: entity_set.add(Entity.from_string(entity_blob)) - return DataSubjectGroup(list[entity_set]) # type: ignore + return DataSubjectGroup(entity_set) # type: ignore def __add__( self, other: Union[DataSubjectGroup, Entity, int, float] diff --git a/packages/syft/src/syft/core/adp/scalar/intermediate_gamma_scalar.py b/packages/syft/src/syft/core/adp/scalar/intermediate_gamma_scalar.py index 4ca72b1744f..c4c8a97bbb8 100644 --- a/packages/syft/src/syft/core/adp/scalar/intermediate_gamma_scalar.py +++ b/packages/syft/src/syft/core/adp/scalar/intermediate_gamma_scalar.py @@ -127,7 +127,7 @@ def max_lipschitz_via_explicit_search( r2_diffs = np.array( [ - GammaScalar(x.min_val, x.value, x.max_val, entity=x.entity).poly # type: ignore + GammaScalar(x.min_val, x.value, x.max_val, entity=x.entity, prime=x.prime).poly # type: ignore for x in self.input_scalars ] ) diff --git a/packages/syft/src/syft/core/node/common/node_manager/constants.py b/packages/syft/src/syft/core/node/common/node_manager/constants.py new file mode 100644 index 00000000000..a8ebb0495ad --- /dev/null +++ b/packages/syft/src/syft/core/node/common/node_manager/constants.py @@ -0,0 +1,8 @@ +# stdlib +from enum import Enum + + +class UserApplicationStatus(Enum): + PENDING = "pending" + ACCEPTED = "accepted" + REJECTED = "rejected" diff --git a/packages/syft/src/syft/core/node/common/node_manager/user_manager.py b/packages/syft/src/syft/core/node/common/node_manager/user_manager.py index 1ba1fa2edc6..30f3e3b6d8e 100644 --- a/packages/syft/src/syft/core/node/common/node_manager/user_manager.py +++ b/packages/syft/src/syft/core/node/common/node_manager/user_manager.py @@ -1,3 +1,5 @@ +"""This file defines classes and methods which are used to manage database queries on the SyftUser table.""" + # stdlib from datetime import datetime from typing import Any @@ -13,8 +15,6 @@ from nacl.encoding import HexEncoder from nacl.signing import SigningKey from nacl.signing import VerifyKey -from pydantic import BaseModel -from pydantic import EmailStr from sqlalchemy.engine import Engine from sqlalchemy.orm import Query from sqlalchemy.orm import sessionmaker @@ -26,47 +26,13 @@ from ..node_table.roles import Role from ..node_table.user import SyftUser from ..node_table.user import UserApplication +from .constants import UserApplicationStatus from .database_manager import DatabaseManager from .role_manager import RoleManager -# Shared properties -class UserBase(BaseModel): - email: Optional[EmailStr] = None - is_active: Optional[bool] = True - is_superuser: bool = False - full_name: Optional[str] = None - - -# Properties to receive via API on creation -class UserCreate(UserBase): - email: EmailStr - password: str - - -# Properties to receive via API on update -class UserUpdate(UserBase): - password: Optional[str] = None - - -class UserInDBBase(UserBase): - id: Optional[int] = None - - class Config: - orm_mode = True - - -# Additional properties to return via API -class User(UserInDBBase): - pass - - -# Additional properties stored in DB -class UserInDB(UserInDBBase): - hashed_password: str - - class UserManager(DatabaseManager): + """Class to manage user database actions.""" schema = SyftUser @@ -76,6 +42,7 @@ def __init__(self, database: Engine) -> None: @property def common_users(self) -> list: + """Return users having the common role access.""" common_users: List[SyftUser] = [] for role in self.roles.common_roles: common_users = common_users + list(super().query(role=role.id)) @@ -84,6 +51,7 @@ def common_users(self) -> list: @property def org_users(self) -> list: + """Return all the users in the organization.""" org_users: List[SyftUser] = [] for role in self.roles.org_roles: org_users = org_users + list(super().query(role=role.id)) @@ -99,6 +67,21 @@ def create_user_application( website: Optional[str] = "", budget: Optional[float] = 0.0, ) -> int: + """Stores the information of the application submitted by the user. + + Args: + name (str): name of the user. + email (str): email of the user. + password (str): password of the user. + daa_pdf (Optional[bytes]): data access agreement. + institution (Optional[str], optional): name of the institution to which the user belongs. Defaults to "". + website (Optional[str], optional): website link of the institution. Defaults to "". + budget (Optional[float], optional): privacy budget allocated to the user. Defaults to 0.0. + + Returns: + int: Id of the application of the user. + """ + salt, hashed = self.__salt_and_hash_password(password, 12) session_local = sessionmaker(autocommit=False, autoflush=False, bind=self.db)() _pdf_obj = PDFObject(binary=daa_pdf) @@ -125,6 +108,11 @@ def create_user_application( return _obj_id def get_all_applicant(self) -> List[UserApplication]: + """Returns the application data of all the applicants in the database. + + Returns: + List[UserApplication]: All user applications. + """ session_local = sessionmaker(autocommit=False, autoflush=False, bind=self.db)() result = list(session_local.query(UserApplication).all()) session_local.close() @@ -133,13 +121,26 @@ def get_all_applicant(self) -> List[UserApplication]: def process_user_application( self, candidate_id: int, status: str, verify_key: VerifyKey ) -> None: + """Process the application for the given candidate. + + If the application of the user was accepted, then register the user + and its details in the database. Finally update the application status + for the given user/candidate in the database. + + Args: + candidate_id (int): user id of the candidate. + status (str): application status. + verify_key (VerifyKey): public digital signature of the user. + """ session_local = sessionmaker(autocommit=False, autoflush=False, bind=self.db)() candidate = ( session_local.query(UserApplication).filter_by(id=candidate_id).first() ) session_local.close() - if status == "accepted": + if ( + status == UserApplicationStatus.ACCEPTED.value + ): # If application was accepted # Generate a new signing key _private_key = SigningKey.generate() @@ -148,6 +149,8 @@ def process_user_application( "utf-8" ) added_by = self.get_user(verify_key).name # type: ignore + + # Register the user in the database self.register( name=candidate.name, email=candidate.email, @@ -164,7 +167,7 @@ def process_user_application( created_at=datetime.now(), ) else: - status = "rejected" + status = UserApplicationStatus.REJECTED.value session_local = sessionmaker(autocommit=False, autoflush=False, bind=self.db)() candidate = ( @@ -185,6 +188,20 @@ def signup( private_key: str, verify_key: str, ) -> SyftUser: + """Registers a user in the database, when they signup on a domain. + + Args: + name (str): name of the user. + email (str): email of the user. + password (str): password set by the user. + budget (float): privacy budget alloted to the user. + role (int): role of the user when they signup on the domain. + private_key (str): private digital signature of the user. + verify_key (str): public digital signature of the user. + + Returns: + SyftUser: the registered user object. + """ salt, hashed = self.__salt_and_hash_password(password, 12) return self.register( name=name, @@ -209,6 +226,15 @@ def first(self, **kwargs: Any) -> SyftUser: return result def login(self, email: str, password: str) -> SyftUser: + """Returns the user object for the given the email and password. + + Args: + email (str): email of the user. + password (str): password of the user. + + Returns: + SyftUser: user object for the given email and password. + """ return self.__login_validation(email, password) def set( # nosec @@ -222,6 +248,22 @@ def set( # nosec institution: str = "", budget: float = 0.0, ) -> None: + """Updates the information for the given user id. + + Args: + user_id (str): unique id of the user in the database. + email (str, optional): email of the user. Defaults to "". + password (str, optional): password of the user. Defaults to "". + role (int, optional): role of the user. Defaults to 0. + name (str, optional): name of the user. Defaults to "". + website (str, optional): website of the institution of the user. Defaults to "". + institution (str, optional): name of the institution of the user. Defaults to "". + budget (float, optional): privacy budget allocated to the user. Defaults to 0.0. + + Raises: + UserNotFoundError: Raised when a user does not exits for the given user id. + Exception: Raised when an invalid argument/property is passed. + """ if not self.contain(id=user_id): raise UserNotFoundError @@ -259,47 +301,67 @@ def set( # nosec self.modify({"id": user_id}, {key: value}) def can_create_users(self, verify_key: VerifyKey) -> bool: + """Checks if a user has permissions to create new users.""" try: return self.role(verify_key=verify_key).can_create_users except UserNotFoundError: return False def can_upload_data(self, verify_key: VerifyKey) -> bool: + """Checks if a user has permissions to upload data to the node.""" try: return self.role(verify_key=verify_key).can_upload_data except UserNotFoundError: return False def can_triage_requests(self, verify_key: VerifyKey) -> bool: + """Checks if a user has permissions to triage requests.""" try: return self.role(verify_key=verify_key).can_triage_data_requests except UserNotFoundError: return False def can_manage_infrastructure(self, verify_key: VerifyKey) -> bool: + """Checks if a user has permissions to manage the deployed infrastructure.""" try: return self.role(verify_key=verify_key).can_manage_infrastructure except UserNotFoundError: return False def can_edit_roles(self, verify_key: VerifyKey) -> bool: + """Checks if a user has permission to edit roles of other users.""" try: return self.role(verify_key=verify_key).can_edit_roles except UserNotFoundError: return False def role(self, verify_key: VerifyKey) -> Role: + """Returns the role of the given user.""" user = self.get_user(verify_key) if not user: raise UserNotFoundError return self.roles.first(id=user.role) def get_user(self, verify_key: VerifyKey) -> Optional[SyftUser]: + """Returns the user for the given public digital signature.""" return self.first( verify_key=verify_key.encode(encoder=HexEncoder).decode("utf-8") ) def __login_validation(self, email: str, password: str) -> SyftUser: + """Validates and returns the user object for the given credentials. + + Args: + email (str): email of the user. + password (str): password of the user. + + Raises: + UserNotFoundError: Raised if the user does not exist for the email. + InvalidCredentialsError: Raised if either the password or email is incorrect. + + Returns: + SyftUser: Returns the user for the given credentials. + """ try: user = self.first(email=email) if not user: diff --git a/packages/syft/src/syft/core/node/common/node_service/get_all_requests/__init__.py b/packages/syft/src/syft/core/node/common/node_service/get_all_requests/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/packages/syft/src/syft/core/node/common/node_service/get_all_requests/get_all_requests_messages.py b/packages/syft/src/syft/core/node/common/node_service/get_all_requests/get_all_requests_messages.py deleted file mode 100644 index 7fc907d7aa6..00000000000 --- a/packages/syft/src/syft/core/node/common/node_service/get_all_requests/get_all_requests_messages.py +++ /dev/null @@ -1,168 +0,0 @@ -# stdlib -from typing import List -from typing import Optional - -# third party -from google.protobuf.reflection import GeneratedProtocolMessageType - -# relative -from ...... import deserialize -from ...... import serialize -from ......proto.core.node.domain.service.get_all_requests_message_pb2 import ( - GetAllRequestsMessage as GetAllRequestsMessage_PB, -) -from ......proto.core.node.domain.service.get_all_requests_message_pb2 import ( - GetAllRequestsResponseMessage as GetAllRequestsResponseMessage_PB, -) -from .....common import UID -from .....common.message import ImmediateSyftMessageWithReply -from .....common.message import ImmediateSyftMessageWithoutReply -from .....common.serde.serializable import serializable -from .....io.address import Address -from ..request_receiver.request_receiver_messages import RequestMessage - - -@serializable() -class GetAllRequestsMessage(ImmediateSyftMessageWithReply): - def __init__( - self, address: Address, reply_to: Address, msg_id: Optional[UID] = None - ): - super().__init__(address=address, msg_id=msg_id, reply_to=reply_to) - - def _object2proto(self) -> GetAllRequestsMessage_PB: - """Returns a protobuf serialization of self. - - As a requirement of all objects which inherit from Serializable, - this method transforms the current object into the corresponding - Protobuf object so that it can be further serialized. - - :return: returns a protobuf object - :rtype: GetAllRequestsMessage_PB - - .. note:: - This method is purely an internal method. Please use serialize(object) or one of - the other public serialization methods if you wish to serialize an - object. - """ - return GetAllRequestsMessage_PB( - msg_id=serialize(self.id), - address=serialize(self.address), - reply_to=serialize(self.reply_to), - ) - - @staticmethod - def _proto2object(proto: GetAllRequestsMessage_PB) -> "GetAllRequestsMessage": - """Creates a ReprMessage from a protobuf - - As a requirement of all objects which inherit from Serializable, - this method transforms a protobuf object into an instance of this class. - - :return: returns an instance of ReprMessage - :rtype: ReprMessage - - .. note:: - This method is purely an internal method. Please use syft.deserialize() - if you wish to deserialize an object. - """ - - return GetAllRequestsMessage( - msg_id=deserialize(blob=proto.msg_id), - address=deserialize(blob=proto.address), - reply_to=deserialize(blob=proto.reply_to), - ) - - @staticmethod - def get_protobuf_schema() -> GeneratedProtocolMessageType: - """Return the type of protobuf object which stores a class of this type - - As a part of serialization and deserialization, we need the ability to - lookup the protobuf object type directly from the object type. This - static method allows us to do this. - - Importantly, this method is also used to create the reverse lookup ability within - the metaclass of Serializable. In the metaclass, it calls this method and then - it takes whatever type is returned from this method and adds an attribute to it - with the type of this class attached to it. See the MetaSerializable class for details. - - :return: the type of protobuf object which corresponds to this class. - :rtype: GeneratedProtocolMessageType - - """ - - return GetAllRequestsMessage_PB - - -@serializable() -class GetAllRequestsResponseMessage(ImmediateSyftMessageWithoutReply): - def __init__( - self, - requests: List[RequestMessage], - address: Address, - msg_id: Optional[UID] = None, - ): - super().__init__(address=address, msg_id=msg_id) - self.requests = requests - - def _object2proto(self) -> GetAllRequestsResponseMessage_PB: - """Returns a protobuf serialization of self. - - As a requirement of all objects which inherit from Serializable, - this method transforms the current object into the corresponding - Protobuf object so that it can be further serialized. - - :return: returns a protobuf object - :rtype: ReprMessage_PB - - .. note:: - This method is purely an internal method. Please use serialize(object) or one of - the other public serialization methods if you wish to serialize an - object. - """ - return GetAllRequestsResponseMessage_PB( - msg_id=serialize(self.id), - address=serialize(self.address), - requests=list(map(lambda x: serialize(x), self.requests)), - ) - - @staticmethod - def _proto2object( - proto: GetAllRequestsResponseMessage_PB, - ) -> "GetAllRequestsResponseMessage": - """Creates a GetAllRequestsResponseMessage from a protobuf - - As a requirement of all objects which inherit from Serializable, - this method transforms a protobuf object into an instance of this class. - - :return: returns an instance of GetAllRequestsResponseMessage - :rtype: GetAllRequestsResponseMessage - - .. note:: - This method is purely an internal method. Please use syft.deserialize() - if you wish to deserialize an object. - """ - - return GetAllRequestsResponseMessage( - msg_id=deserialize(blob=proto.msg_id), - address=deserialize(blob=proto.address), - requests=[deserialize(blob=x) for x in proto.requests], - ) - - @staticmethod - def get_protobuf_schema() -> GeneratedProtocolMessageType: - """Return the type of protobuf object which stores a class of this type - - As a part of serialization and deserialization, we need the ability to - lookup the protobuf object type directly from the object type. This - static method allows us to do this. - - Importantly, this method is also used to create the reverse lookup ability within - the metaclass of Serializable. In the metaclass, it calls this method and then - it takes whatever type is returned from this method and adds an attribute to it - with the type of this class attached to it. See the MetaSerializable class for details. - - :return: the type of protobuf object which corresponds to this class. - :rtype: GeneratedProtocolMessageType - - """ - - return GetAllRequestsResponseMessage_PB diff --git a/packages/syft/src/syft/core/node/common/node_service/get_all_requests/get_all_requests_service.py b/packages/syft/src/syft/core/node/common/node_service/get_all_requests/get_all_requests_service.py deleted file mode 100644 index 3163f746ed6..00000000000 --- a/packages/syft/src/syft/core/node/common/node_service/get_all_requests/get_all_requests_service.py +++ /dev/null @@ -1,52 +0,0 @@ -# # stdlib -# from typing import List -# from typing import Optional -# -# # third party -# from nacl.signing import VerifyKey -# -# # relative -# from ......logger import traceback_and_raise -# from ....abstract.node import AbstractNode -# from ....common.node_service.node_service import ImmediateNodeServiceWithoutReply -# from ..request_receiver.request_receiver_messages import RequestMessage -# from .get_all_requests_messages import GetAllRequestsMessage -# from .get_all_requests_messages import GetAllRequestsResponseMessage -# -# -# class GetAllRequestsService(ImmediateNodeServiceWithoutReply): -# @staticmethod -# def message_handler_types() -> List[type]: -# return [GetAllRequestsMessage] -# -# @staticmethod -# def process( -# node: AbstractNode, -# msg: GetAllRequestsMessage, -# verify_key: Optional[VerifyKey] = None, -# ) -> GetAllRequestsResponseMessage: -# try: -# if verify_key is None: -# traceback_and_raise( -# ValueError( -# "Can't process Request service without a given " "verification key" -# ) -# ) -# -# if verify_key == node.root_verify_key: -# return GetAllRequestsResponseMessage( -# requests=node.requests, address=msg.reply_to -# ) -# -# # only return requests which concern the user asking -# valid_requests: List[RequestMessage] = list() -# for request in node.requests: -# if request.requester_verify_key == verify_key: -# valid_requests.append(request) -# -# return GetAllRequestsResponseMessage( -# requests=valid_requests, address=msg.reply_to -# ) -# except Exception as e: -# print('\n\nSOMETHING WENT WRONG!!!\n\n') -# print(e) diff --git a/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_messages.py b/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_messages.py index 6bc1ac7f202..2bf1a8410a8 100644 --- a/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_messages.py +++ b/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_messages.py @@ -26,6 +26,12 @@ from ......proto.grid.messages.request_messages_pb2 import ( DeleteRequestResponse as DeleteRequestResponse_PB, ) +from ......proto.grid.messages.request_messages_pb2 import ( + GetAllRequestsMessage as GetAllRequestsMessage_PB, +) +from ......proto.grid.messages.request_messages_pb2 import ( + GetAllRequestsResponseMessage as GetAllRequestsResponseMessage_PB, +) from ......proto.grid.messages.request_messages_pb2 import ( GetBudgetRequestsMessage as GetBudgetRequestsMessage_PB, ) @@ -56,6 +62,7 @@ from .....common.serde.serializable import serializable from .....common.uid import UID from .....io.address import Address +from ..request_receiver.request_receiver_messages import RequestMessage @serializable() @@ -988,3 +995,149 @@ def get_protobuf_schema() -> GeneratedProtocolMessageType: """ return DeleteRequestResponse_PB + + +@serializable() +class GetAllRequestsMessage(ImmediateSyftMessageWithReply): + def __init__( + self, address: Address, reply_to: Address, msg_id: Optional[UID] = None + ): + super().__init__(address=address, msg_id=msg_id, reply_to=reply_to) + + def _object2proto(self) -> GetAllRequestsMessage_PB: + """Returns a protobuf serialization of self. + + As a requirement of all objects which inherit from Serializable, + this method transforms the current object into the corresponding + Protobuf object so that it can be further serialized. + + :return: returns a protobuf object + :rtype: GetAllRequestsMessage_PB + + .. note:: + This method is purely an internal method. Please use serialize(object) or one of + the other public serialization methods if you wish to serialize an + object. + """ + return GetAllRequestsMessage_PB( + msg_id=serialize(self.id), + address=serialize(self.address), + reply_to=serialize(self.reply_to), + ) + + @staticmethod + def _proto2object(proto: GetAllRequestsMessage_PB) -> "GetAllRequestsMessage": + """Creates a ReprMessage from a protobuf + + As a requirement of all objects which inherit from Serializable, + this method transforms a protobuf object into an instance of this class. + + :return: returns an instance of ReprMessage + :rtype: ReprMessage + + .. note:: + This method is purely an internal method. Please use syft.deserialize() + if you wish to deserialize an object. + """ + + return GetAllRequestsMessage( + msg_id=_deserialize(blob=proto.msg_id), + address=_deserialize(blob=proto.address), + reply_to=_deserialize(blob=proto.reply_to), + ) + + @staticmethod + def get_protobuf_schema() -> GeneratedProtocolMessageType: + """Return the type of protobuf object which stores a class of this type + + As a part of serialization and deserialization, we need the ability to + lookup the protobuf object type directly from the object type. This + static method allows us to do this. + + Importantly, this method is also used to create the reverse lookup ability within + the metaclass of Serializable. In the metaclass, it calls this method and then + it takes whatever type is returned from this method and adds an attribute to it + with the type of this class attached to it. See the MetaSerializable class for details. + + :return: the type of protobuf object which corresponds to this class. + :rtype: GeneratedProtocolMessageType + + """ + + return GetAllRequestsMessage_PB + + +@serializable() +class GetAllRequestsResponseMessage(ImmediateSyftMessageWithoutReply): + def __init__( + self, + requests: List[RequestMessage], + address: Address, + msg_id: Optional[UID] = None, + ): + super().__init__(address=address, msg_id=msg_id) + self.requests = requests + + def _object2proto(self) -> GetAllRequestsResponseMessage_PB: + """Returns a protobuf serialization of self. + + As a requirement of all objects which inherit from Serializable, + this method transforms the current object into the corresponding + Protobuf object so that it can be further serialized. + + :return: returns a protobuf object + :rtype: ReprMessage_PB + + .. note:: + This method is purely an internal method. Please use serialize(object) or one of + the other public serialization methods if you wish to serialize an + object. + """ + return GetAllRequestsResponseMessage_PB( + msg_id=serialize(self.id), + address=serialize(self.address), + requests=list(map(lambda x: serialize(x), self.requests)), + ) + + @staticmethod + def _proto2object( + proto: GetAllRequestsResponseMessage_PB, + ) -> "GetAllRequestsResponseMessage": + """Creates a GetAllRequestsResponseMessage from a protobuf + + As a requirement of all objects which inherit from Serializable, + this method transforms a protobuf object into an instance of this class. + + :return: returns an instance of GetAllRequestsResponseMessage + :rtype: GetAllRequestsResponseMessage + + .. note:: + This method is purely an internal method. Please use syft.deserialize() + if you wish to deserialize an object. + """ + + return GetAllRequestsResponseMessage( + msg_id=_deserialize(blob=proto.msg_id), + address=_deserialize(blob=proto.address), + requests=[_deserialize(blob=x) for x in proto.requests], + ) + + @staticmethod + def get_protobuf_schema() -> GeneratedProtocolMessageType: + """Return the type of protobuf object which stores a class of this type + + As a part of serialization and deserialization, we need the ability to + lookup the protobuf object type directly from the object type. This + static method allows us to do this. + + Importantly, this method is also used to create the reverse lookup ability within + the metaclass of Serializable. In the metaclass, it calls this method and then + it takes whatever type is returned from this method and adds an attribute to it + with the type of this class attached to it. See the MetaSerializable class for details. + + :return: the type of protobuf object which corresponds to this class. + :rtype: GeneratedProtocolMessageType + + """ + + return GetAllRequestsResponseMessage_PB diff --git a/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_service.py b/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_service.py index 7f2a07b6c9c..902f9c332cf 100644 --- a/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_service.py +++ b/packages/syft/src/syft/core/node/common/node_service/object_request/object_request_service.py @@ -31,8 +31,6 @@ AcceptOrDenyRequestMessage, ) from ..auth import service_auth -from ..get_all_requests.get_all_requests_messages import GetAllRequestsMessage -from ..get_all_requests.get_all_requests_messages import GetAllRequestsResponseMessage from ..node_service import ImmediateNodeServiceWithReply from ..node_service import ImmediateNodeServiceWithoutReply from ..request_answer.request_answer_messages import RequestAnswerMessage @@ -49,6 +47,8 @@ from .object_request_messages import CreateRequestResponse from .object_request_messages import DeleteRequestMessage from .object_request_messages import DeleteRequestResponse +from .object_request_messages import GetAllRequestsMessage +from .object_request_messages import GetAllRequestsResponseMessage from .object_request_messages import GetBudgetRequestsMessage from .object_request_messages import GetBudgetRequestsResponse from .object_request_messages import GetRequestMessage @@ -402,8 +402,6 @@ def get_all_requests( _can_triage_request = node.users.can_triage_requests(verify_key=verify_key) - _requests = node.data_requests.all() - if _can_triage_request: _requests = node.data_requests.all() else: diff --git a/packages/syft/src/syft/core/node/common/node_service/role_manager/role_manager_service.py b/packages/syft/src/syft/core/node/common/node_service/role_manager/role_manager_service.py index 3ecb014cfef..479c7b73150 100644 --- a/packages/syft/src/syft/core/node/common/node_service/role_manager/role_manager_service.py +++ b/packages/syft/src/syft/core/node/common/node_service/role_manager/role_manager_service.py @@ -1,3 +1,7 @@ +"""This file defines all the functions/classes to perform any CRUD operation on +the Role table, for a given domain node, in a RESTful manner. +""" + # stdlib from typing import Callable from typing import Dict @@ -51,7 +55,22 @@ def create_role_msg( node: DomainInterface, verify_key: VerifyKey, ) -> SuccessResponseMessage: - # Check key permissions + """Creates a new role in the database. + + Args: + msg (CreateRoleMessage): details of the role. + node (DomainInterface): domain node. + verify_key (VerifyKey): public digital signature/key of the user. + + Raises: + MissingRequestKeyError: If name of the role is missing. + RequestError: If role already exists. + AuthorizationError: If user does not have permissions to create new role. + + Returns: + SuccessResponseMessage: Success message on role creation. + """ + # Check if user has permissions to create new roles _allowed = node.users.can_edit_roles(verify_key=verify_key) if not msg.name: @@ -94,6 +113,20 @@ def update_role_msg( node: DomainInterface, verify_key: VerifyKey, ) -> SuccessResponseMessage: + """Updates the properties of the given role. + + Args: + msg (UpdateRoleMessage): stores msg address and properties of the role to be updated. + node (DomainInterface): domain node. + verify_key (VerifyKey): public digital signature of the user. + + Raises: + MissingRequestKeyError: If the role id does not exist in the `msg`. + AuthorizationError: If user does not have permissions to perform the update operation. + + Returns: + SuccessResponseMessage: Message on successfully updating the role. + """ params = { "name": msg.name, @@ -112,7 +145,7 @@ def update_role_msg( if not msg.role_id: raise MissingRequestKeyError - # Check Key permissions + # Check if user has permissions to edit roles _allowed = node.users.can_edit_roles(verify_key=verify_key) if _allowed: @@ -131,8 +164,21 @@ def get_role_msg( node: DomainInterface, verify_key: VerifyKey, ) -> GetRoleResponse: + """Retrieves details of a role. - # Check Key permissions + Args: + msg (GetRoleMessage): stores msg address and role id. + node (DomainInterface): domain node. + verify_key (VerifyKey): public digital signature of the user. + + Raises: + AuthorizationError: If user does not have permissions to get role information. + + Returns: + GetRoleResponse: details of the role. + """ + + # Check if user has permissions to triage requests _allowed = node.users.can_triage_requests(verify_key=verify_key) if _allowed: @@ -151,7 +197,21 @@ def get_all_roles_msg( node: DomainInterface, verify_key: VerifyKey, ) -> GetRolesResponse: + """Retrieves details of the all available roles. + + Args: + msg (GetRolesMessage): stores the address of the message. + node (DomainInterface): domain node. + verify_key (VerifyKey): public digital signature of the user. + Raises: + AuthorizationError: If user does not have permissions to access roles. + + Returns: + GetRolesResponse: stores the details for all the roles as a list. + """ + + # Check if user has permissions to view roles _allowed = node.users.can_triage_requests(verify_key=verify_key) if _allowed: @@ -168,6 +228,21 @@ def del_role_msg( node: DomainInterface, verify_key: VerifyKey, ) -> SuccessResponseMessage: + """Delete the role corresponding to the given role id. + + Args: + msg (DeleteRoleMessage): stores the msg address and id of the role to be deleted. + node (DomainInterface): domain node. + verify_key (VerifyKey): public digital signature of the user. + + Raises: + AuthorizationError: If user does not have permissions to edit roles. + + Returns: + SuccessResponseMessage: stores the response msg on successful role deletion. + """ + + # Check if user has permissions to edit roles _allowed = node.users.can_edit_roles(verify_key=verify_key) if _allowed: @@ -182,6 +257,8 @@ def del_role_msg( class RoleManagerService(ImmediateNodeServiceWithReply): + """A class to handle all operations performed on the Role table.""" + msg_handler_map: Dict[INPUT_TYPE, Callable[..., OUTPUT_MESSAGES]] = { CreateRoleMessage: create_role_msg, UpdateRoleMessage: update_role_msg, diff --git a/packages/syft/src/syft/core/node/domain/client.py b/packages/syft/src/syft/core/node/domain/client.py index 1adedda9aca..27b58c2decf 100644 --- a/packages/syft/src/syft/core/node/domain/client.py +++ b/packages/syft/src/syft/core/node/domain/client.py @@ -16,7 +16,6 @@ import pandas as pd # relative -from .... import deserialize from ....logger import traceback_and_raise from ....util import validate_field from ...common.message import SyftMessage @@ -72,7 +71,7 @@ def __init__(self, client: Client) -> None: def requests(self) -> List[RequestMessage]: # relative - from ..common.node_service.get_all_requests.get_all_requests_messages import ( + from ..common.node_service.object_request.object_request_messages import ( GetAllRequestsMessage, ) @@ -80,9 +79,6 @@ def requests(self) -> List[RequestMessage]: address=self.client.address, reply_to=self.client.address ) - blob = serialize(msg, to_bytes=True) - msg = deserialize(blob, from_bytes=True) - requests = self.client.send_immediate_msg_with_reply(msg=msg).requests # type: ignore for request in requests: diff --git a/packages/syft/src/syft/lib/python/dict.py b/packages/syft/src/syft/lib/python/dict.py index 268004258c0..79b5e704e61 100644 --- a/packages/syft/src/syft/lib/python/dict.py +++ b/packages/syft/src/syft/lib/python/dict.py @@ -46,7 +46,7 @@ def __init__(*args: Any, **kwargs: Any) -> None: self, *args = args # type: ignore if len(args) > 1: traceback_and_raise( - TypeError("expected at most 1 arguments, got %d" % len(args)) + TypeError(f"expected at most 1 arguments, got {len(args)}") ) if args: args_dict = args[0] diff --git a/packages/syft/src/syft/lib/python/slice.py b/packages/syft/src/syft/lib/python/slice.py index 5c3898036e9..23b20bb014c 100644 --- a/packages/syft/src/syft/lib/python/slice.py +++ b/packages/syft/src/syft/lib/python/slice.py @@ -1,3 +1,8 @@ +""" +This source file aims to replace the standard slice object/function provided by Python +to be handled by the PySyft's abstract syntax tree data structure during a remote call. +""" + # stdlib from typing import Any from typing import Optional @@ -15,9 +20,10 @@ from .primitive_factory import PrimitiveFactory from .primitive_interface import PyPrimitive from .types import SyPrimitiveRet +from .util import upcast -@serializable() +@serializable() # This decorator turns this class serializable. class Slice(PyPrimitive): def __init__( self, @@ -26,6 +32,16 @@ def __init__( step: Optional[Any] = None, id: Optional[UID] = None, ): + """ + This class will receive start, stop, step and ID as valid parameters. + + Args: + start (Any): Index/position where the slicing of the object starts. + stop (Any): Index/position which the slicing takes place. The slicing stops at index stop-1. + step (Any): Determines the increment between each index for slicing. + id (UID): PySyft's objects have an unique ID related to them. + """ + # first, second, third if stop is None and step is None: # slice treats 1 arg as stop not start @@ -37,62 +53,157 @@ def __init__( @property def id(self) -> UID: - """We reveal PyPrimitive.id as a property to discourage users and + """ + We reveal PyPrimitive.id as a property to discourage users and developers of Syft from modifying .id attributes after an object has been initialized. - :return: returns the unique id of the object - :rtype: UID + Returns: + UID: The unique ID of the object. """ return self._id def __eq__(self, other: Any) -> SyPrimitiveRet: - res = self.value.__eq__(other) + """ + Compare if self == other. + + Args: + other (Any): Object to be compared. + Returns: + SyPrimitiveRet: returns a PySyft boolean format checking if self == other. + """ + res = self.value.__eq__(upcast(other)) return PrimitiveFactory.generate_primitive(value=res) def __ge__(self, other: Any) -> SyPrimitiveRet: - res = self.value.__ge__(other) # type: ignore + """ + Compare if self >= other. + + Args: + other (Any): Object to be compared. + Returns: + SyPrimitiveRet: returns a PySyft boolean format checking if self >= other. + """ + res = self.value.__ge__(upcast(other)) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __gt__(self, other: Any) -> SyPrimitiveRet: - res = self.value.__gt__(other) # type: ignore + """ + Compare if self > other. + + Args: + other (Any): Object to be compared. + Returns: + SyPrimitiveRet: returns a PySyft boolean format checking if self > other. + """ + res = self.value.__gt__(upcast(other)) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __le__(self, other: Any) -> SyPrimitiveRet: - res = self.value.__le__(other) # type: ignore + """ + Compare if self <= other. + + Args: + other (Any): Object to be compared. + Returns: + SyPrimitiveRet: returns a PySyft boolean format checking if self <= other. + """ + res = self.value.__le__(upcast(other)) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __lt__(self, other: Any) -> SyPrimitiveRet: - res = self.value.__lt__(other) # type: ignore + """ + Compare if self < other. + + Args: + other (Any): Object to be compared. + Returns: + SyPrimitiveRet: returns a PySyft boolean format checking if self < other. + """ + res = self.value.__lt__(upcast(other)) # type: ignore return PrimitiveFactory.generate_primitive(value=res) def __ne__(self, other: Any) -> SyPrimitiveRet: - res = self.value.__ne__(other) + """ + Compare if self != other. + + Args: + other (Any): Object to be compared. + Returns: + SyPrimitiveRet: returns a PySyft boolean format checking if self != other. + """ + res = self.value.__ne__(upcast(other)) return PrimitiveFactory.generate_primitive(value=res) def __str__(self) -> str: + """Slice's string representation + + Returns: + str: The string representation of this Slice object. + """ return self.value.__str__() def indices(self, index: int) -> tuple: + """ + Assuming a sequence of length len, calculate the start and stop + indices, and the stride length of the extended slice described by + the Slice object. Out of bounds indices are clipped in + a manner consistent with the handling of normal slices. + + Args: + index (int): Input index. + Returns: + Tuple: A tuple of concrete indices for a range of length len. + """ res = self.value.indices(index) return PrimitiveFactory.generate_primitive(value=res) @property def start(self) -> Optional[int]: + """ + Index/position where the slicing of the object starts. + + Returns: + int: Index where the slicing starts. + """ return self.value.start @property def step(self) -> Optional[int]: + """ + Increment between each index for slicing. + + Returns: + int: Slices' increment value. + """ return self.value.step @property def stop(self) -> Optional[int]: + """ + Index/position which the slicing takes place. + + Returns: + int: Slices' stop value. + """ return self.value.stop def upcast(self) -> slice: + """ + Returns the standard python slice object. + + Returns: + slice: returns a default python slice object represented by this object instance. + """ return self.value def _object2proto(self) -> Slice_PB: + """ + Serialize the Slice object instance returning a protobuf. + + Returns: + Slice_PB: returns a protobuf object class representing this Slice object. + """ slice_pb = Slice_PB() if self.start: slice_pb.start = self.start @@ -112,6 +223,14 @@ def _object2proto(self) -> Slice_PB: @staticmethod def _proto2object(proto: Slice_PB) -> "Slice": + """ + Deserialize a protobuf object creating a new Slice object instance. + + Args: + proto (Slice_PB): Protobuf object representing a serialized slice object. + Returns: + Slice: PySyft Slice object instance. + """ id_: UID = sy.deserialize(blob=proto.id) start = None stop = None @@ -134,4 +253,10 @@ def _proto2object(proto: Slice_PB) -> "Slice": @staticmethod def get_protobuf_schema() -> GeneratedProtocolMessageType: + """ + Returns the proper Slice protobuf schema. + + Returns: + Slice_PB: Returns the Slice's Protobuf class definition. + """ return Slice_PB diff --git a/packages/syft/src/syft/proto/core/node/domain/service/get_all_requests_message_pb2.py b/packages/syft/src/syft/proto/core/node/domain/service/get_all_requests_message_pb2.py deleted file mode 100644 index da258ef49e3..00000000000 --- a/packages/syft/src/syft/proto/core/node/domain/service/get_all_requests_message_pb2.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: proto/core/node/domain/service/get_all_requests_message.proto -"""Generated protocol buffer code.""" -# third party -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -# syft absolute -from syft.proto.core.common import ( - common_object_pb2 as proto_dot_core_dot_common_dot_common__object__pb2, -) -from syft.proto.core.io import address_pb2 as proto_dot_core_dot_io_dot_address__pb2 -from syft.proto.core.node.domain.service import ( - request_message_pb2 as proto_dot_core_dot_node_dot_domain_dot_service_dot_request__message__pb2, -) - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n=proto/core/node/domain/service/get_all_requests_message.proto\x12\x1dsyft.core.node.domain.service\x1a\x34proto/core/node/domain/service/request_message.proto\x1a%proto/core/common/common_object.proto\x1a\x1bproto/core/io/address.proto"\x8f\x01\n\x15GetAllRequestsMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\'\n\x08reply_to\x18\x03 \x01(\x0b\x32\x15.syft.core.io.Address"\xaf\x01\n\x1dGetAllRequestsResponseMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12?\n\x08requests\x18\x03 \x03(\x0b\x32-.syft.core.node.domain.service.RequestMessageb\x06proto3' -) - - -_GETALLREQUESTSMESSAGE = DESCRIPTOR.message_types_by_name["GetAllRequestsMessage"] -_GETALLREQUESTSRESPONSEMESSAGE = DESCRIPTOR.message_types_by_name[ - "GetAllRequestsResponseMessage" -] -GetAllRequestsMessage = _reflection.GeneratedProtocolMessageType( - "GetAllRequestsMessage", - (_message.Message,), - { - "DESCRIPTOR": _GETALLREQUESTSMESSAGE, - "__module__": "proto.core.node.domain.service.get_all_requests_message_pb2" - # @@protoc_insertion_point(class_scope:syft.core.node.domain.service.GetAllRequestsMessage) - }, -) -_sym_db.RegisterMessage(GetAllRequestsMessage) - -GetAllRequestsResponseMessage = _reflection.GeneratedProtocolMessageType( - "GetAllRequestsResponseMessage", - (_message.Message,), - { - "DESCRIPTOR": _GETALLREQUESTSRESPONSEMESSAGE, - "__module__": "proto.core.node.domain.service.get_all_requests_message_pb2" - # @@protoc_insertion_point(class_scope:syft.core.node.domain.service.GetAllRequestsResponseMessage) - }, -) -_sym_db.RegisterMessage(GetAllRequestsResponseMessage) - -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _GETALLREQUESTSMESSAGE._serialized_start = 219 - _GETALLREQUESTSMESSAGE._serialized_end = 362 - _GETALLREQUESTSRESPONSEMESSAGE._serialized_start = 365 - _GETALLREQUESTSRESPONSEMESSAGE._serialized_end = 540 -# @@protoc_insertion_point(module_scope) diff --git a/packages/syft/src/syft/proto/grid/messages/request_messages_pb2.py b/packages/syft/src/syft/proto/grid/messages/request_messages_pb2.py index b028be3a289..017cf484a23 100644 --- a/packages/syft/src/syft/proto/grid/messages/request_messages_pb2.py +++ b/packages/syft/src/syft/proto/grid/messages/request_messages_pb2.py @@ -19,10 +19,13 @@ common_object_pb2 as proto_dot_core_dot_common_dot_common__object__pb2, ) from syft.proto.core.io import address_pb2 as proto_dot_core_dot_io_dot_address__pb2 +from syft.proto.core.node.domain.service import ( + request_message_pb2 as proto_dot_core_dot_node_dot_domain_dot_service_dot_request__message__pb2, +) from syft.proto.lib.python import dict_pb2 as proto_dot_lib_dot_python_dot_dict__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n*proto/grid/messages/request_messages.proto\x12\x12syft.grid.messages\x1a%proto/core/common/common_object.proto\x1a\x1bproto/core/io/address.proto\x1a\x1bproto/lib/python/dict.proto"\x9f\x01\n\x14\x43reateRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8c\x01\n\x15\x43reateRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8b\x01\n\x1a\x43reateBudgetRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x0e\n\x06\x62udget\x18\x03 \x01(\x02\x12\x0e\n\x06reason\x18\x04 \x01(\t"\xa6\x01\n\x18GetBudgetRequestsMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x92\x01\n\x19GetBudgetRequestsResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x63ontent\x18\x03 \x03(\x0b\x32\x15.syft.lib.python.Dict\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x9f\x01\n\x11GetRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8c\x01\n\x12GetRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8c\x01\n\x12GetRequestsMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\xa1\x01\n\x13GetRequestsResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12&\n\x07\x63ontent\x18\x03 \x03(\x0b\x32\x15.syft.lib.python.Dict\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\xa2\x01\n\x14\x44\x65leteRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8f\x01\n\x15\x44\x65leteRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\xb2\x01\n\x14UpdateRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\nrequest_id\x18\x04 \x01(\t\x12\'\n\x08reply_to\x18\x05 \x01(\x0b\x32\x15.syft.core.io.Address"\x9f\x01\n\x15UpdateRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\nrequest_id\x18\x04 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x05 \x01(\x0b\x32\x15.syft.core.io.Addressb\x06proto3' + b'\n*proto/grid/messages/request_messages.proto\x12\x12syft.grid.messages\x1a%proto/core/common/common_object.proto\x1a\x1bproto/core/io/address.proto\x1a\x1bproto/lib/python/dict.proto\x1a\x34proto/core/node/domain/service/request_message.proto"\x9f\x01\n\x14\x43reateRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8c\x01\n\x15\x43reateRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8b\x01\n\x1a\x43reateBudgetRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x0e\n\x06\x62udget\x18\x03 \x01(\x02\x12\x0e\n\x06reason\x18\x04 \x01(\t"\xa6\x01\n\x18GetBudgetRequestsMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x92\x01\n\x19GetBudgetRequestsResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x63ontent\x18\x03 \x03(\x0b\x32\x15.syft.lib.python.Dict\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x9f\x01\n\x11GetRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8c\x01\n\x12GetRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8c\x01\n\x12GetRequestsMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\xa1\x01\n\x13GetRequestsResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12&\n\x07\x63ontent\x18\x03 \x03(\x0b\x32\x15.syft.lib.python.Dict\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\xa2\x01\n\x14\x44\x65leteRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12\'\n\x08reply_to\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\x8f\x01\n\x15\x44\x65leteRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x12\n\nrequest_id\x18\x03 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x04 \x01(\x0b\x32\x15.syft.core.io.Address"\xb2\x01\n\x14UpdateRequestMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\nrequest_id\x18\x04 \x01(\t\x12\'\n\x08reply_to\x18\x05 \x01(\x0b\x32\x15.syft.core.io.Address"\x9f\x01\n\x15UpdateRequestResponse\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12\x13\n\x0bstatus_code\x18\x02 \x01(\x05\x12\x0e\n\x06status\x18\x03 \x01(\t\x12\x12\n\nrequest_id\x18\x04 \x01(\t\x12&\n\x07\x61\x64\x64ress\x18\x05 \x01(\x0b\x32\x15.syft.core.io.Address"\x8f\x01\n\x15GetAllRequestsMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12\'\n\x08reply_to\x18\x03 \x01(\x0b\x32\x15.syft.core.io.Address"\xaf\x01\n\x1dGetAllRequestsResponseMessage\x12%\n\x06msg_id\x18\x01 \x01(\x0b\x32\x15.syft.core.common.UID\x12&\n\x07\x61\x64\x64ress\x18\x02 \x01(\x0b\x32\x15.syft.core.io.Address\x12?\n\x08requests\x18\x03 \x03(\x0b\x32-.syft.core.node.domain.service.RequestMessageb\x06proto3' ) @@ -43,6 +46,10 @@ _DELETEREQUESTRESPONSE = DESCRIPTOR.message_types_by_name["DeleteRequestResponse"] _UPDATEREQUESTMESSAGE = DESCRIPTOR.message_types_by_name["UpdateRequestMessage"] _UPDATEREQUESTRESPONSE = DESCRIPTOR.message_types_by_name["UpdateRequestResponse"] +_GETALLREQUESTSMESSAGE = DESCRIPTOR.message_types_by_name["GetAllRequestsMessage"] +_GETALLREQUESTSRESPONSEMESSAGE = DESCRIPTOR.message_types_by_name[ + "GetAllRequestsResponseMessage" +] CreateRequestMessage = _reflection.GeneratedProtocolMessageType( "CreateRequestMessage", (_message.Message,), @@ -186,33 +193,59 @@ ) _sym_db.RegisterMessage(UpdateRequestResponse) +GetAllRequestsMessage = _reflection.GeneratedProtocolMessageType( + "GetAllRequestsMessage", + (_message.Message,), + { + "DESCRIPTOR": _GETALLREQUESTSMESSAGE, + "__module__": "proto.grid.messages.request_messages_pb2" + # @@protoc_insertion_point(class_scope:syft.grid.messages.GetAllRequestsMessage) + }, +) +_sym_db.RegisterMessage(GetAllRequestsMessage) + +GetAllRequestsResponseMessage = _reflection.GeneratedProtocolMessageType( + "GetAllRequestsResponseMessage", + (_message.Message,), + { + "DESCRIPTOR": _GETALLREQUESTSRESPONSEMESSAGE, + "__module__": "proto.grid.messages.request_messages_pb2" + # @@protoc_insertion_point(class_scope:syft.grid.messages.GetAllRequestsResponseMessage) + }, +) +_sym_db.RegisterMessage(GetAllRequestsResponseMessage) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _CREATEREQUESTMESSAGE._serialized_start = 164 - _CREATEREQUESTMESSAGE._serialized_end = 323 - _CREATEREQUESTRESPONSE._serialized_start = 326 - _CREATEREQUESTRESPONSE._serialized_end = 466 - _CREATEBUDGETREQUESTMESSAGE._serialized_start = 469 - _CREATEBUDGETREQUESTMESSAGE._serialized_end = 608 - _GETBUDGETREQUESTSMESSAGE._serialized_start = 611 - _GETBUDGETREQUESTSMESSAGE._serialized_end = 777 - _GETBUDGETREQUESTSRESPONSE._serialized_start = 780 - _GETBUDGETREQUESTSRESPONSE._serialized_end = 926 - _GETREQUESTMESSAGE._serialized_start = 929 - _GETREQUESTMESSAGE._serialized_end = 1088 - _GETREQUESTRESPONSE._serialized_start = 1091 - _GETREQUESTRESPONSE._serialized_end = 1231 - _GETREQUESTSMESSAGE._serialized_start = 1234 - _GETREQUESTSMESSAGE._serialized_end = 1374 - _GETREQUESTSRESPONSE._serialized_start = 1377 - _GETREQUESTSRESPONSE._serialized_end = 1538 - _DELETEREQUESTMESSAGE._serialized_start = 1541 - _DELETEREQUESTMESSAGE._serialized_end = 1703 - _DELETEREQUESTRESPONSE._serialized_start = 1706 - _DELETEREQUESTRESPONSE._serialized_end = 1849 - _UPDATEREQUESTMESSAGE._serialized_start = 1852 - _UPDATEREQUESTMESSAGE._serialized_end = 2030 - _UPDATEREQUESTRESPONSE._serialized_start = 2033 - _UPDATEREQUESTRESPONSE._serialized_end = 2192 + _CREATEREQUESTMESSAGE._serialized_start = 218 + _CREATEREQUESTMESSAGE._serialized_end = 377 + _CREATEREQUESTRESPONSE._serialized_start = 380 + _CREATEREQUESTRESPONSE._serialized_end = 520 + _CREATEBUDGETREQUESTMESSAGE._serialized_start = 523 + _CREATEBUDGETREQUESTMESSAGE._serialized_end = 662 + _GETBUDGETREQUESTSMESSAGE._serialized_start = 665 + _GETBUDGETREQUESTSMESSAGE._serialized_end = 831 + _GETBUDGETREQUESTSRESPONSE._serialized_start = 834 + _GETBUDGETREQUESTSRESPONSE._serialized_end = 980 + _GETREQUESTMESSAGE._serialized_start = 983 + _GETREQUESTMESSAGE._serialized_end = 1142 + _GETREQUESTRESPONSE._serialized_start = 1145 + _GETREQUESTRESPONSE._serialized_end = 1285 + _GETREQUESTSMESSAGE._serialized_start = 1288 + _GETREQUESTSMESSAGE._serialized_end = 1428 + _GETREQUESTSRESPONSE._serialized_start = 1431 + _GETREQUESTSRESPONSE._serialized_end = 1592 + _DELETEREQUESTMESSAGE._serialized_start = 1595 + _DELETEREQUESTMESSAGE._serialized_end = 1757 + _DELETEREQUESTRESPONSE._serialized_start = 1760 + _DELETEREQUESTRESPONSE._serialized_end = 1903 + _UPDATEREQUESTMESSAGE._serialized_start = 1906 + _UPDATEREQUESTMESSAGE._serialized_end = 2084 + _UPDATEREQUESTRESPONSE._serialized_start = 2087 + _UPDATEREQUESTRESPONSE._serialized_end = 2246 + _GETALLREQUESTSMESSAGE._serialized_start = 2249 + _GETALLREQUESTSMESSAGE._serialized_end = 2392 + _GETALLREQUESTSRESPONSEMESSAGE._serialized_start = 2395 + _GETALLREQUESTSRESPONSEMESSAGE._serialized_end = 2570 # @@protoc_insertion_point(module_scope) diff --git a/packages/syft/tests/syft/core/node/common/service/role_manager_service_test.py b/packages/syft/tests/syft/core/node/common/service/role_manager_service_test.py new file mode 100644 index 00000000000..e279656c56f --- /dev/null +++ b/packages/syft/tests/syft/core/node/common/service/role_manager_service_test.py @@ -0,0 +1,138 @@ +# stdlib +from unittest.mock import patch + +# third party +from nacl.signing import SigningKey + +# syft absolute +import syft as sy +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + CreateRoleMessage, +) +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + DeleteRoleMessage, +) +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + GetRoleMessage, +) +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + GetRoleResponse, +) +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + GetRolesMessage, +) +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + GetRolesResponse, +) +from syft.core.node.common.node_service.role_manager.role_manager_messages import ( + UpdateRoleMessage, +) +from syft.core.node.common.node_service.role_manager.role_manager_service import ( + RoleManagerService, +) +from syft.core.node.common.node_service.success_resp_message import ( + SuccessResponseMessage, +) + + +def test_create_role_message() -> None: + domain = sy.Domain(name="Domain Name") + role_name = "New Role" + user_key = SigningKey(domain.verify_key.encode()) + + msg = CreateRoleMessage( + address=domain.address, + name=role_name, + reply_to=domain.address, + ) + + reply = None + with patch.object(domain.users, "can_edit_roles", return_value=True): + reply = RoleManagerService.process(node=domain, msg=msg, verify_key=user_key) + + assert reply is not None + assert isinstance(reply, SuccessResponseMessage) is True + assert reply.resp_msg == "Role created successfully!" + + +def test_update_role_message() -> None: + domain = sy.Domain(name="Domain Name") + role = domain.roles.first() + new_name = "New Role Name" + user_key = SigningKey(domain.verify_key.encode()) + + msg = UpdateRoleMessage( + address=domain.address, + role_id=role.id, + name=new_name, + reply_to=domain.address, + ) + + reply = None + with patch.object(domain.users, "can_edit_roles", return_value=True): + reply = RoleManagerService.process(node=domain, msg=msg, verify_key=user_key) + + assert reply is not None + assert reply.resp_msg == "Role updated successfully!" + role_obj = domain.roles.first() + assert role_obj.name == new_name + + +def test_get_role_message() -> None: + domain = sy.Domain(name="Domain Name") + role = domain.roles.first() + user_key = SigningKey(domain.verify_key.encode()) + + msg = GetRoleMessage( + address=domain.address, + role_id=role.id, + reply_to=domain.address, + ) + + reply = None + with patch.object(domain.users, "can_triage_requests", return_value=True): + reply = RoleManagerService.process(node=domain, msg=msg, verify_key=user_key) + + assert reply is not None + assert isinstance(reply, GetRoleResponse) is True + assert reply.content is not None + assert reply.content["name"] == role.name + + +def test_get_roles_message() -> None: + domain = sy.Domain(name="Domain Name") + user_key = SigningKey(domain.verify_key.encode()) + + msg = GetRolesMessage( + address=domain.address, + reply_to=domain.address, + ) + + reply = None + with patch.object(domain.users, "can_triage_requests", return_value=True): + reply = RoleManagerService.process(node=domain, msg=msg, verify_key=user_key) + + assert reply is not None + assert isinstance(reply, GetRolesResponse) is True + assert reply.content is not None + assert type(reply.content) == list + + +def test_del_role_manager() -> None: + domain = sy.Domain(name="Domain Name") + user_key = SigningKey(domain.verify_key.encode()) + role = domain.roles.first() + + msg = DeleteRoleMessage( + address=domain.address, + reply_to=domain.address, + role_id=role.id, + ) + + reply = None + with patch.object(domain.users, "can_edit_roles", return_value=True): + reply = RoleManagerService.process(node=domain, msg=msg, verify_key=user_key) + + assert reply is not None + assert isinstance(reply, SuccessResponseMessage) is True + assert reply.resp_msg == "Role has been deleted!" diff --git a/packages/syft/tests/syft/core/tensor/passthrough_test.py b/packages/syft/tests/syft/core/tensor/passthrough_test.py index 620843b1f8b..2fd6d51d236 100644 --- a/packages/syft/tests/syft/core/tensor/passthrough_test.py +++ b/packages/syft/tests/syft/core/tensor/passthrough_test.py @@ -3,6 +3,7 @@ import numpy as np import torch + # syft absolute from syft.core.tensor.passthrough import PassthroughTensor diff --git a/packages/syft/tests/syft/lib/python/slice/slice_test.py b/packages/syft/tests/syft/lib/python/slice/slice_test.py index f330883a2b6..9adc2a19eca 100644 --- a/packages/syft/tests/syft/lib/python/slice/slice_test.py +++ b/packages/syft/tests/syft/lib/python/slice/slice_test.py @@ -10,6 +10,7 @@ import pytest # syft absolute +from syft.core.common.uid import UID from syft.lib.python.list import List from syft.lib.python.slice import Slice from syft.lib.python.string import String @@ -102,6 +103,29 @@ def test_cmp(self): self.assertNotEqual(s1.value, (1, 2, 3)) self.assertNotEqual(s1, "") + # Check if PySyft's Slice object gives the same cmp results as python slice object. + # Check __lt__ + assert (s1.value < s2.value) == (s1 < s2) + + # Check __gt__ + assert (s1.value > s2.value) == (s1 > s2) + + # Check __ne__ + assert (s1.value != s2.value) == (s1 != s2) + + # Check __ge__ + assert (s1.value >= s2.value) == (s1 >= s2) + + # Check __le__ + assert (s1.value <= s2.value) == (s1 <= s2) + + # Check if Slice against slice can retrieve proper boolean return + std_slice = slice(1, 2, 3) + + assert std_slice == s1 + + assert std_slice != s3 + class Exc(Exception): pass @@ -124,6 +148,17 @@ def __eq__(self, other): self.assertEqual(s1, s1) self.assertRaises(Exc, lambda: s1.value == s2.value) + def test_id(self): + new_id = UID() + s1 = Slice(1, 2, 3, id=new_id) + assert new_id == s1.id + + def test_upcast(self): + python_s1 = slice(1, 2, 3) + pysyft_s1 = Slice(1, 2, 3) + + assert pysyft_s1.upcast() == python_s1 + def test_members(self): s = Slice(1) self.assertEqual(s.start, None)