diff --git a/notebooks/api/0.8/05-custom-policy.ipynb b/notebooks/api/0.8/05-custom-policy.ipynb index b223311365b..1dddcf26ab4 100644 --- a/notebooks/api/0.8/05-custom-policy.ipynb +++ b/notebooks/api/0.8/05-custom-policy.ipynb @@ -87,6 +87,8 @@ " return self.state[\"counts\"]\n", " \n", " def apply_output(self, context, outputs):\n", + " if hasattr(outputs, \"syft_action_data\"):\n", + " outputs = outputs.syft_action_data\n", " output_dict = {}\n", " if self.state[\"counts\"] < self.n_calls:\n", " for output_arg in self.downloadable_output_args:\n", @@ -159,7 +161,7 @@ "print(policy.init_kwargs)\n", "a_obj = sy.ActionObject.from_obj({'y': [1,2,3]})\n", "x = policy.apply_output(None, a_obj)\n", - "x" + "x['y']" ] }, { @@ -185,39 +187,34 @@ "source": [ "x = np.array([1,2,3])\n", "x_pointer = sy.ActionObject.from_obj(x)\n", - "domain_client.api.services.action.save(x_pointer)" + "x_pointer" ] }, { "cell_type": "code", "execution_count": null, - "id": "5da4428a-0fed-41e3-b770-02fbaca20bfc", - "metadata": { - "tags": [] - }, + "id": "e82409e4", + "metadata": {}, "outputs": [], "source": [ - "@sy.syft_function(\n", - " input_policy=sy.ExactMatch(x=x_pointer),\n", - " output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=['y']),\n", - ")\n", - "def func(x):\n", - " return {\"y\": x+1}" + "domain_client.api.services.action.set(x_pointer)" ] }, { "cell_type": "code", "execution_count": null, - "id": "44565122-4ff4-4169-8e0a-db3b86bf53e1", + "id": "5da4428a-0fed-41e3-b770-02fbaca20bfc", "metadata": { "tags": [] }, "outputs": [], "source": [ - "@sy.syft_function(input_policy=sy.ExactMatch(x=x_pointer),\n", - " output_policy=sy.SingleExecutionExactOutput())\n", - "def train_mlp(x):\n", - " return x" + "@sy.syft_function(\n", + " input_policy=sy.ExactMatch(x=x_pointer),\n", + " output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=['y']),\n", + ")\n", + "def func(x):\n", + " return {\"y\": x+1}" ] }, { @@ -293,19 +290,17 @@ "outputs": [], "source": [ "res_ptr = domain_client.code.func(x=x_pointer)\n", - "res = res_ptr.get()\n", - "res" + "res_ptr" ] }, { "cell_type": "code", "execution_count": null, - "id": "b8c74835", - "metadata": { - "tags": [] - }, + "id": "31e706e4", + "metadata": {}, "outputs": [], "source": [ + "res = res_ptr.get()\n", "res" ] }, diff --git a/notebooks/api/0.8/09-blob-storage.ipynb b/notebooks/api/0.8/09-blob-storage.ipynb index 6f67340b608..268e9581532 100644 --- a/notebooks/api/0.8/09-blob-storage.ipynb +++ b/notebooks/api/0.8/09-blob-storage.ipynb @@ -18,6 +18,7 @@ "outputs": [], "source": [ "import syft as sy\n", + "import io\n", "sy.requires(SYFT_VERSION)\n", "from syft import autocache" ] @@ -40,6 +41,13 @@ "domain_client = node.login(email=\"info@openmined.org\", password=\"changethis\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Upload and retrieve SyftObject with blob storage (low level API)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -80,8 +88,17 @@ "outputs": [], "source": [ "# write/deposit object\n", - "blob_deposit = allocate_object(domain_client, user_object)\n", - "write_result = blob_deposit.write(sy.serialize(user_object, to_bytes=True))\n", + "blob_deposit = allocate_object(domain_client, user_object)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = io.BytesIO(sy.serialize(user_object, to_bytes=True))\n", + "write_result = blob_deposit.write(data)\n", "write_result" ] }, @@ -96,11 +113,248 @@ "user_object_read = blob_retrieval.read()\n", "user_object_read" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# delete object in blob storage\n", + "domain_client.api.services.blob_storage.delete(blob_deposit.blob_storage_entry_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Upload and retrieve files with blob storage (low level API)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def allocate_file(client: SyftClient, path: Path) -> BlobDeposit:\n", + " create_blob_storage_entry = CreateBlobStorageEntry.from_path(path)\n", + " return client.api.services.blob_storage.allocate(create_blob_storage_entry)\n", + "\n", + "\n", + "def upload_file(client: SyftClient, path: Path) -> sy.UID:\n", + " blob_deposit = allocate_file(client, path)\n", + " with open(path, \"rb\") as f:\n", + " blob_deposit.write(f)\n", + " return blob_deposit.blob_storage_entry_id\n", + "\n", + "\n", + "def retrieve_file(client, blob_storage_entry_id: sy.UID) -> Path:\n", + " blob_retrieval = client.api.services.blob_storage.read(blob_storage_entry_id)\n", + " file = blob_retrieval.read()\n", + " return Path(file.file_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "canada_dataset_url = \"https://github.com/OpenMined/datasets/blob/main/trade_flow/ca%20-%20feb%202021.csv?raw=True\"\n", + "data_file = autocache(canada_dataset_url, \"csv\")\n", + "data_file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "uploaded_file_storage_id = upload_file(domain_client, data_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retrieved_file = retrieve_file(domain_client, uploaded_file_storage_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Original file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pd.read_csv(data_file, nrows=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Retrieved file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pd.read_csv(retrieved_file, nrows=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retrieved_file.unlink()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# delete file from blob storage\n", + "domain_client.api.services.blob_storage.delete(uploaded_file_storage_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## From file using Action Object (Partial Functional)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "canada_dataset_url = \"https://github.com/OpenMined/datasets/blob/main/trade_flow/ca%20-%20feb%202021.csv?raw=True\"\n", + "data_file = autocache(canada_dataset_url, \"csv\")\n", + "data_file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# creating an action object from file\n", + "action_object = sy.ActionObject.from_path(path=data_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_ptr = action_object.send(domain_client)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@sy.syft_function_single_use(trade_data=data_ptr)\n", + "def sum_trade_value_mil(trade_data):\n", + " import pandas as pd\n", + " import opendp.prelude as dp\n", + " dp.enable_features(\"contrib\")\n", + " from opendp.measurements import make_laplace\n", + " aggregate = 0.\n", + " base_lap = dp.m.make_base_laplace(\n", + " dp.atom_domain(T=float),\n", + " dp.absolute_distance(T=float),\n", + " scale=10.\n", + " )\n", + " \n", + " noise = base_lap(aggregate)\n", + "\n", + " df = pd.read_csv(data_ptr.syft_action_data.file_name, low_memory=False)\n", + " total = df[\"Trade Value (US$)\"].sum()\n", + " return (float(total / 1_000_000), float(noise))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sum_trade_value_mil(trade_data=data_ptr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## clean up\n", + "## delete downloaded file\n", + "import os\n", + "os.unlink(data_ptr.syft_action_data.file_name)\n", + "\n", + "## delete file from blob storage\n", + "domain_client.api.services.blob_storage.delete(action_object.syft_blob_storage_entry_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Cleanup local domain server\n", + "if node.node_type.value == \"python\":\n", + " node.land()" + ] } ], "metadata": { "kernelspec": { - "display_name": "syft08", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -114,7 +368,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.9.7" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false } }, "nbformat": 4, diff --git a/notebooks/tutorials/data-owner/02-account-management.ipynb b/notebooks/tutorials/data-owner/02-account-management.ipynb index 7b697d1bacb..6c318f566be 100644 --- a/notebooks/tutorials/data-owner/02-account-management.ipynb +++ b/notebooks/tutorials/data-owner/02-account-management.ipynb @@ -178,7 +178,7 @@ "id": "a7eb3bff", "metadata": {}, "source": [ - "Lets update the user we just created, and change the role using the `user.update` service method" + "Lets update the user we just created, and change the role using the `users.update` service method" ] }, { @@ -198,9 +198,7 @@ "metadata": {}, "outputs": [], "source": [ - "updated_user = client.users.update(new_user.id, \n", - " UserUpdate(role=ServiceRole.DATA_SCIENTIST, password=\"123\")\n", - ")" + "updated_user = client.users.update(new_user.id, UserUpdate(role=ServiceRole.DATA_SCIENTIST, password=\"123\"))" ] }, { @@ -231,6 +229,16 @@ "ds_client = node.login(email=\"newuser@openmined.org\", password=\"123\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "75cc6719", + "metadata": {}, + "outputs": [], + "source": [ + "ds_client" + ] + }, { "cell_type": "code", "execution_count": null, @@ -254,7 +262,7 @@ "id": "82d0802d", "metadata": {}, "source": [ - "Lastly, we can delete users using the `user.delete` service method" + "Lastly, we can delete users using the `users.delete` service method" ] }, { @@ -419,14 +427,6 @@ "source": [ "client.users" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b83dd65a", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -445,7 +445,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.11.3" }, "toc": { "base_numbering": 1, diff --git a/packages/grid/backend/grid/core/node.py b/packages/grid/backend/grid/core/node.py index ccfa0bdf506..d164debc14a 100644 --- a/packages/grid/backend/grid/core/node.py +++ b/packages/grid/backend/grid/core/node.py @@ -5,6 +5,9 @@ from syft.node.node import get_node_name from syft.node.node import get_node_side_type from syft.node.node import get_node_type +from syft.node.node import get_node_uid_env +from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig +from syft.store.blob_storage.seaweedfs import SeaweedFSConfig from syft.store.mongo_client import MongoStoreClientConfig from syft.store.mongo_document_store import MongoStoreConfig from syft.store.sqlite_document_store import SQLiteStoreClientConfig @@ -32,6 +35,17 @@ node_side_type = get_node_side_type() enable_warnings = get_enable_warnings() +seaweed_client_config = SeaweedFSClientConfig( + host=settings.S3_ENDPOINT, + port=settings.S3_PORT, + access_key=settings.S3_ROOT_USER, + secret_key=settings.S3_ROOT_PWD, + region=settings.S3_REGION, + bucket_name=get_node_uid_env(), +) + +blob_storage_config = SeaweedFSConfig(client_config=seaweed_client_config) + if node_type == "gateway" or node_type == "network": worker = Gateway( @@ -40,6 +54,7 @@ action_store_config=sql_store_config, document_store_config=mongo_store_config, enable_warnings=enable_warnings, + blob_storage_config=blob_storage_config, ) else: worker = Domain( @@ -48,4 +63,5 @@ action_store_config=sql_store_config, document_store_config=mongo_store_config, enable_warnings=enable_warnings, + blob_storage_config=blob_storage_config, ) diff --git a/packages/grid/default.env b/packages/grid/default.env index 676520c1569..69575d79f46 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -30,7 +30,8 @@ DOCKER_IMAGE_TRAEFIK=traefik TRAEFIK_VERSION=v2.10 REDIS_VERSION=6.2 RABBITMQ_VERSION=3 -DOCKER_IMAGE_SEAWEEDFS=chrislusf/seaweedfs:latest +SEAWEEDFS_VERSION=3.55 +DOCKER_IMAGE_SEAWEEDFS=chrislusf/seaweedfs VERSION=latest VERSION_HASH=unknown STACK_API_KEY="" diff --git a/packages/grid/docker-compose.dev.yml b/packages/grid/docker-compose.dev.yml index 046720b6b79..3faef69b269 100644 --- a/packages/grid/docker-compose.dev.yml +++ b/packages/grid/docker-compose.dev.yml @@ -46,6 +46,8 @@ services: - ${RELATIVE_PATH}./data/package-cache:/root/.cache environment: - DEV_MODE=True + stdin_open: true + tty: true # backend_stream: # volumes: @@ -71,12 +73,12 @@ services: ports: - "4000" - # seaweedfs: - # profiles: - # - blob-storage - # # volumes: - # # - ./data/seaweedfs:/data - # ports: - # - "9333" # admin web port - # - "8888" # filer web port - # # - "8333" # S3 API port + seaweedfs: + profiles: + - blob-storage + # volumes: + # - ./data/seaweedfs:/data + ports: + - "9333" # admin web port + - "8888" # filer web port + - "8333" # S3 API port diff --git a/packages/grid/docker-compose.pull.yml b/packages/grid/docker-compose.pull.yml index f69e67ed20c..a7d2cf9de22 100644 --- a/packages/grid/docker-compose.pull.yml +++ b/packages/grid/docker-compose.pull.yml @@ -6,8 +6,8 @@ services: # queue: # image: rabbitmq:${RABBITMQ_VERSION?Variable not Set}${RABBITMQ_MANAGEMENT:-} - # seaweedfs: - # image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}" + seaweedfs: + image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}:${SEAWEEDFS_VERSION}" # docker-host: # image: qoomon/docker-host diff --git a/packages/grid/docker-compose.test.yml b/packages/grid/docker-compose.test.yml index 9397c9af8f8..92716eaa469 100644 --- a/packages/grid/docker-compose.test.yml +++ b/packages/grid/docker-compose.test.yml @@ -27,13 +27,13 @@ services: ports: - "4000" - # seaweedfs: - # profiles: - # - blob-storage - # ports: - # - "9333" # admin - # - "8888" # filer - # - "8333" # S3 + seaweedfs: + profiles: + - blob-storage + ports: + - "9333" # admin + - "8888" # filer + - "8333" # S3 backend: environment: diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index ad88ad7474d..1ad72fc49eb 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -214,20 +214,24 @@ services: - NETWORK_NAME=omnet - STACK_API_KEY=$STACK_API_KEY - # seaweedfs: - # profiles: - # - blob-storage - # depends_on: - # - proxy - # - redis - # image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}" - # environment: - # - S3_VOLUME_SIZE_MB=${S3_VOLUME_SIZE_MB:-1024} - # command: "server -s3 -s3.config=/etc/s3config.json -master.volumeSizeLimitMB=${S3_VOLUME_SIZE_MB}" - # volumes: - # - seaweedfs-data:/data - # - ./seaweedfs/s3config.json:/etc/s3config.json - # - ./seaweedfs/filer.toml:/etc/seaweedfs/filer.toml + seaweedfs: + profiles: + - blob-storage + depends_on: + - proxy + image: "${DOCKER_IMAGE_SEAWEEDFS?Variable not set}:${SEAWEEDFS_VERSION}" + environment: + - S3_VOLUME_SIZE_MB=${S3_VOLUME_SIZE_MB:-1024} + - S3_ROOT_USER=${S3_ROOT_USER:-admin} + - S3_ROOT_PWD=${S3_ROOT_PWD:-admin} + - S3_PORT=${S3_PORT:-8888} + entrypoint: ["/bin/sh"] + command: + - "/etc/seaweedfs/start.sh" + volumes: + - seaweedfs-data:/data/blob + - ./seaweedfs/filer.toml:/etc/seaweedfs/filer.toml + - ./seaweedfs/start.sh:/etc/seaweedfs/start.sh mongo: image: mongo:latest @@ -264,7 +268,7 @@ volumes: tailscale-data: headscale-data: # app-redis-data: - # seaweedfs-data: + seaweedfs-data: mongo-data: networks: diff --git a/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml b/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml index 81d108795be..1684f1c3dfe 100644 --- a/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml +++ b/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml @@ -34,7 +34,8 @@ data: TRAEFIK_VERSION: v2.8.1 REDIS_VERSION: 6.2 RABBITMQ_VERSION: 3 - DOCKER_IMAGE_SEAWEEDFS: chrislusf/seaweedfs:latest + SEAWEEDFS_VERSION: 3.55 + DOCKER_IMAGE_SEAWEEDFS: chrislusf/seaweedfs:3.55 VERSION: 0.8.2-beta.6 VERSION_HASH: unknown STACK_API_KEY: "" diff --git a/packages/grid/seaweedfs/filer.toml b/packages/grid/seaweedfs/filer.toml index 00f62836f9c..dd69566768f 100644 --- a/packages/grid/seaweedfs/filer.toml +++ b/packages/grid/seaweedfs/filer.toml @@ -1,5 +1,11 @@ -[redis] +# [redis] +# enabled = true +# address = "redis:6379" +# password = "" +# database = 15 + +[leveldb2] +# local on disk, mostly for simple single-machine setup, fairly scalable +# faster than previous leveldb, recommended. enabled = true -address = "redis:6379" -password = "" -database = 15 +dir = "./filerldb2" \ No newline at end of file diff --git a/packages/grid/seaweedfs/start.sh b/packages/grid/seaweedfs/start.sh new file mode 100644 index 00000000000..d6dc34f535d --- /dev/null +++ b/packages/grid/seaweedfs/start.sh @@ -0,0 +1,6 @@ +#! /usr/bin/env bash + +sleep 30 && +echo "s3.configure -access_key ${S3_ROOT_USER} -secret_key ${S3_ROOT_PWD} -user iam -actions Read,Write,List,Tagging,Admin -apply" \ +| weed shell > /dev/null 2>&1 \ +& weed server -s3 -s3.port=${S3_PORT} -master.volumeSizeLimitMB=${S3_VOLUME_SIZE_MB} \ No newline at end of file diff --git a/packages/hagrid/hagrid/cli.py b/packages/hagrid/hagrid/cli.py index 698ee862f62..17d67026b2b 100644 --- a/packages/hagrid/hagrid/cli.py +++ b/packages/hagrid/hagrid/cli.py @@ -443,6 +443,20 @@ def clean(location: str) -> None: is_flag=True, help="Launch a low side node type else a high side node type", ) +@click.option( + "--set-s3-username", + default=None, + required=False, + type=str, + help="Set root username for s3 blob storage", +) +@click.option( + "--set-s3-password", + default=None, + required=False, + type=str, + help="Set root password for s3 blob storage", +) def launch(args: TypeTuple[str], **kwargs: Any) -> None: verb = get_launch_verb() try: @@ -1236,6 +1250,10 @@ def create_launch_cmd( parsed_kwargs["use_blob_storage"] = not bool(kwargs["no_blob_storage"]) + if parsed_kwargs["use_blob_storage"]: + parsed_kwargs["set_s3_username"] = kwargs["set_s3_username"] + parsed_kwargs["set_s3_password"] = kwargs["set_s3_password"] + parsed_kwargs["node_count"] = ( int(kwargs["node_count"]) if "node_count" in kwargs else 1 ) @@ -1299,11 +1317,11 @@ def create_launch_cmd( else: parsed_kwargs["image_name"] = "default" - if "tag" in kwargs and kwargs["tag"] is not None and kwargs["tag"] != "": - parsed_kwargs["tag"] = kwargs["tag"] + if parsed_kwargs["dev"] is True: + parsed_kwargs["tag"] = "local" else: - if parsed_kwargs["dev"] is True: - parsed_kwargs["tag"] = "local" + if "tag" in kwargs and kwargs["tag"] is not None and kwargs["tag"] != "": + parsed_kwargs["tag"] = kwargs["tag"] else: parsed_kwargs["tag"] = "latest" @@ -2189,6 +2207,12 @@ def create_launch_docker_cmd( if "set_root_email" in kwargs and kwargs["set_root_email"] is not None: envs["DEFAULT_ROOT_EMAIL"] = kwargs["set_root_email"] + if "set_s3_username" in kwargs and kwargs["set_s3_username"] is not None: + envs["S3_ROOT_USER"] = kwargs["set_s3_username"] + + if "set_s3_password" in kwargs and kwargs["set_s3_password"] is not None: + envs["S3_ROOT_PWD"] = kwargs["set_s3_password"] + if "release" in kwargs: envs["RELEASE"] = kwargs["release"] diff --git a/packages/hagrid/hagrid/manifest_template.yml b/packages/hagrid/hagrid/manifest_template.yml index fa6cc2b4469..a38e5283189 100644 --- a/packages/hagrid/hagrid/manifest_template.yml +++ b/packages/hagrid/hagrid/manifest_template.yml @@ -12,7 +12,7 @@ files: - rabbitmq/rabbitmq.conf - redis/redis.conf - seaweedfs/filer.toml - - seaweedfs/s3config.json + - seaweedfs/start.sh - vpn/config.yaml - default.env docker: diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 749f705fa05..d5188c17572 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -27,6 +27,7 @@ package_dir = syft = bcrypt==4.0.1 + boto3==1.28.20 forbiddenfruit==0.1.4 gevent==22.10.2 gipc==1.5.0 diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 7c3d1d099f2..c81c170076c 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -372,7 +372,6 @@ def debox_signed_syftapicall_response( if not signed_result.is_valid: return SyftError(message="The result signature is invalid") # type: ignore - return signed_result.message.data diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 1fa8381bc3d..920f2dba9b2 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -123,6 +123,7 @@ class Routes(Enum): ROUTE_LOGIN = f"{API_PATH}/login" ROUTE_REGISTER = f"{API_PATH}/register" ROUTE_API_CALL = f"{API_PATH}/api_call" + ROUTE_BLOB_STORE = "/blob" @serializable(attrs=["proxy_target_uid", "url"]) @@ -149,6 +150,10 @@ def get_cache_key(self) -> str: def api_url(self) -> GridURL: return self.url.with_path(self.routes.ROUTE_API_CALL.value) + def to_blob_route(self, path: str) -> GridURL: + _path = self.routes.ROUTE_BLOB_STORE.value + path + return self.url.with_path(_path) + @property def session(self) -> Session: if self.session_cache is None: diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index c8d0be36a89..7adaeaa8b39 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -75,7 +75,13 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError for asset in tqdm(dataset.asset_list): print(f"Uploading: {asset.name}") try: - twin = TwinObject(private_obj=asset.data, mock_obj=asset.mock) + twin = TwinObject( + private_obj=asset.data, + mock_obj=asset.mock, + syft_node_location=self.id, + syft_client_verify_key=self.verify_key, + ) + twin._save_to_blob_storage() except Exception as e: return SyftError(message=f"Failed to create twin. {e}") response = self.api.services.action.set(twin) diff --git a/packages/syft/src/syft/external/oblv/oblv_service.py b/packages/syft/src/syft/external/oblv/oblv_service.py index b185c53006b..38d380be88a 100644 --- a/packages/syft/src/syft/external/oblv/oblv_service.py +++ b/packages/syft/src/syft/external/oblv/oblv_service.py @@ -387,29 +387,21 @@ def send_user_code_inputs_to_enclave( user_code.status = res.ok() user_code_service.update_code_state(context=context, code_item=user_code) + root_context = context.as_root_context() + if not action_service.exists(context=context, obj_id=user_code_id): dict_object = ActionObject.from_obj({}) dict_object.id = user_code_id dict_object[str(context.credentials)] = inputs - action_service.store.set( - uid=user_code_id, - credentials=context.node.verify_key, - syft_object=dict_object, - has_result_read_permission=True, - ) + root_context.extra_kwargs = {"has_result_read_permission": True} + action_service.set(root_context, dict_object) else: - res = action_service.store.get( - uid=user_code_id, credentials=context.node.verify_key - ) + res = action_service.get(uid=user_code_id, context=context) if res.is_ok(): dict_object = res.ok() dict_object[str(context.credentials)] = inputs - action_service.store.set( - uid=user_code_id, - credentials=context.node.verify_key, - syft_object=dict_object, - ) + action_service.set(root_context, dict_object) else: return res diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 5b87aa8ca91..b4584bf6a6f 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -3,6 +3,7 @@ # stdlib import binascii +from collections import OrderedDict import contextlib from datetime import datetime from functools import partial @@ -16,6 +17,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Tuple from typing import Type from typing import Union import uuid @@ -82,6 +84,7 @@ from ..service.user.user_service import UserService from ..service.user.user_stash import UserStash from ..store.blob_storage import BlobStorageConfig +from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.dict_document_store import DictStoreConfig from ..store.document_store import StoreConfig @@ -93,6 +96,7 @@ from ..types.uid import UID from ..util.experimental_flags import flags from ..util.telemetry import instrument +from ..util.util import get_root_data_path from ..util.util import random_name from ..util.util import str_to_bool from ..util.util import thread_ident @@ -179,6 +183,40 @@ def get_venv_packages() -> str: default_root_password = get_default_root_password() +class AuthNodeContextRegistry: + __node_context_registry__: Dict[Tuple, Node] = OrderedDict() + + @classmethod + def set_node_context( + cls, + node_uid: Union[UID, str], + context: NodeServiceContext, + user_verify_key: Union[SyftVerifyKey, str], + ): + if isinstance(node_uid, str): + node_uid = UID.from_string(node_uid) + + if isinstance(user_verify_key, str): + user_verify_key = SyftVerifyKey.from_string(user_verify_key) + + key = cls._get_key(node_uid=node_uid, user_verify_key=user_verify_key) + + cls.__node_context_registry__[key] = context + + @staticmethod + def _get_key(node_uid: UID, user_verify_key: SyftVerifyKey) -> str: + return "-".join(str(x) for x in (node_uid, user_verify_key)) + + @classmethod + def auth_context_for_user( + cls, + node_uid: UID, + user_verify_key: SyftVerifyKey, + ) -> Optional[AuthedServiceContext]: + key = cls._get_key(node_uid=node_uid, user_verify_key=user_verify_key) + return cls.__node_context_registry__.get(key) + + @instrument class Node(AbstractNode): signing_key: Optional[SyftSigningKey] @@ -299,7 +337,14 @@ def __init__( NodeRegistry.set_node_for(self.id, self) def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None: - config_ = OnDiskBlobStorageConfig() if config is None else config + if config is None: + root_directory = get_root_data_path() + base_directory = root_directory / f"{self.id}" + client_config = OnDiskBlobStorageClientConfig(base_directory=base_directory) + config_ = OnDiskBlobStorageConfig(client_config=client_config) + else: + config_ = config + self.blob_store_config = config_ self.blob_storage_client = config_.client_type(config=config_.client_config) def init_queue_manager(self, queue_config: Optional[QueueConfig]): @@ -343,6 +388,7 @@ def named( raise Exception(f"Invalid UID: {name_hash_string} for name: {name}") uid = UID(name_hash_string) key = SyftSigningKey(signing_key=SigningKey(name_hash)) + blob_storage_config = None if reset: store_config = SQLiteStoreClientConfig() store_config.filename = f"{uid}.sqlite" @@ -366,6 +412,19 @@ def named( if os.path.exists(store_config.file_path): os.unlink(store_config.file_path) + # Reset blob storage + root_directory = get_root_data_path() + base_directory = root_directory / f"{uid}" + if base_directory.exists(): + for file in base_directory.iterdir(): + file.unlink() + blob_client_config = OnDiskBlobStorageClientConfig( + base_directory=base_directory + ) + blob_storage_config = OnDiskBlobStorageConfig( + client_config=blob_client_config + ) + return cls( name=name, id=uid, @@ -376,6 +435,7 @@ def named( node_type=node_type, node_side_type=node_side_type, enable_warnings=enable_warnings, + blob_storage_config=blob_storage_config, ) def is_root(self, credentials: SyftVerifyKey) -> bool: @@ -390,7 +450,9 @@ def root_client(self): client_type = connection.get_client_type() if isinstance(client_type, SyftError): return client_type - return client_type(connection=connection, credentials=self.signing_key) + root_client = client_type(connection=connection, credentials=self.signing_key) + root_client.api.refresh_api_callback() + return root_client @property def guest_client(self): @@ -411,7 +473,11 @@ def get_guest_client(self, verbose: bool = True): if isinstance(client_type, SyftError): return client_type - return client_type(connection=connection, credentials=SyftSigningKey.generate()) + guest_client = client_type( + connection=connection, credentials=SyftSigningKey.generate() + ) + guest_client.api.refresh_api_callback() + return guest_client def __repr__(self) -> str: service_string = "" @@ -424,7 +490,12 @@ def __repr__(self) -> str: return f"{type(self).__name__}: {self.name} - {self.id} - {self.node_type}{service_string}" def post_init(self) -> None: - context = AuthedServiceContext(node=self, credentials=self.verify_key) + context = AuthedServiceContext( + node=self, credentials=self.verify_key, role=ServiceRole.ADMIN + ) + AuthNodeContextRegistry.set_node_context( + node_uid=self.id, user_verify_key=self.verify_key, context=context + ) if UserCodeService in self.services: user_code_service = self.get_service(UserCodeService) @@ -718,6 +789,7 @@ def handle_api_call_with_unsigned_result( context = AuthedServiceContext( node=self, credentials=credentials, role=role ) + AuthNodeContextRegistry.set_node_context(self.id, context, credentials) user_config_registry = UserServiceConfigRegistry.from_role(role) @@ -835,6 +907,7 @@ def task_runner( signing_key=worker_settings.signing_key, document_store_config=worker_settings.document_store_config, action_store_config=worker_settings.action_store_config, + blob_storage_config=worker_settings.blob_store_config, is_subprocess=True, ) try: diff --git a/packages/syft/src/syft/node/worker_settings.py b/packages/syft/src/syft/node/worker_settings.py index b0d88ce92a6..6996fb411ee 100644 --- a/packages/syft/src/syft/node/worker_settings.py +++ b/packages/syft/src/syft/node/worker_settings.py @@ -1,6 +1,9 @@ # future from __future__ import annotations +# stdlib +from typing import Optional + # third party from typing_extensions import Self @@ -10,6 +13,7 @@ from ..abstract_node import NodeType from ..node.credentials import SyftSigningKey from ..serde.serializable import serializable +from ..store.blob_storage import BlobStorageConfig from ..store.document_store import StoreConfig from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SyftObject @@ -28,6 +32,7 @@ class WorkerSettings(SyftObject): signing_key: SyftSigningKey document_store_config: StoreConfig action_store_config: StoreConfig + blob_store_config: Optional[BlobStorageConfig] @staticmethod def from_node(node: AbstractNode) -> Self: @@ -39,4 +44,5 @@ def from_node(node: AbstractNode) -> Self: document_store_config=node.document_store_config, action_store_config=node.action_store_config, node_side_type=node.node_side_type.value, + blob_store_config=node.blob_store_config, ) diff --git a/packages/syft/src/syft/service/action/action_data_empty.py b/packages/syft/src/syft/service/action/action_data_empty.py index f45a784c396..d1e5ae44381 100644 --- a/packages/syft/src/syft/service/action/action_data_empty.py +++ b/packages/syft/src/syft/service/action/action_data_empty.py @@ -2,8 +2,13 @@ from __future__ import annotations # stdlib +from pathlib import Path from typing import Optional from typing import Type +from typing import Union + +# third party +import pydantic # relative from ...serde.serializable import serializable @@ -25,3 +30,21 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"{type(self).__name__} UID: {self.id} <{self.syft_internal_type}>" + + +@serializable() +class ActionFileData(SyftObject): + __canonical_name__ = "ActionFileData" + __version__ = SYFT_OBJECT_VERSION_1 + + path: Path + + @pydantic.validator("path", pre=True) + def __validate_file_path(cls, v: Union[str, Path]) -> Path: + if isinstance(v, str): + v = Path(v) + + if v.exists() and v.is_file(): + return v + + raise ValueError(f"Not a valid path to file. {v}") diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 84059871b40..a49bcedd780 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -4,13 +4,14 @@ # stdlib from enum import Enum import inspect +from io import BytesIO +from pathlib import Path import traceback import types from typing import Any from typing import Callable from typing import ClassVar from typing import Dict -from typing import KeysView from typing import List from typing import Optional from typing import Tuple @@ -27,9 +28,13 @@ # relative from ...client.api import SyftAPI from ...client.client import SyftClient +from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable +from ...serde.serialize import _serialize as serialize from ...service.response import SyftError from ...store.linked_obj import LinkedObject +from ...types.blob_storage import CreateBlobStorageEntry +from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject @@ -37,7 +42,9 @@ from ...types.uid import UID from ...util.logger import debug from ..response import SyftException +from ..service import from_api_or_context from .action_data_empty import ActionDataEmpty +from .action_data_empty import ActionFileData from .action_permissions import ActionPermission from .action_types import action_type_for_object from .action_types import action_type_for_type @@ -62,6 +69,10 @@ class ActionType(Enum): CREATEOBJECT = 16 +def repr_cls(c): + return f"{c.__module__}.{c.__name__}" + + @serializable() class Action(SyftObject): """Serializable Action object. @@ -179,6 +190,8 @@ class ActionObjectPointer: "get_from", # syft "get", # syft "delete_data", # syft + "_save_to_blob_storage_", # syft + "syft_action_data", # syft ] dont_wrap_output_attrs = [ "__repr__", @@ -259,7 +272,8 @@ def make_action_side_effect( action_type=context.action_type, ) context.action = action - except Exception: + except Exception as e: + raise e print(f"make_action_side_effect failed with {traceback.format_exc()}") return Err(f"make_action_side_effect failed with {traceback.format_exc()}") return Ok((context, args, kwargs)) @@ -268,6 +282,7 @@ def make_action_side_effect( class TraceResult: result = [] _client = None + is_tracing = False @classmethod def reset(cls): @@ -295,21 +310,31 @@ def convert_to_pointers( if args is not None: for arg in args: if not isinstance(arg, ActionObject): - arg = ActionObject.from_obj(arg) + arg = ActionObject.from_obj( + syft_action_data=arg, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) arg.syft_node_uid = node_uid + r = arg._save_to_blob_storage() + if isinstance(r, SyftError): + print(r.message) arg = api.services.action.set(arg) - # arg = action_obj.send( - # client - # ) # make sure this doesn't break things later on in send_method_action arg_list.append(arg) if kwargs is not None: for k, arg in kwargs.items(): if not isinstance(arg, ActionObject): - arg = ActionObject.from_obj(arg) + arg = ActionObject.from_obj( + syft_action_data=arg, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) arg.syft_node_uid = node_uid + r = arg._save_to_blob_storage() + if isinstance(r, SyftError): + print(r.message) arg = api.services.action.set(arg) - # arg = action_obj.send(client) kwarg_dict[k] = arg @@ -410,6 +435,14 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]: "_repr_debug_", "as_empty", "get", + "_save_to_blob_storage", + "_save_to_blob_storage_", + "syft_action_data", + "__check_action_data", + "as_empty_data", + "_set_obj_location_", + "syft_action_data_cache", + "reload_cache", ] @@ -420,8 +453,8 @@ class ActionObject(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 __attr_searchable__: List[str] = [] - syft_action_data: Optional[Any] = None - # syft_action_proxy_reference: Optional[LinkedObject] = None + syft_action_data_cache: Optional[Any] = None + syft_blob_storage_entry_id: Optional[UID] = None syft_pointer_type: ClassVar[Type[ActionObjectPointer]] # Help with calculating history hash for code verification @@ -436,29 +469,107 @@ class ActionObject(SyftObject): _syft_post_hooks__: Dict[str, List] = {} syft_twin_type: TwinMode = TwinMode.NONE syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS + syft_action_data_type: Optional[Type] + syft_action_data_repr_: Optional[str] + syft_action_data_str_: Optional[str] + syft_has_bool_attr: Optional[bool] + syft_resolve_data: Optional[bool] + syft_created_at: Optional[DateTime] # syft_dont_wrap_attrs = ["shape"] - # @property - # def syft_action_proxy(self) -> Optional[BlobStorageEntry]: - # return ( - # self.syft_action_proxy_reference.resolve - # if self.syft_action_proxy_reference is not None - # else None - # ) + @property + def syft_action_data(self) -> Any: + if ( + self.syft_blob_storage_entry_id + and self.syft_created_at + and not TraceResult.is_tracing + ): + self.reload_cache() - # @property - # def syft_action_data(self) -> Any: - # # relative - # from ...client.api import APIRegistry + return self.syft_action_data_cache - # api = APIRegistry.api_for( - # node_uid=self.node_uid, - # user_verify_key=self.syft_client_verify_key, - # ) - # syft_object_resource = api.services.blob_storage.read( - # uid=self.syft_action_proxy_reference.id - # ) - # return syft_object_resource.read() + def reload_cache(self): + # If ActionDataEmpty then try to fetch it from store. + if isinstance(self.syft_action_data_cache, ActionDataEmpty): + blob_storage_read_method = from_api_or_context( + func_or_path="blob_storage.read", + syft_node_location=self.syft_node_location, + syft_client_verify_key=self.syft_client_verify_key, + ) + + if blob_storage_read_method is not None: + blob_retrieval_object = blob_storage_read_method( + uid=self.syft_blob_storage_entry_id + ) + if isinstance(blob_retrieval_object, SyftError): + print( + "Detached action object, object exists but is not linked to data in the blob storage", + blob_retrieval_object, + ) + return blob_retrieval_object + self.syft_action_data_cache = blob_retrieval_object.read() + self.syft_action_data_type = type(self.syft_action_data) + + def _save_to_blob_storage_(self, data: Any) -> None: + if not isinstance(data, ActionDataEmpty): + if isinstance(data, ActionFileData): + storage_entry = CreateBlobStorageEntry.from_path(data.path) + else: + storage_entry = CreateBlobStorageEntry.from_obj(data) + + allocate_method = from_api_or_context( + func_or_path="blob_storage.allocate", + syft_node_location=self.syft_node_location, + syft_client_verify_key=self.syft_client_verify_key, + ) + if allocate_method is not None: + blob_deposit_object = allocate_method(storage_entry) + + if isinstance(blob_deposit_object, SyftError): + return blob_deposit_object + + if isinstance(data, ActionFileData): + with open(data.path, "rb") as f: + result = blob_deposit_object.write(f) + else: + result = blob_deposit_object.write( + BytesIO(serialize(data, to_bytes=True)) + ) + + if isinstance(result, SyftError): + return result + self.syft_blob_storage_entry_id = ( + blob_deposit_object.blob_storage_entry_id + ) + + self.syft_action_data_type = type(data) + + if inspect.isclass(data): + self.syft_action_data_repr_ = repr_cls(data) + else: + self.syft_action_data_repr_ = ( + data._repr_markdown_() + if hasattr(data, "_repr_markdown_") + else data.__repr__() + ) + self.syft_action_data_str_ = str(data) + self.syft_has_bool_attr = hasattr(data, "__bool__") + else: + debug("skipping writing action object to store, passed data was empty.") + + self.syft_action_data_cache = data + + def _save_to_blob_storage(self) -> Optional[SyftError]: + data = self.syft_action_data + if isinstance(data, SyftError): + return data + if isinstance(data, ActionDataEmpty): + return SyftError(f"cannot store empty object {self.id}") + result = self._save_to_blob_storage_(data) + if isinstance(result, SyftError): + return result + if not TraceResult.is_tracing: + self.syft_action_data_cache = self.as_empty_data() @property def is_pointer(self) -> bool: @@ -474,6 +585,27 @@ def make_id(cls, v: Optional[UID]) -> UID: """Generate or reuse an UID""" return Action.make_id(v) + class Config: + validate_assignment = True + + @pydantic.root_validator() + def __check_action_data(cls, values: dict) -> dict: + v = values.get("syft_action_data_cache") + if values.get("syft_action_data_type", None) is None: + values["syft_action_data_type"] = type(v) + if not isinstance(v, ActionDataEmpty): + if inspect.isclass(v): + values["syft_action_data_repr_"] = repr_cls(v) + else: + values["syft_action_data_repr_"] = ( + v._repr_markdown_() + if hasattr(v, "_repr_markdown_") + else v.__repr__() + ) + values["syft_action_data_str_"] = str(v) + values["syft_has_bool_attr"] = hasattr(v, "__bool__") + return values + @property def is_mock(self): return self.syft_twin_type == TwinMode.MOCK @@ -486,17 +618,17 @@ def is_real(self): def is_twin(self): return self.syft_twin_type != TwinMode.NONE - @pydantic.validator("syft_action_data", pre=True, always=True) - def check_action_data( - cls, v: ActionObject.syft_pointer_type - ) -> ActionObject.syft_pointer_type: - if cls == AnyActionObject or isinstance( - v, (cls.syft_internal_type, ActionDataEmpty) - ): - return v - raise SyftException( - f"Must init {cls} with {cls.syft_internal_type} not {type(v)}" - ) + # @pydantic.validator("syft_action_data", pre=True, always=True) + # def check_action_data( + # cls, v: ActionObject.syft_pointer_type + # ) -> ActionObject.syft_pointer_type: + # if cls == AnyActionObject or isinstance( + # v, (cls.syft_internal_type, ActionDataEmpty) + # ): + # return v + # raise SyftException( + # f"Must init {cls} with {cls.syft_internal_type} not {type(v)}" + # ) def syft_point_to(self, node_uid: UID) -> ActionObject: """Set the syft_node_uid, used in the post hooks""" @@ -579,6 +711,19 @@ def _syft_try_to_save_to_store(self, obj) -> None: # relative from ...client.api import APIRegistry + if obj.syft_node_location is None: + obj.syft_node_location = obj.syft_node_uid + + api = None + if TraceResult._client is not None: + api = TraceResult._client.api + + if api is not None: + obj._set_obj_location_(api.node_uid, api.signing_key.verify_key) + res = obj._save_to_blob_storage() + if isinstance(res, SyftError): + print(f"failed saving {obj} to blob storage, error: {res}") + action = Action( path="", op="", @@ -590,15 +735,16 @@ def _syft_try_to_save_to_store(self, obj) -> None: create_object=obj, ) - if TraceResult._client is not None: - api = TraceResult._client.api + if api is not None: TraceResult.result += [action] else: api = APIRegistry.api_for( - node_uid=self.syft_node_uid, + node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) - api.services.action.execute(action) + res = api.services.action.execute(action) + if isinstance(res, SyftError): + print(f"Failed to to store (arg) {obj} to store, {res}") def _syft_prepare_obj_uid(self, obj) -> LineageID: # We got the UID @@ -617,7 +763,13 @@ def _syft_prepare_obj_uid(self, obj) -> LineageID: # We got a raw object. We need to create the ActionObject from scratch and save it in the store. obj_id = Action.make_id(None) lin_obj_id = Action.make_result_id(obj_id) - act_obj = ActionObject.from_obj(obj, id=obj_id, syft_lineage_id=lin_obj_id) + act_obj = ActionObject.from_obj( + obj, + id=obj_id, + syft_lineage_id=lin_obj_id, + syft_client_verify_key=self.syft_client_verify_key, + syft_node_location=self.syft_node_location, + ) self._syft_try_to_save_to_store(act_obj) @@ -716,7 +868,7 @@ def syft_make_action_with_self( def syft_get_path(self) -> str: """Get the type path of the underlying object""" if isinstance(self, AnyActionObject) and self.syft_internal_type: - return f"{type(self.syft_action_data).__name__}" # avoids AnyActionObject errors + return f"{self.syft_action_data_type.__name__}" # avoids AnyActionObject errors return f"{type(self).__name__}" def syft_remote_method( @@ -743,9 +895,11 @@ def wrapper( def send(self, client: SyftClient) -> Self: """Send the object to a Syft Client""" + self._set_obj_location_(client.id, client.verify_key) + self._save_to_blob_storage() res = client.api.services.action.set(self) - res.syft_node_location = client.id - res.syft_client_verify_key = client.verify_key + if isinstance(res, ActionObject): + self.syft_created_at = res.syft_created_at return res def get_from(self, client: SyftClient) -> Any: @@ -779,11 +933,50 @@ def as_empty(self): id = id.id return ActionObject.empty(self.syft_internal_type, id, self.syft_lineage_id) + @staticmethod + def from_path( + path: Union[str, Path], + id: Optional[UID] = None, + syft_lineage_id: Optional[LineageID] = None, + syft_client_verify_key: Optional[SyftVerifyKey] = None, + syft_node_location: Optional[UID] = None, + ): + """Create an Action Object from a file.""" + + if id is not None and syft_lineage_id is not None and id != syft_lineage_id.id: + raise ValueError("UID and LineageID should match") + + syft_action_data = ActionFileData( + path=path if isinstance(path, Path) else Path(path) + ) + action_type = action_type_for_object(syft_action_data) + + action_object = action_type(syft_action_data_cache=syft_action_data) + + if id is not None: + action_object.id = id + + if syft_client_verify_key is not None: + action_object.syft_client_verify_key = syft_client_verify_key + + if syft_node_location is not None: + action_object.syft_node_location = syft_node_location + + if syft_lineage_id is not None: + action_object.id = syft_lineage_id.id + action_object.syft_history_hash = syft_lineage_id.syft_history_hash + elif id is not None: + action_object.syft_history_hash = hash(id) + + return action_object + @staticmethod def from_obj( syft_action_data: Any, id: Optional[UID] = None, syft_lineage_id: Optional[LineageID] = None, + syft_client_verify_key: Optional[SyftVerifyKey] = None, + syft_node_location: Optional[UID] = None, ) -> ActionObject: """Create an ActionObject from an existing object. @@ -795,19 +988,25 @@ def from_obj( syft_lineage_id: Optional[LineageID] Which LineageID to use for the ActionObject. Optional """ - if id and syft_lineage_id and id != syft_lineage_id.id: + if id is not None and syft_lineage_id is not None and id != syft_lineage_id.id: raise ValueError("UID and LineageID should match") action_type = action_type_for_object(syft_action_data) - action_object = action_type(syft_action_data=syft_action_data) + action_object = action_type(syft_action_data_cache=syft_action_data) - if id: + if id is not None: action_object.id = id - if syft_lineage_id: + if syft_client_verify_key is not None: + action_object.syft_client_verify_key = syft_client_verify_key + + if syft_node_location is not None: + action_object.syft_node_location = syft_node_location + + if syft_lineage_id is not None: action_object.id = syft_lineage_id.id action_object.syft_history_hash = syft_lineage_id.syft_history_hash - elif id: + elif id is not None: action_object.syft_history_hash = hash(id) return action_object @@ -823,6 +1022,9 @@ def remove_trace_hook(cls): return True # self._syft_pre_hooks__[HOOK_ALWAYS].pop(trace_action_side_effct, None) + def as_empty_data(self) -> ActionDataEmpty: + return ActionDataEmpty(syft_internal_type=self.syft_internal_type) + @staticmethod def empty( syft_internal_type: Type[Any] = NoneType, @@ -842,15 +1044,11 @@ def empty( empty = ActionDataEmpty(syft_internal_type=syft_internal_type) res = ActionObject.from_obj( - syft_action_data=empty, id=id, syft_lineage_id=syft_lineage_id + id=id, syft_lineage_id=syft_lineage_id, syft_action_data=empty ) res.__dict__["syft_internal_type"] = syft_internal_type return res - def delete_data(self): - empty = ActionDataEmpty(syft_internal_type=self.syft_internal_type) - self.syft_action_data = empty - def __post_init__(self) -> None: """Add pre/post hooks.""" if HOOK_ALWAYS not in self._syft_pre_hooks__: @@ -881,8 +1079,8 @@ def __post_init__(self) -> None: if side_effect not in self._syft_post_hooks__[HOOK_ALWAYS]: self._syft_post_hooks__[HOOK_ALWAYS].append(side_effect) - if isinstance(self.syft_action_data, ActionObject): - raise Exception("Nested ActionObjects", self.syft_action_data) + if isinstance(self.syft_action_data_type, ActionObject): + raise Exception("Nested ActionObjects", self.syft_action_data_repr_) self.syft_history_hash = hash(self.id) @@ -966,8 +1164,12 @@ def _syft_output_action_object( syft_twin_type = TwinMode.NONE if context.result_twin_type is not None: syft_twin_type = context.result_twin_type - result = constructor(syft_action_data=result, syft_twin_type=syft_twin_type) - + result = constructor( + syft_twin_type=syft_twin_type, + syft_action_data_cache=result, + syft_node_location=self.syft_node_location, + syft_client_verify_key=self.syft_client_verify_key, + ) return result def _syft_passthrough_attrs(self) -> List[str]: @@ -1010,6 +1212,17 @@ def _syft_attr_propagate_ids(self, context, name: str, result: Any) -> Any: result.syft_node_location = context.syft_node_location result.syft_client_verify_key = context.syft_client_verify_key + # Propogate Syft blob storage entry id + object_attrs = [ + "syft_blob_storage_entry_id", + "syft_action_data_repr_", + "syft_action_data_str_", + "syft_action_data_type", + ] + for attr_name in object_attrs: + attr_value = getattr(context.obj, attr_name, None) + setattr(result, attr_name, attr_value) + # Propagate Result ID if context.result_id is not None: result.id = context.result_id @@ -1023,7 +1236,7 @@ def _syft_wrap_attribute_for_bool_on_nonbools(self, name: str) -> Any: "[_wrap_attribute_for_bool_on_nonbools] Use this only for the __bool__ operator" ) - if hasattr(self.syft_action_data, "__bool__"): + if self.syft_has_bool_attr: raise RuntimeError( "[_wrap_attribute_for_bool_on_nonbools] self.syft_action_data already implements the bool operator" ) @@ -1038,7 +1251,14 @@ def _syft_wrap_attribute_for_bool_on_nonbools(self, name: str) -> Any: context, _, _ = self._syft_run_pre_hooks__(context, name, (), {}) # no input needs to propagate - result = self._syft_run_post_hooks__(context, name, bool(self.syft_action_data)) + result = self._syft_run_post_hooks__( + context, + name, + any( + x is not None + for x in (self.syft_blob_storage_entry_id, self.syft_action_data_cache) + ), + ) result = self._syft_attr_propagate_ids(context, name, result) def __wrapper__bool__() -> bool: @@ -1080,7 +1300,7 @@ def fake_func(*args: Any, **kwargs: Any) -> Any: debug(f"[__getattribute__] Handling method {name} ") if ( - isinstance(self.syft_action_data, ActionDataEmpty) + issubclass(self.syft_action_data_type, ActionDataEmpty) and name not in action_data_empty_must_run ): original_func = fake_func @@ -1139,6 +1359,7 @@ def wrapper(_self: Any, *args: Any, **kwargs: Any): except Exception: debug("name", name, "has no signature") + # third party return wrapper def _syft_setattr(self, name, value): @@ -1149,9 +1370,9 @@ def _syft_setattr(self, name, value): def fake_func(*args: Any, **kwargs: Any) -> Any: return ActionDataEmpty(syft_internal_type=self.syft_internal_type) - if isinstance(self.syft_action_data, ActionDataEmpty) or has_action_data_empty( - args=args, kwargs=kwargs - ): + if isinstance( + self.syft_action_data_type, ActionDataEmpty + ) or has_action_data_empty(args=args, kwargs=kwargs): local_func = fake_func else: local_func = getattr(self.syft_action_data, op_name) @@ -1201,7 +1422,7 @@ def __getattribute__(self, name: str) -> Any: context_self = self._syft_get_attr_context(name) # Handle bool operator on nonbools - if name == "__bool__" and not hasattr(self.syft_action_data, "__bool__"): + if name == "__bool__" and not self.syft_has_bool_attr: return self._syft_wrap_attribute_for_bool_on_nonbools(name) # Handle Properties @@ -1209,7 +1430,8 @@ def __getattribute__(self, name: str) -> Any: return self._syft_wrap_attribute_for_properties(name) # Handle anything else - return self._syft_wrap_attribute_for_methods(name) + res = self._syft_wrap_attribute_for_methods(name) + return res def __setattr__(self, name: str, value: Any) -> Any: defined_on_self = name in self.__dict__ or name in self.__private_attributes__ @@ -1225,8 +1447,8 @@ def __setattr__(self, name: str, value: Any) -> Any: context_self = self.syft_action_data # type: ignore return context_self.__setattr__(name, value) - def keys(self) -> KeysView[str]: - return self.syft_action_data.keys() # type: ignore + # def keys(self) -> KeysView[str]: + # return self.syft_action_data.keys() # type: ignore ###### __DUNDER_MIFFLIN__ @@ -1240,13 +1462,20 @@ def _repr_markdown_(self) -> str: res = "TwinPointer(Real)" elif not self.is_twin: res = "Pointer" - child_repr = ( - self.syft_action_data._repr_markdown_() - if hasattr(self.syft_action_data, "_repr_markdown_") - else self.syft_action_data.__repr__() - ) - return f"```python\n{res}\n```\n{child_repr}" + if isinstance(self.syft_action_data_cache, ActionDataEmpty): + data_repr_ = self.syft_action_data_repr_ + else: + if inspect.isclass(self.syft_action_data_cache): + data_repr_ = repr_cls(self.syft_action_data_cache) + else: + data_repr_ = ( + self.syft_action_data_cache._repr_markdown_() + if hasattr(self.syft_action_data_cache, "_repr_markdown_") + else self.syft_action_data_cache.__repr__() + ) + + return f"```python\n{res}\n```\n{data_repr_}" def __repr__(self) -> str: if self.is_mock: @@ -1255,13 +1484,23 @@ def __repr__(self) -> str: res = "TwinPointer(Real)" if not self.is_twin: res = "Pointer" - return f"{res}:\n{str(self.syft_action_data)}" + if isinstance(self.syft_action_data_cache, ActionDataEmpty): + data_repr_ = self.syft_action_data_repr_ + else: + if inspect.isclass(self.syft_action_data_cache): + data_repr_ = repr_cls(self.syft_action_data_cache) + else: + data_repr_ = self.syft_action_data_cache.__repr__() + return f"{res}:\n{data_repr_}" def __call__(self, *args: Any, **kwds: Any) -> Any: return self.__call__(*args, **kwds) def __str__(self) -> str: - return self.__str__() + if not inspect.isclass: + return self.__str__() + else: + return self.syft_action_data_str_ def __len__(self) -> int: return self.__len__() @@ -1293,6 +1532,9 @@ def __matmul__(self, other: Any) -> Any: def __eq__(self, other: Any) -> Any: return self._syft_output_action_object(self.__eq__(other)) + def __ne__(self, other: Any) -> Any: + return self._syft_output_action_object(self.__ne__(other)) + def __lt__(self, other: Any) -> Any: return self._syft_output_action_object(self.__lt__(other)) @@ -1437,8 +1679,8 @@ def debug_original_func(name: str, func: Callable) -> None: def is_action_data_empty(obj: Any) -> bool: - return isinstance(obj, AnyActionObject) and isinstance( - obj.syft_action_data, ActionDataEmpty + return isinstance(obj, AnyActionObject) and issubclass( + obj.syft_action_data_type, ActionDataEmpty ) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 1819b4d0ce4..149a896e5fd 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -13,8 +13,10 @@ # relative from ...serde.serializable import serializable +from ...types.datetime import DateTime from ...types.twin_object import TwinObject from ...types.uid import UID +from ..blob_storage.service import BlobStorageService from ..code.user_code import UserCode from ..code.user_code import execute_byte_code from ..context import AuthedServiceContext @@ -53,8 +55,16 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: if not isinstance(data, np.ndarray): data = np.array(data) np_obj = NumpyArrayObject( - syft_action_data=data, dtype=data.dtype, shape=data.shape + dtype=data.dtype, + shape=data.shape, + syft_action_data_cache=data, + syft_node_location=context.node.id, + syft_client_verify_key=context.credentials, ) + blob_store_result = np_obj._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + np_pointer = self.set(context, np_obj) return np_pointer @@ -71,10 +81,22 @@ def set( ) -> Result[ActionObject, str]: """Save an object to the action store""" # 🟡 TODO 9: Create some kind of type checking / protocol for SyftSerializable + + if isinstance(action_object, ActionObject): + action_object.syft_created_at = DateTime.now() + else: + action_object.private_obj.syft_created_at = DateTime.now() + action_object.mock_obj.syft_created_at = DateTime.now() + + has_result_read_permission = context.extra_kwargs.get( + "has_result_read_permission", False + ) + result = self.store.set( uid=action_object.id, credentials=context.credentials, syft_object=action_object, + has_result_read_permission=has_result_read_permission, ) if result.is_ok(): if isinstance(action_object, TwinObject): @@ -83,23 +105,6 @@ def set( return Ok(action_object) return result.err() - @service_method(path="action.save", name="save") - def save( - self, - context: AuthedServiceContext, - action_object: Union[ActionObject, TwinObject], - ) -> Result[SyftSuccess, str]: - """Save an object to the action store""" - # 🟡 TODO 9: Create some kind of type checking / protocol for SyftSerializable - result = self.store.set( - uid=action_object.id, - credentials=context.credentials, - syft_object=action_object, - ) - if result.is_ok(): - return Ok(SyftSuccess(message=f"{type(action_object)} saved")) - return result.err() - @service_method(path="action.get", name="get", roles=GUEST_ROLE_LEVEL) def get( self, @@ -122,7 +127,11 @@ def _get( uid=uid, credentials=context.credentials, has_permission=has_permission ) if result.is_ok(): - obj = result.ok() + obj: Union[TwinObject, ActionObject] = result.ok() + obj._set_obj_location_( + context.node.id, + context.credentials, + ) if isinstance(obj, TwinObject): if twin_mode == TwinMode.PRIVATE: obj = obj.private @@ -148,6 +157,11 @@ def get_pointer( uid=uid, credentials=context.credentials, node_uid=context.node.id ) if result.is_ok(): + obj = result.ok() + obj._set_obj_location_( + context.node.id, + context.credentials, + ) return Ok(result.ok()) return Err(result.err()) @@ -165,6 +179,17 @@ def _user_code_execute( if filtered_kwargs.is_err(): return filtered_kwargs filtered_kwargs = filtered_kwargs.ok() + + expected_input_kwargs = set() + for _inp_kwarg in code_item.input_policy.inputs.values(): + expected_input_kwargs.update(_inp_kwarg.keys()) + permitted_input_kwargs = list(filtered_kwargs.keys()) + not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs) + if len(not_approved_kwargs) > 0: + return Err( + f"Input arguments: {not_approved_kwargs} to the function are not approved yet." + ) + has_twin_inputs = False real_kwargs = {} @@ -207,16 +232,26 @@ def _user_code_execute( except Exception as e: return Err(f"_user_code_execute failed. {e}") - set_result = self.store.set( - uid=result_id, - credentials=context.credentials, - syft_object=result_action_object, - has_result_read_permission=True, + result_action_object._set_obj_location_( + context.node.id, + context.credentials, ) + blob_store_result = result_action_object._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + + # pass permission information to the action store as extra kwargs + context.extra_kwargs = {"has_result_read_permission": True} + + set_result = self.set(context, result_action_object) if set_result.is_err(): return set_result.err() + blob_storage_service: BlobStorageService = context.node.get_service( + BlobStorageService + ) + if len(code_item.output_policy.output_readers) > 0: self.store.add_permissions( [ @@ -224,8 +259,18 @@ def _user_code_execute( for x in code_item.output_policy.output_readers ] ) + blob_storage_service.stash.add_permissions( + [ + ActionObjectPermission( + result_action_object.syft_blob_storage_entry_id, + ActionPermission.READ, + x, + ) + for x in code_item.output_policy.output_readers + ] + ) - return Ok(result_action_object) + return set_result def execute_plan( self, plan, context: AuthedServiceContext, plan_kwargs: Dict[str, ActionObject] @@ -250,7 +295,7 @@ def execute_plan( for plan_action in plan.actions: action_res = self.execute(context, plan_action) - if action_res.is_err(): + if isinstance(action_res, SyftError): return action_res result_id = plan.outputs[0].id return self._get(context, result_id, TwinMode.MOCK, has_permission=True) @@ -298,7 +343,7 @@ def set_attribute( # depending on permisisons? public_args = filter_twin_args(args, twin_mode=TwinMode.MOCK) public_val = public_args[0] - setattr(resolved_self.mock, name, public_val) + setattr(resolved_self.mock.syft_action_data, name, public_val) return Ok( TwinObject( id=action.result_id, @@ -394,6 +439,7 @@ def execute( if action.action_type == ActionType.CREATEOBJECT: result_action_object = Ok(action.create_object) + # print(action.create_object, "already in blob storage") elif action.action_type == ActionType.FUNCTION: result_action_object = self.call_function(context, action) else: @@ -408,9 +454,7 @@ def execute( f"Failed executing action {action}, could not resolve self: {resolved_self.err()}" ) resolved_self = resolved_self.ok() - if action.op == "__call__" and isinstance( - resolved_self.syft_action_data, Plan - ): + if action.op == "__call__" and resolved_self.syft_action_data_type == Plan: result_action_object = self.execute_plan( plan=resolved_self.syft_action_data, context=context, @@ -440,24 +484,27 @@ def execute( context, action ) - set_result = self.store.set( - uid=action.result_id, - credentials=context.credentials, - syft_object=result_action_object, - has_result_read_permission=has_result_read_permission, + result_action_object._set_obj_location_( + context.node.id, + context.credentials, ) + + blob_store_result = result_action_object._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + + # pass permission information to the action store as extra kwargs + context.extra_kwargs = { + "has_result_read_permission": has_result_read_permission + } + + set_result = self.set(context, result_action_object) if set_result.is_err(): return Err( f"Failed executing action {action}, set result is an error: {set_result.err()}" ) - if isinstance(result_action_object, TwinObject): - result_action_object = result_action_object.mock - # we patch this on the object, because this is the thing we are getting back - result_action_object.id = action.result_id - result_action_object.syft_point_to(context.node.id) - - return Ok(result_action_object) + return set_result def has_read_permission_for_action_result( self, context: AuthedServiceContext, action: Action @@ -594,12 +641,13 @@ def execute_object( twin_mode: TwinMode = TwinMode.NONE, ) -> Result[Ok[Union[TwinObject, ActionObject]], Err[str]]: unboxed_resolved_self = resolved_self.syft_action_data - args, has_arg_twins = resolve_action_args(action, context, service) + _args, has_arg_twins = resolve_action_args(action, context, service) + kwargs, has_kwargs_twins = resolve_action_kwargs(action, context, service) - if args.is_err(): - return args + if _args.is_err(): + return _args else: - args = args.ok() + args = _args.ok() if kwargs.is_err(): return kwargs else: @@ -620,15 +668,15 @@ def execute_object( result_action_object = wrap_result(action.result_id, result) elif twin_mode == TwinMode.NONE and has_twin_inputs: # self isn't a twin but one of the inputs is - private_args = filter_twin_args(args, twin_mode=twin_mode) - private_kwargs = filter_twin_kwargs(kwargs, twin_mode=twin_mode) + private_args = filter_twin_args(args, twin_mode=TwinMode.PRIVATE) + private_kwargs = filter_twin_kwargs(kwargs, twin_mode=TwinMode.PRIVATE) private_result = target_method(*private_args, **private_kwargs) result_action_object_private = wrap_result( action.result_id, private_result ) - mock_args = filter_twin_args(args, twin_mode=twin_mode) - mock_kwargs = filter_twin_kwargs(kwargs, twin_mode=twin_mode) + mock_args = filter_twin_args(args, twin_mode=TwinMode.MOCK) + mock_kwargs = filter_twin_kwargs(kwargs, twin_mode=TwinMode.MOCK) mock_result = target_method(*mock_args, **mock_kwargs) result_action_object_mock = wrap_result(action.result_id, mock_result) @@ -666,7 +714,7 @@ def execute_object( def wrap_result(result_id: UID, result: Any) -> ActionObject: # 🟡 TODO 11: Figure out how we want to store action object results action_type = action_type_for_type(result) - result_action_object = action_type(id=result_id, syft_action_data=result) + result_action_object = action_type(id=result_id, syft_action_data_cache=result) return result_action_object diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index f9006abd3f2..ef7dd308fc3 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -226,9 +226,8 @@ def remove_permission(self, permission: ActionObjectPermission): self.permissions[permission.uid] = permissions def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: - results = [] for permission in permissions: - results.append(self.add_permission(permission)) + self.add_permission(permission) @serializable() diff --git a/packages/syft/src/syft/service/action/numpy.py b/packages/syft/src/syft/service/action/numpy.py index c5b6e044e39..3c19aa61bc2 100644 --- a/packages/syft/src/syft/service/action/numpy.py +++ b/packages/syft/src/syft/service/action/numpy.py @@ -58,9 +58,6 @@ class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): # ) # return self == other - def __bool__(self) -> bool: - return bool(self.all()) - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): inputs = tuple( np.array(x.syft_action_data, dtype=x.dtype) @@ -72,12 +69,12 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): result = getattr(ufunc, method)(*inputs, **kwargs) if type(result) is tuple: return tuple( - NumpyArrayObject(syft_action_data=x, dtype=x.dtype, shape=x.shape) + NumpyArrayObject(syft_action_data_cache=x, dtype=x.dtype, shape=x.shape) for x in result ) else: return NumpyArrayObject( - syft_action_data=result, dtype=result.dtype, shape=result.shape + syft_action_data_cache=result, dtype=result.dtype, shape=result.shape ) diff --git a/packages/syft/src/syft/service/action/pandas.py b/packages/syft/src/syft/service/action/pandas.py index cac27b94f12..cd669ff1425 100644 --- a/packages/syft/src/syft/service/action/pandas.py +++ b/packages/syft/src/syft/service/action/pandas.py @@ -1,7 +1,6 @@ # stdlib from typing import Any from typing import ClassVar -from typing import Optional from typing import Type # third party @@ -50,7 +49,7 @@ class PandasSeriesObject(ActionObject): syft_internal_type = Series syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - name: Optional[str] = None + # name: Optional[str] = None # syft_dont_wrap_attrs = ["shape"] def __getattribute__(self, name: str) -> Any: diff --git a/packages/syft/src/syft/service/action/plan.py b/packages/syft/src/syft/service/action/plan.py index e1add8949d5..298f34693bc 100644 --- a/packages/syft/src/syft/service/action/plan.py +++ b/packages/syft/src/syft/service/action/plan.py @@ -58,6 +58,7 @@ def __call__(self, *args, **kwargs): def planify(func): TraceResult.reset() ActionObject.add_trace_hook() + TraceResult.is_tracing = True worker = Worker.named(name="plan_building", reset=True, processes=0) client = worker.root_client TraceResult._client = client @@ -69,7 +70,12 @@ def planify(func): actions = TraceResult.result TraceResult.reset() code = inspect.getsource(func) + for a in actions: + if a.create_object is not None: + # warmup cache + a.create_object.syft_action_data # noqa: B018 plan = Plan(inputs=plan_kwargs, actions=actions, outputs=outputs, code=code) + TraceResult.is_tracing = False return ActionObject.from_obj(plan) diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 542b4068f1f..30bf7fbee98 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -6,10 +6,13 @@ # relative from ...serde.serializable import serializable -from ...store.blob_storage import BlobDeposit from ...store.blob_storage import BlobRetrieval +from ...store.blob_storage.on_disk import OnDiskBlobDeposit +from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ...store.document_store import DocumentStore +from ...store.document_store import UIDPartitionKey from ...types.blob_storage import BlobStorageEntry +from ...types.blob_storage import BlobStorageMetadata from ...types.blob_storage import CreateBlobStorageEntry from ...types.uid import UID from ..context import AuthedServiceContext @@ -18,8 +21,11 @@ from ..service import AbstractService from ..service import TYPE_TO_SERVICE from ..service import service_method +from ..user.user_roles import GUEST_ROLE_LEVEL from .stash import BlobStorageStash +BlobDepositType = Union[OnDiskBlobDeposit, SeaweedFSBlobDeposit] + @serializable() class BlobStorageService(AbstractService): @@ -48,24 +54,50 @@ def get_blob_storage_entry_by_uid( return result.ok() return SyftError(message=result.err()) - @service_method(path="blob_storage.read", name="read") + @service_method(path="blob_storage.get_metadata", name="get_metadata") + def get_blob_storage_metadata_by_uid( + self, context: AuthedServiceContext, uid: UID + ) -> Union[BlobStorageEntry, SyftError]: + result = self.stash.get_by_uid(context.credentials, uid=uid) + if result.is_ok(): + blob_storage_entry = result.ok() + return blob_storage_entry.to(BlobStorageMetadata) + return SyftError(message=result.err()) + + @service_method( + path="blob_storage.read", + name="read", + roles=GUEST_ROLE_LEVEL, + ) def read( self, context: AuthedServiceContext, uid: UID ) -> Union[BlobRetrieval, SyftError]: result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): - with context.node.blob_storage_client as conn: - return conn.read(result.ok().location) + obj: BlobStorageEntry = result.ok() + if obj is None: + return SyftError(message=f"No blob storage entry exists for uid: {uid}") + + with context.node.blob_storage_client.connect() as conn: + return conn.read(obj.location, obj.type_) return SyftError(message=result.err()) - @service_method(path="blob_storage.allocate", name="allocate") + @service_method( + path="blob_storage.allocate", + name="allocate", + roles=GUEST_ROLE_LEVEL, + ) def allocate( self, context: AuthedServiceContext, obj: CreateBlobStorageEntry - ) -> Union[BlobDeposit, SyftError]: - with context.node.blob_storage_client as conn: + ) -> Union[BlobDepositType, SyftError]: + with context.node.blob_storage_client.connect() as conn: secure_location = conn.allocate(obj) + if isinstance(secure_location, SyftError): + return secure_location + blob_storage_entry = BlobStorageEntry( + id=obj.id, location=secure_location, type_=obj.type_, mimetype=obj.mimetype, @@ -79,7 +111,11 @@ def allocate( return SyftError(message=f"{result.err()}") return blob_deposit - @service_method(path="blob_storage.write_to_disk", name="write_to_disk") + @service_method( + path="blob_storage.write_to_disk", + name="write_to_disk", + roles=GUEST_ROLE_LEVEL, + ) def write_to_disk( self, context: AuthedServiceContext, uid: UID, data: bytes ) -> Union[SyftSuccess, SyftError]: @@ -101,5 +137,59 @@ def write_to_disk( except Exception as e: return SyftError(message=f"Failed to write object to disk: {e}") + @service_method( + path="blob_storage.mark_write_complete", + name="mark_write_complete", + roles=GUEST_ROLE_LEVEL, + ) + def mark_write_complete( + self, + context: AuthedServiceContext, + uid: UID, + etags: List, + ) -> Union[SyftError, SyftSuccess]: + result = self.stash.get_by_uid( + credentials=context.credentials, + uid=uid, + ) + if result.is_err(): + return SyftError(message=f"{result.err()}") + + obj: Optional[BlobStorageEntry] = result.ok() + + if obj is None: + return SyftError(message=f"No blob storage entry exists for uid: {uid}") + + with context.node.blob_storage_client.connect() as conn: + result = conn.complete_multipart_upload(obj, etags) + + return result + + @service_method(path="blob_storage.delete", name="delete") + def delete( + self, context: AuthedServiceContext, uid: UID + ) -> Union[SyftSuccess, SyftError]: + result = self.stash.get_by_uid(context.credentials, uid=uid) + if result.is_ok(): + obj = result.ok() + + if obj is None: + return SyftError(message=f"No blob storage entry exists for uid: {uid}") + try: + with context.node.blob_storage_client.connect() as conn: + file_unlinked_result = conn.delete(obj.location) + except Exception as e: + return SyftError(message=f"Failed to delete file: {e}") + + if isinstance(file_unlinked_result, SyftError): + return file_unlinked_result + blob_storage_entry_deleted = self.stash.delete( + context.credentials, UIDPartitionKey.with_obj(uid), has_permission=True + ) + if blob_storage_entry_deleted.is_ok(): + return file_unlinked_result + + return SyftError(message=result.err()) + TYPE_TO_SERVICE[BlobStorageEntry] = BlobStorageEntry diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 3d9fb656134..e4a9eba500b 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -189,7 +189,13 @@ def get_results( connection=connection, credentials=context.node.signing_key, ) - return enclave_client.code.get_results(code.id) + outputs = enclave_client.code.get_results(code.id) + if isinstance(outputs, list): + for output in outputs: + output.syft_action_data # noqa: B018 + else: + outputs.syft_action_data # noqa: B018 + return outputs # if the current node is the enclave else: diff --git a/packages/syft/src/syft/service/context.py b/packages/syft/src/syft/service/context.py index 2b9135e4f3a..5f8613ad3d0 100644 --- a/packages/syft/src/syft/service/context.py +++ b/packages/syft/src/syft/service/context.py @@ -1,4 +1,5 @@ # stdlib +from typing import Dict from typing import List from typing import Optional @@ -32,10 +33,19 @@ class AuthedServiceContext(NodeServiceContext): credentials: SyftVerifyKey role: ServiceRole = ServiceRole.NONE + extra_kwargs: Dict = {} def capabilities(self) -> List[ServiceRoleCapability]: return ROLE_TO_CAPABILITIES.get(self.role, []) + def with_credentials(self, credentials: SyftVerifyKey, role: ServiceRole): + return AuthedServiceContext(credentials=credentials, role=role, node=self.node) + + def as_root_context(self): + return AuthedServiceContext( + credentials=self.node.verify_key, role=ServiceRole.ADMIN, node=self.node + ) + class UnauthedServiceContext(NodeServiceContext): __canonical_name__ = "UnauthedServiceContext" diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 1ea2866c929..8650a37a667 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -361,7 +361,7 @@ def no_mock(self) -> None: # relative from ..action.action_object import ActionObject - self.mock = ActionObject.empty() + self.set_mock(ActionObject.empty(), False) def set_shape(self, shape: Tuple) -> None: self.shape = shape diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index df931ccaea4..c2050b9bf80 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -10,6 +10,7 @@ from ...service.response import SyftSuccess from ...service.user.user_roles import GUEST_ROLE_LEVEL from ...store.document_store import DocumentStore +from ...types.twin_object import TwinObject from ...types.uid import UID from ..action.action_object import ActionObject from ..code.user_code_service import UserCode @@ -73,31 +74,22 @@ def send_user_code_inputs_to_enclave( if isinstance(user_code_update, SyftError): return user_code_update + root_context = context.as_root_context() if not action_service.exists(context=context, obj_id=user_code_id): dict_object = ActionObject.from_obj({}) dict_object.id = user_code_id dict_object[str(context.credentials)] = inputs + root_context.extra_kwargs = {"has_result_read_permission": True} # TODO: Instead of using the action store, modify to # use the action service directly to store objects - action_service.store.set( - uid=user_code_id, - credentials=context.node.verify_key, - syft_object=dict_object, - has_result_read_permission=True, - ) + action_service.set(root_context, dict_object) else: - res = action_service.store.get( - uid=user_code_id, credentials=context.node.verify_key - ) + res = action_service.get(uid=user_code_id, context=root_context) if res.is_ok(): dict_object = res.ok() dict_object[str(context.credentials)] = inputs - action_service.store.set( - uid=user_code_id, - credentials=context.node.verify_key, - syft_object=dict_object, - ) + action_service.set(root_context, dict_object) else: return SyftError( message=f"Error while fetching the object on Enclave: {res.err()}" @@ -163,6 +155,18 @@ def propagate_inputs_to_enclave(user_code: UserCode, context: ChangeContext): if isinstance(inputs, SyftError): return inputs + # Save inputs to blob store + for var_name, var_value in inputs.items(): + if isinstance(var_value, (ActionObject, TwinObject)): + # Set the obj location to enclave + var_value._set_obj_location_( + enclave_client.api.node_uid, + enclave_client.verify_key, + ) + var_value._save_to_blob_storage() + + inputs[var_name] = var_value + # send data of the current node to enclave res = send_method( user_code_id=user_code.id, diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index aa9a4b62c9c..ee8fe02843b 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -503,6 +503,7 @@ def http_connection_to_node_route() -> List[Callable]: def get_python_node_route(context: TransformContext) -> TransformContext: context.output["id"] = context.obj.node.id + print("Store config....", context.obj.node.blob_store_config) context.output["worker_settings"] = WorkerSettings.from_node(context.obj.node) context.output["proxy_target_uid"] = context.obj.proxy_target_uid return context diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 0fdd2ba1ebd..af2be77fd5f 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -194,14 +194,23 @@ def _inputs_for_context(self, context: ChangeContext): user_node_view = NodeIdentity.from_change_context(context) inputs = self.inputs[user_node_view] + root_context = AuthedServiceContext( + node=context.node, credentials=context.approving_user_credentials + ).as_root_context() + action_service = context.node.get_service("actionservice") for var_name, uid in inputs.items(): - action_object = action_service.store.get( - uid=uid, credentials=user_node_view.verify_key - ) + action_object = action_service.get(uid=uid, context=root_context) if action_object.is_err(): return SyftError(message=action_object.err()) - inputs[var_name] = action_object.ok() + action_object_value = action_object.ok() + # resolve syft action data from blob store + if isinstance(action_object_value, TwinObject): + action_object_value.private_obj.syft_action_data # noqa: B018 + action_object_value.mock_obj.syft_action_data # noqa: B018 + elif isinstance(action_object_value, ActionObject): + action_object_value.syft_action_data # noqa: B018 + inputs[var_name] = action_object_value return inputs diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index e22062bf90c..e171db8ca41 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -77,6 +77,7 @@ def handle_message(message: bytes): signing_key=worker_settings.signing_key, document_store_config=worker_settings.document_store_config, action_store_config=worker_settings.action_store_config, + blob_storage_config=worker_settings.blob_store_config, is_subprocess=True, ) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index b133d339534..2c9b5741fd1 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -30,6 +30,7 @@ from ...types.transforms import add_node_uid_for_key from ...types.transforms import generate_id from ...types.transforms import transform +from ...types.twin_object import TwinObject from ...types.uid import LineageID from ...types.uid import UID from ...util import options @@ -41,6 +42,7 @@ from ..action.action_service import ActionService from ..action.action_store import ActionObjectPermission from ..action.action_store import ActionPermission +from ..blob_storage.service import BlobStorageService from ..code.user_code import UserCode from ..code.user_code import UserCodeStatus from ..context import AuthedServiceContext @@ -97,29 +99,63 @@ def _run( self, context: ChangeContext, apply: bool ) -> Result[SyftSuccess, SyftError]: try: - action_service = context.node.get_service(ActionService) + action_service: ActionService = context.node.get_service(ActionService) + blob_storage_service = context.node.get_service(BlobStorageService) action_store = action_service.store # can we ever have a lineage ID in the store? obj_uid = self.linked_obj.object_uid obj_uid = obj_uid.id if isinstance(obj_uid, LineageID) else obj_uid + action_obj = action_store.get( + uid=obj_uid, + credentials=context.approving_user_credentials, + ) + + if action_obj.is_err(): + return Err(SyftError(message=f"{action_obj.err()}")) + + action_obj = action_obj.ok() + owner_permission = ActionObjectPermission( uid=obj_uid, credentials=context.approving_user_credentials, permission=self.apply_permission_type, ) if action_store.has_permission(permission=owner_permission): - requesting_permission = ActionObjectPermission( - uid=obj_uid, + id_action = ( + action_obj.id + if not isinstance(action_obj.id, LineageID) + else action_obj.id.id + ) + requesting_permission_action_obj = ActionObjectPermission( + uid=id_action, + credentials=context.requesting_user_credentials, + permission=self.apply_permission_type, + ) + if isinstance(action_obj, TwinObject): + uid_blob = action_obj.private.syft_blob_storage_entry_id + else: + uid_blob = action_obj.syft_blob_storage_entry_id + requesting_permission_blob_obj = ActionObjectPermission( + uid=uid_blob, credentials=context.requesting_user_credentials, permission=self.apply_permission_type, ) if apply: - action_store.add_permission(requesting_permission) + action_store.add_permission(requesting_permission_action_obj) + blob_storage_service.stash.add_permission( + requesting_permission_blob_obj + ) else: - if action_store.has_permission(requesting_permission): - action_store.remove_permission(requesting_permission) + if action_store.has_permission(requesting_permission_action_obj): + action_store.remove_permission(requesting_permission_action_obj) + if blob_storage_service.stash.has_permission( + requesting_permission_blob_obj + ): + blob_storage_service.stash.remove_permission( + requesting_permission_blob_obj + ) else: return Err( SyftError( @@ -128,7 +164,7 @@ def _run( ) return Ok(SyftSuccess(message=f"{type(self)} Success")) except Exception as e: - print(f"failed to apply {type(self)}") + print(f"failed to apply {type(self)}", e) return Err(SyftError(message=str(e))) def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: @@ -442,19 +478,34 @@ def accept_by_depositing_result(self, result: Any, force: bool = False): message="Already approved, if you want to force updating the result use force=True" ) action_obj_id = state.output_history[0].outputs[0] - action_object = ActionObject.from_obj(result, id=action_obj_id) - result = api.services.action.save(action_object) - if not result: + action_object = ActionObject.from_obj( + result, + id=action_obj_id, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) + blob_store_result = action_object._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + result = api.services.action.set(action_object) + if isinstance(result, SyftError): return result return SyftSuccess(message="Request submitted for updating result.") else: - action_object = ActionObject.from_obj(result) - result = api.services.action.save(action_object) - if not result: + action_object = ActionObject.from_obj( + result, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) + blob_store_result = action_object._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + result = api.services.action.set(action_object) + if isinstance(result, SyftError): return result ctx = AuthedServiceContext(credentials=api.signing_key.verify_key) - state.apply_output(context=ctx, outputs=action_object) + state.apply_output(context=ctx, outputs=result) policy_state_mutation = ObjectMutation( linked_obj=change.linked_obj, attr_name="output_policy", @@ -462,9 +513,7 @@ def accept_by_depositing_result(self, result: Any, force: bool = False): value=state, ) - action_object_link = LinkedObject.from_obj( - action_object, node_uid=self.node_uid - ) + action_object_link = LinkedObject.from_obj(result, node_uid=self.node_uid) permission_change = ActionStoreChange( linked_obj=action_object_link, apply_permission_type=ActionPermission.READ, @@ -603,7 +652,7 @@ def _run( try: obj = self.linked_obj.resolve_with_context(context) if obj.is_err(): - return SyftError(message=obj.err()) + return Err(SyftError(message=obj.err())) obj = obj.ok() if apply: obj = self.mutate(obj, value=self.value) @@ -792,7 +841,7 @@ def _run( return Err(valid) obj = self.linked_obj.resolve_with_context(context) if obj.is_err(): - return SyftError(message=obj.err()) + return Err(SyftError(message=obj.err())) obj = obj.ok() if apply: res = self.mutate(obj, context, undo=False) diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 49ea8980ec3..c8df9106291 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -215,7 +215,10 @@ def apply( ) send_notification(context=context, notification=notification) - return result.value + # TODO: check whereever we're return SyftError encapsulate it in Result. + if hasattr(result, "value"): + return result.value + return result return request.value @service_method(path="request.undo", name="undo") @@ -236,7 +239,7 @@ def undo( if result.is_err(): return SyftError( - f"Failed to undo Request: <{uid}> with error: {result.err()}" + message=f"Failed to undo Request: <{uid}> with error: {result.err()}" ) link = LinkedObject.with_context(request, context=context) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index e7533a71823..974fe2d02b6 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -1,6 +1,7 @@ # stdlib from collections import defaultdict from copy import deepcopy +from functools import partial import inspect from inspect import Parameter from typing import Any @@ -407,3 +408,53 @@ def get_transform( version_to = type_to.__version__ mapping_string = f"{klass_from}_{version_from}_x_{klass_to}_{version_to}" return cls.__object_transform_registry__[mapping_string] + + +def from_api_or_context( + func_or_path: str, + syft_node_location: Optional[UID] = None, + syft_client_verify_key: Optional[SyftVerifyKey] = None, +): + # relative + from ..client.api import APIRegistry + from ..node.node import AuthNodeContextRegistry + + if callable(func_or_path): + func_or_path = func_or_path.__qualname__ + + if not (syft_node_location and syft_client_verify_key): + return None + + api = APIRegistry.api_for( + node_uid=syft_node_location, + user_verify_key=syft_client_verify_key, + ) + if api is not None: + service_method = api.services + for path in func_or_path.split("."): + service_method = getattr(service_method, path) + return service_method + + node_context = AuthNodeContextRegistry.auth_context_for_user( + node_uid=syft_node_location, + user_verify_key=syft_client_verify_key, + ) + if node_context is not None: + user_config_registry = UserServiceConfigRegistry.from_role( + node_context.role, + ) + if func_or_path not in user_config_registry: + if ServiceConfigRegistry.path_exists(func_or_path): + return SyftError( + message=f"As a `{node_context.role}` you have has no access to: {func_or_path}" + ) + else: + return SyftError( + message=f"API call not in registered services: {func_or_path}" + ) + + _private_api_path = user_config_registry.private_path_for(func_or_path) + service_method = node_context.node.get_service_method( + _private_api_path, + ) + return partial(service_method, node_context) diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 55f8ef64fbf..1189c69eb16 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -112,11 +112,11 @@ class UserUpdate(PartialSyftObject): __version__ = SYFT_OBJECT_VERSION_1 @pydantic.validator("email", pre=True) - def make_email(cls, v: EmailStr) -> Optional[EmailStr]: - return EmailStr(v) if isinstance(v, str) else v + def make_email(cls, v: Any) -> Any: + return EmailStr(v) if isinstance(v, str) and not isinstance(v, EmailStr) else v @pydantic.validator("role", pre=True) - def str_to_role(cls, v: Union[str, ServiceRole]) -> Optional[ServiceRole]: + def str_to_role(cls, v: Any) -> Any: if isinstance(v, str) and hasattr(ServiceRole, v.upper()): return getattr(ServiceRole, v.upper()) return v @@ -222,7 +222,7 @@ def set_email(self, email: str) -> Union[SyftSuccess, SyftError]: try: user_update = UserUpdate(email=email) - except ValidationError as e: # noqa: F841 + except ValidationError: return SyftError(message="{email} is not a valid email address.") result = api.services.user.update(uid=self.id, user_update=user_update) diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index b214313614d..f488147bc61 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -210,7 +210,11 @@ def get_current_user( SyftError(message="User not found!") return SyftError(message=str(result.err())) - @service_method(path="user.update", name="update", roles=GUEST_ROLE_LEVEL) + @service_method( + path="user.update", + name="update", + roles=GUEST_ROLE_LEVEL, + ) def update( self, context: AuthedServiceContext, uid: UID, user_update: UserUpdate ) -> Union[UserView, SyftError]: diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index a6de9989978..561fdd4a6b5 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -42,11 +42,15 @@ # stdlib +from typing import Optional from typing import Type from typing import Union +from urllib.request import urlretrieve # third party from pydantic import BaseModel +import requests +from typing_extensions import Self # relative from ...serde.deserialize import _deserialize as deserialize @@ -54,12 +58,16 @@ from ...service.response import SyftError from ...service.response import SyftSuccess from ...types.base import SyftBaseModel +from ...types.blob_storage import BlobFile +from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry from ...types.blob_storage import CreateBlobStorageEntry from ...types.blob_storage import SecureFilePathLocation +from ...types.grid_url import GridURL from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.uid import UID +from ...util.constants import DEFAULT_TIMEOUT @serializable() @@ -67,7 +75,10 @@ class BlobRetrieval(SyftObject): __canonical_name__ = "BlobRetrieval" __version__ = SYFT_OBJECT_VERSION_1 - def read(self) -> SyftObject: + type_: Optional[Type] + file_name: str + + def read(self) -> Union[SyftObject, SyftError]: pass @@ -78,7 +89,11 @@ class SyftObjectRetrieval(BlobRetrieval): syft_object: bytes - def read(self) -> SyftObject: + def read(self) -> Union[SyftObject, SyftError]: + if self.type_ is BlobFileType: + with open(self.file_name, "wb") as fp: + fp.write(self.syft_object) + return BlobFile(file_name=self.file_name) return deserialize(self.syft_object, from_bytes=True) @@ -87,10 +102,29 @@ class BlobRetrievalByURL(BlobRetrieval): __canonical_name__ = "BlobRetrievalByURL" __version__ = SYFT_OBJECT_VERSION_1 - url: str - - def read(self) -> SyftObject: - pass + url: GridURL + + def read(self) -> Union[SyftObject, SyftError]: + # relative + from ...client.api import APIRegistry + + api = APIRegistry.api_for( + node_uid=self.syft_node_location, + user_verify_key=self.syft_client_verify_key, + ) + if api is not None: + blob_url = api.connection.to_blob_route(self.url.url_path) + else: + blob_url = self.url + try: + if self.type_ is BlobFileType: + urlretrieve(str(blob_url), filename=self.file_name) # nosec + return BlobFile(file_name=self.file_name) + response = requests.get(str(blob_url), timeout=DEFAULT_TIMEOUT) + response.raise_for_status() + return deserialize(response.content, from_bytes=True) + except requests.RequestException as e: + return SyftError(message=f"Failed to retrieve with Error: {e}") @serializable() @@ -110,27 +144,36 @@ class BlobStorageClientConfig(BaseModel): class BlobStorageConnection: - def read(self, fp: SecureFilePathLocation) -> BlobRetrieval: + def __enter__(self) -> Self: + raise NotImplementedError + + def __exit__(self, *exc) -> None: + raise NotImplementedError + + def read(self, fp: SecureFilePathLocation, type_: Optional[Type]) -> BlobRetrieval: raise NotImplementedError - def allocate(self, obj: CreateBlobStorageEntry) -> SecureFilePathLocation: + def allocate( + self, obj: CreateBlobStorageEntry + ) -> Union[SecureFilePathLocation, SyftError]: raise NotImplementedError def write(self, obj: BlobStorageEntry) -> BlobDeposit: raise NotImplementedError + def delete(self, fp: SecureFilePathLocation) -> bool: + raise NotImplementedError + @serializable() class BlobStorageClient(SyftBaseModel): config: BlobStorageClientConfig - def __enter__(self) -> BlobStorageConnection: - raise NotImplementedError - - def __exit__(self, *exc) -> None: + def connect(self) -> BlobStorageConnection: raise NotImplementedError +@serializable() class BlobStorageConfig(SyftBaseModel): client_type: Type[BlobStorageClient] client_config: BlobStorageClientConfig diff --git a/packages/syft/src/syft/store/blob_storage/on_disk.py b/packages/syft/src/syft/store/blob_storage/on_disk.py index 68ced67f0d0..81f990ca6e8 100644 --- a/packages/syft/src/syft/store/blob_storage/on_disk.py +++ b/packages/syft/src/syft/store/blob_storage/on_disk.py @@ -1,12 +1,14 @@ # stdlib +from io import BytesIO from pathlib import Path from tempfile import gettempdir from typing import Any +from typing import Optional from typing import Type from typing import Union # third party -from pydantic import PrivateAttr +from typing_extensions import Self # relative from . import BlobDeposit @@ -30,17 +32,16 @@ class OnDiskBlobDeposit(BlobDeposit): __canonical_name__ = "OnDiskBlobDeposit" __version__ = SYFT_OBJECT_VERSION_1 - def write(self, data: bytes) -> Union[SyftSuccess, SyftError]: + def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: # relative - from ...client.api import APIRegistry + from ...service.service import from_api_or_context - api = APIRegistry.api_for( - node_uid=self.syft_node_location, - user_verify_key=self.syft_client_verify_key, - ) - return api.services.blob_storage.write_to_disk( - data=data, uid=self.blob_storage_entry_id + write_to_disk_method = from_api_or_context( + func_or_path="blob_storage.write_to_disk", + syft_node_location=self.syft_node_location, + syft_client_verify_key=self.syft_client_verify_key, ) + return write_to_disk_method(data=data.read(), uid=self.blob_storage_entry_id) class OnDiskBlobStorageConnection(BlobStorageConnection): @@ -49,19 +50,40 @@ class OnDiskBlobStorageConnection(BlobStorageConnection): def __init__(self, base_directory: Path) -> None: self._base_directory = base_directory - def read(self, fp: SecureFilePathLocation) -> BlobRetrieval: + def __enter__(self) -> Self: + return self + + def __exit__(self, *exc) -> None: + pass + + def read(self, fp: SecureFilePathLocation, type_: Optional[Type]) -> BlobRetrieval: + file_path = self._base_directory / fp.path return SyftObjectRetrieval( - syft_object=(self._base_directory / fp.path).read_bytes() + syft_object=file_path.read_bytes(), + file_name=file_path.name, + type_=type_, ) - def allocate(self, obj: CreateBlobStorageEntry) -> SecureFilePathLocation: - return SecureFilePathLocation( - path=str((self._base_directory / str(obj.id)).absolute()) - ) + def allocate( + self, obj: CreateBlobStorageEntry + ) -> Union[SecureFilePathLocation, SyftError]: + try: + return SecureFilePathLocation( + path=str((self._base_directory / obj.file_name).absolute()) + ) + except Exception as e: + return SyftError(message=f"Failed to allocate: {e}") def write(self, obj: BlobStorageEntry) -> BlobDeposit: return OnDiskBlobDeposit(blob_storage_entry_id=obj.id) + def delete(self, fp: SecureFilePathLocation) -> Union[SyftSuccess, SyftError]: + try: + (self._base_directory / fp.path).unlink() + return SyftSuccess(message="Successfully deleted file.") + except FileNotFoundError as e: + return SyftError(message=f"Failed to delete file: {e}") + @serializable() class OnDiskBlobStorageClientConfig(BlobStorageClientConfig): @@ -71,19 +93,16 @@ class OnDiskBlobStorageClientConfig(BlobStorageClientConfig): @serializable() class OnDiskBlobStorageClient(BlobStorageClient): config: OnDiskBlobStorageClientConfig - _connection: OnDiskBlobStorageConnection = PrivateAttr() def __init__(self, **data: Any): super().__init__(**data) - self._connection = OnDiskBlobStorageConnection(self.config.base_directory) + self.config.base_directory.mkdir(exist_ok=True) - def __enter__(self) -> BlobStorageConnection: - return self._connection - - def __exit__(self, *exc) -> None: - pass + def connect(self) -> BlobStorageConnection: + return OnDiskBlobStorageConnection(self.config.base_directory) +@serializable() class OnDiskBlobStorageConfig(BlobStorageConfig): client_type: Type[BlobStorageClient] = OnDiskBlobStorageClient client_config: OnDiskBlobStorageClientConfig = OnDiskBlobStorageClientConfig() diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 386ea7ff68f..39b8c161bfe 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -1,16 +1,53 @@ # stdlib +from io import BytesIO +import math +from pathlib import Path +from typing import Generator +from typing import List +from typing import Optional +from typing import Type from typing import Union +# third party +import boto3 +from botocore.client import BaseClient as S3BaseClient +from botocore.client import ClientError as BotoClientError +from botocore.client import Config +import requests +from typing_extensions import Self + # relative from . import BlobDeposit +from . import BlobRetrieval +from . import BlobRetrievalByURL from . import BlobStorageClient from . import BlobStorageClientConfig from . import BlobStorageConfig from . import BlobStorageConnection from ...serde.serializable import serializable from ...service.response import SyftError +from ...service.response import SyftException from ...service.response import SyftSuccess +from ...service.service import from_api_or_context +from ...types.blob_storage import BlobStorageEntry +from ...types.blob_storage import CreateBlobStorageEntry +from ...types.blob_storage import SeaweedSecureFilePathLocation +from ...types.blob_storage import SecureFilePathLocation +from ...types.grid_url import GridURL from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...util.constants import DEFAULT_TIMEOUT + +READ_EXPIRATION_TIME = 1800 # seconds +WRITE_EXPIRATION_TIME = 900 # seconds +DEFAULT_CHUNK_SIZE = 1024**2 # 1 GB + + +def _byte_chunks(bytes: BytesIO, size: int) -> Generator[bytes, None, None]: + while True: + try: + yield bytes.read(size) + except BlockingIOError: + return @serializable() @@ -18,26 +55,175 @@ class SeaweedFSBlobDeposit(BlobDeposit): __canonical_name__ = "SeaweedFSBlobDeposit" __version__ = SYFT_OBJECT_VERSION_1 - def write(self, data: bytes) -> Union[SyftSuccess, SyftError]: - pass + urls: List[GridURL] + + def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: + # relative + from ...client.api import APIRegistry + + api = APIRegistry.api_for( + node_uid=self.syft_node_location, + user_verify_key=self.syft_client_verify_key, + ) + + etags = [] + + try: + for part_no, (byte_chunk, url) in enumerate( + zip(_byte_chunks(data, DEFAULT_CHUNK_SIZE), self.urls), + start=1, + ): + if api is not None: + blob_url = api.connection.to_blob_route(url.url_path) + else: + blob_url = url + response = requests.put( + url=str(blob_url), data=byte_chunk, timeout=DEFAULT_TIMEOUT + ) + response.raise_for_status() + etag = response.headers["ETag"] + etags.append({"ETag": etag, "PartNumber": part_no}) + except requests.RequestException as e: + return SyftError(message=str(e)) + + mark_write_complete_method = from_api_or_context( + func_or_path="blob_storage.mark_write_complete", + syft_node_location=self.syft_node_location, + syft_client_verify_key=self.syft_client_verify_key, + ) + return mark_write_complete_method( + etags=etags, + uid=self.blob_storage_entry_id, + ) @serializable() class SeaweedFSClientConfig(BlobStorageClientConfig): - pass + host: str + port: int + access_key: str + secret_key: str + region: str + bucket_name: str + + @property + def endpoint_url(self) -> str: + grid_url = GridURL(host_or_ip=self.host, port=self.port) + return grid_url.url @serializable() class SeaweedFSClient(BlobStorageClient): config: SeaweedFSClientConfig - def __enter__(self) -> BlobStorageConnection: - pass + def connect(self) -> BlobStorageConnection: + return SeaweedFSConnection( + client=boto3.client( + "s3", + endpoint_url=self.config.endpoint_url, + aws_access_key_id=self.config.access_key, + aws_secret_access_key=self.config.secret_key, + config=Config(signature_version="s3v4"), + region_name=self.config.region, + ), + bucket_name=self.config.bucket_name, + ) + + +@serializable() +class SeaweedFSConnection(BlobStorageConnection): + client: S3BaseClient + bucket_name: str + + def __init__(self, client: S3BaseClient, bucket_name: str): + self.client = client + self.bucket_name = bucket_name + + def __enter__(self) -> Self: + return self def __exit__(self, *exc) -> None: - pass + self.client.close() + + def read(self, fp: SecureFilePathLocation, type_: Optional[Type]) -> BlobRetrieval: + try: + url = self.client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket_name, "Key": fp.path}, + ExpiresIn=READ_EXPIRATION_TIME, + ) + + return BlobRetrievalByURL( + url=GridURL.from_url(url), file_name=Path(fp.path).name, type_=type_ + ) + except BotoClientError as e: + raise SyftException(e) + + def allocate( + self, obj: CreateBlobStorageEntry + ) -> Union[SecureFilePathLocation, SyftError]: + try: + file_name = obj.file_name + result = self.client.create_multipart_upload( + Bucket=self.bucket_name, + Key=file_name, + ) + upload_id = result["UploadId"] + return SeaweedSecureFilePathLocation(upload_id=upload_id, path=file_name) + except BotoClientError as e: + return SyftError( + message=f"Failed to allocate space for {obj} with error: {e}" + ) + def write(self, obj: BlobStorageEntry) -> BlobDeposit: + total_parts = math.ceil(obj.file_size / DEFAULT_CHUNK_SIZE) + urls = [ + GridURL.from_url( + self.client.generate_presigned_url( + ClientMethod="upload_part", + Params={ + "Bucket": self.bucket_name, + "Key": obj.location.path, + "UploadId": obj.location.upload_id, + "PartNumber": i + 1, + }, + ExpiresIn=WRITE_EXPIRATION_TIME, + ) + ) + for i in range(total_parts) + ] + + return SeaweedFSBlobDeposit(blob_storage_entry_id=obj.id, urls=urls) + + def complete_multipart_upload( + self, + blob_entry: BlobStorageEntry, + etags: List, + ) -> Union[SyftError, SyftSuccess]: + try: + self.client.complete_multipart_upload( + Bucket=self.bucket_name, + Key=blob_entry.location.path, + MultipartUpload={"Parts": etags}, + UploadId=blob_entry.location.upload_id, + ) + return SyftSuccess(message="Successfully saved file.") + except BotoClientError as e: + return SyftError(message=str(e)) + + def delete( + self, + fp: SecureFilePathLocation, + ) -> Union[SyftSuccess, SyftError]: + try: + self.client.delete_object(Bucket=self.bucket_name, Key=fp.path) + return SyftSuccess(message="Successfully deleted file.") + except BotoClientError as e: + return SyftError(message=str(e)) + + +@serializable() class SeaweedFSConfig(BlobStorageConfig): - client_type = SeaweedFSClient + client_type: Type[BlobStorageClient] = SeaweedFSClient client_config: SeaweedFSClientConfig diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index d5018ca9e9e..98ca622c5bb 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -485,6 +485,18 @@ def _delete(self, qk: QueryKey) -> Result[SyftSuccess, Err]: def _all(self) -> Result[List[BaseStash.object_type], str]: raise NotImplementedError + def add_permission(self, permission: ActionObjectPermission) -> None: + raise NotImplementedError + + def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: + raise NotImplementedError + + def remove_permission(self, permission: ActionObjectPermission) -> None: + raise NotImplementedError + + def has_permission(self, permission: ActionObjectPermission) -> bool: + raise NotImplementedError + @instrument @serializable() @@ -543,6 +555,18 @@ def get_all( ) -> Result[List[BaseStash.object_type], str]: return self.partition.all(credentials, order_by, has_permission) + def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: + self.partition.add_permissions(permissions) + + def add_permission(self, permission: ActionObjectPermission) -> None: + self.partition.add_permission(permission) + + def remove_permission(self, permission: ActionObjectPermission) -> None: + self.partition.remove_permission(permission) + + def has_permission(self, permission: ActionObjectPermission) -> bool: + return self.partition.has_permission(permission=permission) + def __len__(self) -> int: return len(self.partition) @@ -679,9 +703,6 @@ def get_by_uid( qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) return self.query_one(credentials=credentials, qks=qks) - def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: - self.partition.add_permissions(permissions) - def set( self, credentials: SyftVerifyKey, diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 2c280c443d5..d7131f52be2 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -253,9 +253,8 @@ def remove_permission(self, permission: ActionObjectPermission): self.permissions[permission.uid] = permissions def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: - results = [] for permission in permissions: - results.append(self.add_permission(permission)) + self.add_permission(permission) def has_permission(self, permission: ActionObjectPermission) -> bool: if not isinstance(permission.permission, ActionPermission): diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index 376b5c26a27..8ef880dd6d4 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -2,6 +2,7 @@ import mimetypes from pathlib import Path import sys +from typing import List from typing import Optional from typing import Type from typing import Union @@ -11,14 +12,29 @@ # relative from ..node.credentials import SyftVerifyKey +from ..serde import serialize from ..serde.serializable import serializable from ..service.response import SyftException +from ..types.transforms import keep +from ..types.transforms import transform from .datetime import DateTime from .syft_object import SYFT_OBJECT_VERSION_1 from .syft_object import SyftObject from .uid import UID +@serializable() +class BlobFile(SyftObject): + __canonical_name__ = "BlobFile" + __version__ = SYFT_OBJECT_VERSION_1 + + file_name: str + + +class BlobFileType(type): + pass + + @serializable() class SecureFilePathLocation(SyftObject): __canonical_name__ = "SecureFilePathLocation" @@ -27,6 +43,17 @@ class SecureFilePathLocation(SyftObject): id: UID path: str + def __repr__(self) -> str: + return f"{self.path}" + + +@serializable() +class SeaweedSecureFilePathLocation(SecureFilePathLocation): + __canonical_name__ = "SeaweedSecureFilePathLocation" + __version__ = SYFT_OBJECT_VERSION_1 + + upload_id: str + @serializable() class BlobStorageEntry(SyftObject): @@ -34,12 +61,22 @@ class BlobStorageEntry(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 id: UID - location: SecureFilePathLocation - type_: Optional[Type[SyftObject]] + location: Union[SecureFilePathLocation, SeaweedSecureFilePathLocation] + type_: Optional[Type] mimetype: str = "bytes" file_size: int uploaded_by: SyftVerifyKey - create_at: DateTime = DateTime.now() + created_at: DateTime = DateTime.now() + + +@serializable() +class BlobStorageMetadata(SyftObject): + __canonical_name__ = "BlobStorageMetadata" + __version__ = SYFT_OBJECT_VERSION_1 + + type_: Optional[Type[SyftObject]] + mimetype: str = "bytes" + file_size: int @serializable() @@ -48,13 +85,15 @@ class CreateBlobStorageEntry(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 id: UID - type_: Optional[Type[SyftObject]] + type_: Optional[Type] mimetype: str = "bytes" file_size: int + extensions: List[str] = [] @classmethod def from_obj(cls, obj: SyftObject) -> Self: - return cls(file_size=sys.getsizeof(obj), type_=type(obj)) + file_size = sys.getsizeof(serialize._serialize(obj=obj, to_bytes=True)) + return cls(file_size=file_size, type_=type(obj)) @classmethod def from_path(cls, fp: Union[str, Path], mimetype: Optional[str] = None) -> Self: @@ -74,4 +113,18 @@ def from_path(cls, fp: Union[str, Path], mimetype: Optional[str] = None) -> Self "Please specify mimetype manually `from_path(..., mimetype = ...)`." ) - return cls(mimetype=mimetype, file_size=path.stat().st_size) + return cls( + mimetype=mimetype, + file_size=path.stat().st_size, + extensions=path.suffixes, + type_=BlobFileType, + ) + + @property + def file_name(self) -> str: + return str(self.id) + "".join(self.extensions) + + +@transform(BlobStorageEntry, BlobStorageMetadata) +def storage_entry_to_metadata(): + return [keep(["id", "type_", "mimetype", "file_size"])] diff --git a/packages/syft/src/syft/types/grid_url.py b/packages/syft/src/syft/types/grid_url.py index b6708441db2..9100c03704e 100644 --- a/packages/syft/src/syft/types/grid_url.py +++ b/packages/syft/src/syft/types/grid_url.py @@ -11,6 +11,7 @@ # third party import requests +from typing_extensions import Self # relative from ..serde.serializable import serializable @@ -19,8 +20,8 @@ @serializable(attrs=["protocol", "host_or_ip", "port", "path", "query"]) class GridURL: - @staticmethod - def from_url(url: Union[str, GridURL]) -> GridURL: + @classmethod + def from_url(cls, url: Union[str, GridURL]) -> Self: if isinstance(url, GridURL): return url try: @@ -36,7 +37,7 @@ def from_url(url: Union[str, GridURL]) -> GridURL: host_or_ip = host_or_ip_parts[0] if parts.scheme == "https": port = 443 - return GridURL( + return cls( host_or_ip=host_or_ip, path=parts.path, port=port, @@ -77,12 +78,12 @@ def __init__( self.protocol = protocol self.query = query - def with_path(self, path: str) -> GridURL: + def with_path(self, path: str) -> Self: dupe = copy.copy(self) dupe.path = path return dupe - def as_container_host(self, container_host: Optional[str] = None) -> GridURL: + def as_container_host(self, container_host: Optional[str] = None) -> Self: if self.host_or_ip not in [ "localhost", "host.docker.internal", @@ -106,7 +107,7 @@ def as_container_host(self, container_host: Optional[str] = None) -> GridURL: # convert it back for non container clients hostname = "localhost" - return GridURL( + return self.__class__( protocol=self.protocol, host_or_ip=hostname, port=self.port, @@ -140,7 +141,7 @@ def base_url_no_port(self) -> str: def url_path(self) -> str: return f"{self.path}{self.query_string}" - def to_tls(self) -> GridURL: + def to_tls(self) -> Self: if self.protocol == "https": return self @@ -151,7 +152,9 @@ def to_tls(self) -> GridURL: new_base_url = r.url if new_base_url.endswith("/"): new_base_url = new_base_url[0:-1] - return GridURL.from_url(url=f"{new_base_url}{self.path}{self.query_string}") + return self.__class__.from_url( + url=f"{new_base_url}{self.path}{self.query_string}" + ) def __repr__(self) -> str: return f"<{type(self).__name__} {self.url}>" @@ -162,9 +165,9 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.__str__()) - def copy(self) -> GridURL: - return GridURL.from_url(self.url) + def __copy__(self) -> Self: + return self.__class__.from_url(self.url) - def set_port(self, port: int) -> GridURL: + def set_port(self, port: int) -> Self: self.port = port return self diff --git a/packages/syft/src/syft/types/syft_metaclass.py b/packages/syft/src/syft/types/syft_metaclass.py index 874c5e9df2b..762213b7ba4 100644 --- a/packages/syft/src/syft/types/syft_metaclass.py +++ b/packages/syft/src/syft/types/syft_metaclass.py @@ -7,10 +7,8 @@ from typing import Any from typing import Dict from typing import Generator -from typing import T from typing import Tuple from typing import Type -from typing import Union # third party from pydantic.fields import UndefinedType @@ -30,9 +28,6 @@ class Empty: pass -EmptyType = Union[T, Empty] - - class PartialModelMetaclass(ModelMetaclass): def __new__( meta: Type["PartialModelMetaclass"], *args: Any, **kwargs: Any diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 45f814e7f91..74fe7ee5f2e 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -90,6 +90,10 @@ class Config: syft_node_location: Optional[UID] syft_client_verify_key: Optional[SyftVerifyKey] + def _set_obj_location_(self, node_uid, credentials): + self.syft_node_location = node_uid + self.syft_client_verify_key = credentials + class Context(SyftBaseObject): pass @@ -250,6 +254,8 @@ def _repr_debug_(self) -> str: value = getattr(self, attr, "") value_type = full_name_with_qualname(type(attr)) value_type = value_type.replace("builtins.", "") + if hasattr(value, "syft_action_data_str_"): + value = value.syft_action_data_str_ value = f'"{value}"' if isinstance(value, str) else value _repr_str += f" {attr}: {value_type} = {value}\n" return _repr_str diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index b95eff90555..5529daaff53 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -23,7 +23,7 @@ def to_action_object(obj: Any) -> ActionObject: return obj if type(obj) in action_types: - return action_types[type(obj)](syft_action_data=obj) + return action_types[type(obj)](syft_action_data_cache=obj) raise Exception(f"{type(obj)} not in action_types") @@ -71,3 +71,16 @@ def mock(self) -> ActionObject: mock.syft_twin_type = TwinMode.MOCK mock.id = twin_id return mock + + def _save_to_blob_storage(self): + # Set node location and verify key + self.private_obj._set_obj_location_( + self.syft_node_location, + self.syft_client_verify_key, + ) + # self.mock_obj._set_obj_location_( + # self.syft_node_location, + # self.syft_client_verify_key, + # ) + return self.private_obj._save_to_blob_storage() + # self.mock_obj._save_to_blob_storage() diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 7b969adc49e..8f76f6de5be 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -192,22 +192,21 @@ def get_root_data_path() -> Path: # on Windows the directory is: C:/Users/$USER/.syft/data data_dir = Path.home() / ".syft" / "data" + data_dir.mkdir(parents=True, exist_ok=True) - os.makedirs(data_dir, exist_ok=True) return data_dir def download_file(url: str, full_path: Union[str, Path]) -> Optional[Path]: - if not os.path.exists(full_path): + full_path = Path(full_path) + if not full_path.exists(): r = requests.get(url, allow_redirects=True, verify=verify_tls()) # nosec - if r.status_code < 199 or 299 < r.status_code: + if not r.ok: print(f"Got {r.status_code} trying to download {url}") return None - path = os.path.dirname(full_path) - os.makedirs(path, exist_ok=True) - with open(full_path, "wb") as f: - f.write(r.content) - return Path(full_path) + full_path.parent.mkdir(parents=True, exist_ok=True) + full_path.write_bytes(r.content) + return full_path def verify_tls() -> bool: diff --git a/packages/syft/tests/syft/code_verification_test.py b/packages/syft/tests/syft/code_verification_test.py index c7182f118cf..c3a6a509fab 100644 --- a/packages/syft/tests/syft/code_verification_test.py +++ b/packages/syft/tests/syft/code_verification_test.py @@ -11,13 +11,13 @@ @pytest.fixture def data1() -> ActionObject: """Returns an Action Object with a NumPy dataset with values between -1 and 1""" - return NumpyArrayObject(syft_action_data=2 * np.random.rand(10, 10) - 1) + return NumpyArrayObject.from_obj(2 * np.random.rand(10, 10) - 1) @pytest.fixture def data2() -> ActionObject: """Returns an Action Object with a NumPy dataset with values between -1 and 1""" - return NumpyArrayObject(syft_action_data=2 * np.random.rand(10, 10) - 1) + return NumpyArrayObject.from_obj(2 * np.random.rand(10, 10) - 1) @pytest.fixture @@ -29,7 +29,7 @@ def empty1(data1) -> ActionObject: @pytest.fixture def empty2(data1) -> ActionObject: """Returns an Empty Action Object corresponding to data2""" - return NumpyArrayObject(syft_action_data=ActionDataEmpty(), id=data2.id) + return NumpyArrayObject.from_obj(ActionDataEmpty(), id=data2.id) def test_add_private(data1: ActionObject, data2: ActionObject) -> None: @@ -73,9 +73,7 @@ def test_kwargs(data1: ActionObject) -> None: def test_trace_single_op(data1: ActionObject) -> None: """Test that we can recreate the correct history hash using TraceMode""" result1 = data1.std() - trace_result = NumpyArrayObject( - syft_action_data=ActionDataEmpty(), id=data1.id - ).std() + trace_result = NumpyArrayObject.from_obj(ActionDataEmpty(), id=data1.id).std() assert result1.syft_history_hash == trace_result.syft_history_hash diff --git a/packages/syft/tests/syft/eager_test.py b/packages/syft/tests/syft/eager_test.py index 6ccf9fb826d..6870e444811 100644 --- a/packages/syft/tests/syft/eager_test.py +++ b/packages/syft/tests/syft/eager_test.py @@ -30,7 +30,7 @@ def test_eager_permissions(worker, guest_client): assert all(res_root == [3, 3, 3, 3, 3, 3]) -def test_plan(worker, guest_client): +def test_plan(worker): root_domain_client = worker.root_client guest_client = worker.guest_client @@ -65,9 +65,9 @@ def my_plan(x=np.array([[2, 2, 2], [2, 2, 2]])): # noqa: B008 res_ptr.request(guest_client) # root approves result - root_domain_client.api.services.request[0].approve_with_client(root_domain_client) + root_domain_client.api.services.request[-1].approve_with_client(root_domain_client) - assert res_ptr.get() == 729 + assert res_ptr.get_from(guest_client) == 729 def test_plan_with_function_call(worker, guest_client): @@ -112,7 +112,7 @@ def my_plan(x=np.array([1, 2, 3, 4, 5, 6])): # noqa: B008 res_ptr = plan_ptr(x=pointer) assert all( - root_domain_client.api.services.action.get(res_ptr.id) + root_domain_client.api.services.action.get(res_ptr.id).syft_action_data == np.array([2, 3, 4, 5, 6, 7]) ) diff --git a/packages/syft/tests/syft/serde/numpy_functions_test.py b/packages/syft/tests/syft/serde/numpy_functions_test.py index c312ac33f89..122b0739fae 100644 --- a/packages/syft/tests/syft/serde/numpy_functions_test.py +++ b/packages/syft/tests/syft/serde/numpy_functions_test.py @@ -93,5 +93,5 @@ def test_numpy_functions(func, func_arguments, request): else: original_result = eval(f"np.{func}({func_arguments})") - assert result == original_result + assert np.all(result == original_result) assert isinstance(result, ActionObject) diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index b0d0b9faf3c..131b93d70da 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -334,10 +334,6 @@ def test_actionobject_syft_point_to(): (True, "__and__", [False], {}, False), ((1, 1, 3), "count", [1], {}, 2), ([1, 2, 1], "count", [1], {}, 2), - ([1, 2, 3], "append", [4], {}, [1, 2, 3, 4]), - ({"a": 1, "b": 2}, "update", [{"c": 3}], {}, {"a": 1, "b": 2, "c": 3}), - ({1, 2, 3}, "add", [5], {}, {1, 2, 3, 5}), - ({1, 2, 3}, "clear", [], {}, {}), (complex(1, 2), "conjugate", [], {}, complex(1, -2)), ], ) @@ -556,7 +552,7 @@ def test_actionobject_syft_get_attr_context(): assert obj._syft_get_attr_context("capitalize") is orig_obj assert obj._syft_get_attr_context("__add__") is orig_obj - assert obj._syft_get_attr_context("syft_action_data") is obj + assert obj._syft_get_attr_context("syft_action_data") is obj.syft_action_data @pytest.mark.parametrize( @@ -1009,7 +1005,7 @@ def test_actionobject_syft_getattr_pandas(worker): obj = ActionObject.from_obj(orig_obj) - assert obj.columns == orig_obj.columns + assert (obj.columns == orig_obj.columns).all() obj.columns = ["a", "b", "c"] - assert obj.columns == ["a", "b", "c"] + assert (obj.columns == ["a", "b", "c"]).all() diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index c71f060b731..03772951d71 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -92,7 +92,7 @@ def test_action_store() -> None: assert set_result.is_ok() test_object_result = action_store.get(uid=uid, credentials=test_signing_key) assert test_object_result.is_ok() - assert test_object == test_object_result.ok() + assert (test_object == test_object_result.ok()).all() test_verift_key_2 = SyftVerifyKey.from_string(test_verify_key_string_2) test_object_result_fail = action_store.get(uid=uid, credentials=test_verift_key_2) @@ -203,7 +203,7 @@ def test_action_object_hooks() -> None: def pre_add(context: Any, *args: Any, **kwargs: Any) -> Any: # double it new_value = args[0] - new_value.syft_action_data = new_value.syft_action_data * 2 + new_value.syft_action_data_cache = new_value.syft_action_data_cache * 2 return Ok((context, (new_value,), kwargs)) def post_add(context: Any, name: str, new_result: Any) -> Any: diff --git a/tox.ini b/tox.ini index 46a5ae90ce4..d0ed8ec06d1 100644 --- a/tox.ini +++ b/tox.ini @@ -196,8 +196,9 @@ commands = ; reset volumes and create nodes bash -c "echo Starting Nodes; date" bash -c "docker rm -f $(docker ps -a -q) || true" - bash -c "docker volume rm test_domain_1_mongo-data --force || true" - bash -c "docker volume rm test_domain_1_credentials-data --force || true" + bash -c "docker volume rm test-domain-1_mongo-data --force || true" + bash -c "docker volume rm test-domain-1_credentials-data --force || true" + bash -c "docker volume rm test-domain-1_seaweedfs-data --force || true" bash -c 'HAGRID_ART=$HAGRID_ART hagrid launch test_domain_1 domain to docker:9081 $HAGRID_FLAGS --enable-signup --no-health-checks --verbose --no-warnings' @@ -255,21 +256,20 @@ commands = ; reset volumes and create nodes bash -c "echo Starting Nodes; date" bash -c "docker rm -f $(docker ps -a -q) || true" - bash -c "docker volume rm test_domain_1_mongo-data --force || true" - bash -c "docker volume rm test_domain_1_credentials-data --force || true" - bash -c "docker volume rm test_domain_2_mongo-data --force || true" - bash -c "docker volume rm test_domain_2_credentials-data --force || true" - bash -c "docker volume rm test_gateway_1_mongo-data --force || true" - bash -c "docker volume rm test_gateway_1_credentials-data --force || true" - bash -c "docker volume rm test_domain_1_seaweedfs-data --force || true" - bash -c "docker volume rm test_domain_2_seaweedfs-data --force || true" - bash -c "docker volume rm test_domain_1_app-redis-data --force || true" - bash -c "docker volume rm test_domain_2_app-redis-data --force || true" - bash -c "docker volume rm test_gateway_1_app-redis-data --force || true" - bash -c "docker volume rm test_domain_1_tailscale-data --force || true" - bash -c "docker volume rm test_domain_2_tailscale-data --force || true" - bash -c "docker volume rm test_gateway_1_tailscale-data --force || true" - bash -c "docker volume rm test_gateway_1_headscale-data --force || true" + bash -c "docker volume rm test-domain-1_mongo-data --force || true" + bash -c "docker volume rm test-domain-1_credentials-data --force || true" + bash -c "docker volume rm test-domain-1_seaweedfs-data --force || true" + bash -c "docker volume rm test-domain-2_mongo-data --force || true" + bash -c "docker volume rm test-domain-2_credentials-data --force || true" + bash -c "docker volume rm test-domain-2_seaweedfs-data --force || true" + bash -c "docker volume rm test-gateway-1_mongo-data --force || true" + bash -c "docker volume rm test-gateway-1_credentials-data --force || true" + bash -c "docker volume rm test-gateway-1_seaweedfs-data --force || true" + + bash -c "docker volume rm test-domain-1_tailscale-data --force || true" + bash -c "docker volume rm test-domain-2_tailscale-data --force || true" + bash -c "docker volume rm test-gateway-1_tailscale-data --force || true" + bash -c "docker volume rm test-gateway-1_headscale-data --force || true" bash -c 'HAGRID_ART=$HAGRID_ART hagrid launch test_gateway_1 network to docker:9081 $HAGRID_FLAGS --no-health-checks --verbose --no-warnings' bash -c 'HAGRID_ART=$HAGRID_ART hagrid launch test_domain_1 domain to docker:9082 $HAGRID_FLAGS --no-health-checks --enable-signup --verbose --no-warnings' @@ -455,6 +455,7 @@ commands = # Volume cleanup bash -c "docker volume rm test-domain-1_mongo-data --force || true" bash -c "docker volume rm test-domain-1_credentials-data --force || true" + bash -c "docker volume rm test-domain-1_seaweedfs-data --force || true" bash -c "echo Running with ORCHESTRA_DEPLOYMENT_TYPE=$ORCHESTRA_DEPLOYMENT_TYPE DEV_MODE=$DEV_MODE TEST_NOTEBOOK_PATHS=$TEST_NOTEBOOK_PATHS; date" bash -c "for subfolder in $(echo ${TEST_NOTEBOOK_PATHS} | tr ',' ' ');\