diff --git a/.github/workflows/pr-tests-frontend.yml b/.github/workflows/pr-tests-frontend.yml index 4225c635424..fb9520c59b0 100644 --- a/.github/workflows/pr-tests-frontend.yml +++ b/.github/workflows/pr-tests-frontend.yml @@ -66,7 +66,7 @@ jobs: - name: Docker on MacOS if: steps.changes.outputs.frontend == 'true' && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Install Tox if: steps.changes.outputs.frontend == 'true' @@ -156,7 +156,7 @@ jobs: - name: Docker on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Install Tox if: steps.changes.outputs.stack == 'true' diff --git a/.github/workflows/pr-tests-helm-lint.yml b/.github/workflows/pr-tests-helm-lint.yml new file mode 100644 index 00000000000..1ef21e5a5f9 --- /dev/null +++ b/.github/workflows/pr-tests-helm-lint.yml @@ -0,0 +1,50 @@ +name: PR Tests - Lint Helm Charts + +on: + # temporary disabled + # pull_request: + # branches: + # - dev + # paths: + # - packages/grid/helm/syft/** + + workflow_dispatch: + inputs: + none: + description: "Run Tests Manually" + required: false + +concurrency: + group: pr-tests-helm-lint + cancel-in-progress: true + +jobs: + pr-tests-helm-lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install dependencies + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + brew update + brew tap FairwindsOps/tap + brew install kube-linter FairwindsOps/tap/polaris + + # Install python deps + pip install --upgrade pip + pip install tox + + kube-linter version + polaris version + + - name: Run Polaris + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + polaris audit --helm-chart packages/grid/helm/syft --helm-values packages/grid/helm/syft/values.yaml --format=pretty --only-show-failed-tests + + - name: Run Linter + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + tox -e syft.lint.helm diff --git a/.github/workflows/pr-tests-helm-upgrade.yml b/.github/workflows/pr-tests-helm-upgrade.yml new file mode 100644 index 00000000000..be8bbc21996 --- /dev/null +++ b/.github/workflows/pr-tests-helm-upgrade.yml @@ -0,0 +1,66 @@ +name: PR Tests - Helm Upgrade + +on: + # Re-enable when we have a stable helm chart + # pull_request: + # branches: + # - dev + # paths: + # - packages/grid/helm/syft/** + + workflow_dispatch: + inputs: + upgrade_type: + description: "Select upgrade path type" + required: false + default: "BetaToDev" + type: choice + options: + - BetaToDev + - ProdToBeta + - ProdToDev + +concurrency: + group: pr-tests-helm-upgrade + cancel-in-progress: true + +jobs: + pr-tests-helm-upgrade: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install dependencies + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + brew update + + # Install python deps + pip install --upgrade pip + pip install tox + + # Install kubernetes + brew install helm k3d devspace kubectl + + - name: Setup cluster + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + + tox -e dev.k8s.start + + - name: Upgrade helm chart + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + + # default upgrade is beta to dev, but override with input if provided + UPGRADE_TYPE_INPUT=${{ github.event.inputs.upgrade_type }} + export UPGRADE_TYPE=${UPGRADE_TYPE_INPUT:-BetaToDev} + tox -e syft.test.helm.upgrade + + - name: Destroy cluster + if: always() + run: | + eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)" + + tox -e dev.k8s.destroyall diff --git a/.github/workflows/pr-tests-stack-arm64.yml b/.github/workflows/pr-tests-stack-arm64.yml index 1f9dff53a15..705a95ac16b 100644 --- a/.github/workflows/pr-tests-stack-arm64.yml +++ b/.github/workflows/pr-tests-stack-arm64.yml @@ -94,7 +94,7 @@ jobs: docker run --rm --privileged multiarch/qemu-user-static --reset -p yes - name: Run integration tests - uses: nick-fields/retry@v2 + uses: nick-fields/retry@v3 with: timeout_seconds: 36000 max_attempts: 3 diff --git a/.github/workflows/pr-tests-stack-public.yml b/.github/workflows/pr-tests-stack-public.yml index 7a173000d02..8e102ce0a94 100644 --- a/.github/workflows/pr-tests-stack-public.yml +++ b/.github/workflows/pr-tests-stack-public.yml @@ -117,7 +117,7 @@ jobs: - name: Docker on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Docker Compose on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' @@ -148,7 +148,7 @@ jobs: # tox -e stack.test.integration - name: Run integration tests - uses: nick-fields/retry@v2 + uses: nick-fields/retry@v3 if: steps.changes.outputs.stack == 'true' env: HAGRID_ART: false diff --git a/.github/workflows/pr-tests-stack.yml b/.github/workflows/pr-tests-stack.yml index 67f0d4854b1..421559b42d3 100644 --- a/.github/workflows/pr-tests-stack.yml +++ b/.github/workflows/pr-tests-stack.yml @@ -141,7 +141,7 @@ jobs: - name: Docker on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Docker Compose on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' @@ -168,6 +168,7 @@ jobs: HAGRID_ART: false PYTEST_MODULES: "${{ matrix.pytest-modules }}" run: | + export AZURE_BLOB_STORAGE_KEY="${{ secrets.AZURE_BLOB_STORAGE_KEY }}" tox -e stack.test.integration #Run log collector python script @@ -413,7 +414,7 @@ jobs: - name: Docker on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Docker Compose on MacOS if: steps.changes.outputs.stack == 'true' && matrix.os == 'macos-latest' @@ -645,10 +646,10 @@ jobs: shell: bash run: | mkdir -p ./k8s-logs - kubectl describe all -A --context k3d-testgateway1 --namespace testgateway1 > ./k8s-logs/testgateway1-desc-${{ steps.date.outputs.date }}.txt - kubectl describe all -A --context k3d-testdomain1 --namespace testdomain1 > ./k8s-logs/testdomain1-desc-${{ steps.date.outputs.date }}.txt - kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testgateway1 --namespace testgateway1 > ./k8s-logs/testgateway1-logs-${{ steps.date.outputs.date }}.txt - kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testdomain1 --namespace testdomain1 > ./k8s-logs/testdomain1-logs-${{ steps.date.outputs.date }}.txt + kubectl describe all -A --context k3d-testgateway1 --namespace syft > ./k8s-logs/testgateway1-desc-${{ steps.date.outputs.date }}.txt + kubectl describe all -A --context k3d-testdomain1 --namespace syft > ./k8s-logs/testdomain1-desc-${{ steps.date.outputs.date }}.txt + kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testgateway1 --namespace syft > ./k8s-logs/testgateway1-logs-${{ steps.date.outputs.date }}.txt + kubectl logs -l app.kubernetes.io/name!=random --prefix=true --context k3d-testdomain1 --namespace syft > ./k8s-logs/testdomain1-logs-${{ steps.date.outputs.date }}.txt ls -la ./k8s-logs - name: Upload logs to GitHub diff --git a/.github/workflows/pr-tests-syft.yml b/.github/workflows/pr-tests-syft.yml index 0852db9269b..f55a37ee3d5 100644 --- a/.github/workflows/pr-tests-syft.yml +++ b/.github/workflows/pr-tests-syft.yml @@ -90,7 +90,7 @@ jobs: - name: Docker on MacOS if: steps.changes.outputs.syft == 'true' && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Run unit tests if: steps.changes.outputs.syft == 'true' @@ -174,7 +174,7 @@ jobs: pip install --upgrade tox packaging wheel --default-timeout=60 - name: Run notebook tests - uses: nick-fields/retry@v2 + uses: nick-fields/retry@v3 if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks == 'true' env: ORCHESTRA_DEPLOYMENT_TYPE: "${{ matrix.deployment-type }}" @@ -265,7 +265,7 @@ jobs: - name: Docker on MacOS if: (steps.changes.outputs.stack == 'true' || steps.changes.outputs.notebooks == 'true') && matrix.os == 'macos-latest' - uses: crazy-max/ghaction-setup-docker@v3.0.0 + uses: crazy-max/ghaction-setup-docker@v3.1.0 - name: Docker Compose on MacOS if: (steps.changes.outputs.stack == 'true' || steps.changes.outputs.notebooks == 'true') && matrix.os == 'macos-latest' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78d7205afb6..77995cb5a74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,8 +172,8 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" - # files: "^packages/syft/src/syft/" + files: "^packages/syft/src/syft/" + exclude: "packages/syft/src/syft/types/dicttuple.py|^packages/syft/src/syft/service/action/action_graph.py|^packages/syft/src/syft/external/oblv/" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/DEBUGGING.md b/DEBUGGING.md new file mode 100644 index 00000000000..5dc8d6bae17 --- /dev/null +++ b/DEBUGGING.md @@ -0,0 +1,74 @@ +# Debugging PySyft + +We currently provide information on how to debug PySyft using Visual Studio Code and PyCharm. If you have any other IDE or debugger that you would like to add to this list, please feel free to contribute. + +## VSCode + +If you're running Add the following in `.vscode/launch.json` + +``` +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Remote Attach", + "type": "debugpy", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5678 + }, + "justMyCode": false, + "internalConsoleOptions": "openOnSessionStart", + "pathMappings": [ + { + "localRoot": "${workspaceFolder}/packages/syft/src", + "remoteRoot": "/root/app/syft/src" + }, + { + "localRoot": "${workspaceFolder}/packages/grid/backend/grid", + "remoteRoot": "/root/app/grid" + } + ] + } + ] +} +``` + +Then run + +```bash +tox -e dev.k8s.hotreload +``` + +And you can attach the debugger running on port 5678. + +## PyCharm + +Add the following to `packages/grid/backend/grid/__init__.py` + +```py +import pydevd_pycharm +pydevd_pycharm.settrace('your-local-addr', port=5678, suspend=False) +``` + +Ensure that `your-local-addr` is reachable from the containers. + +Next, replace the debugpy install and `DEBUG_CMD` in `packages/grid/backend/grid/start.sh`: + +```bash +# only set by kubernetes to avoid conflict with docker tests +if [[ ${DEBUGGER_ENABLED} == "True" ]]; +then + pip install --user pydevd-pycharm==233.14475.56 # remove debugpy, add pydevd-pycharm + DEBUG_CMD="" # empty the debug command +fi +``` + +If it fails to connect, check the backend logs. You might need to install a different pydevd-pycharm version. The version to be installed is shown in the log error message. + +Whenever you start a container, it attempts to connect to PyCharm. + +```bash +tox -e dev.k8s.hotreload +``` diff --git a/docs/requirements.txt b/docs/requirements.txt index d9247032aa1..6f3176dae92 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ certifi>=2023.7.22 # not directly required, pinned by Snyk to avoid a vulnerability ipython==8.10.0 +jinja2>=3.1.3 # not directly required, pinned by Snyk to avoid a vulnerability markupsafe==2.0.1 pydata-sphinx-theme==0.7.2 pygments>=2.15.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/notebooks/api/0.8/01-submit-code.ipynb b/notebooks/api/0.8/01-submit-code.ipynb index 080216c153d..71f40191867 100644 --- a/notebooks/api/0.8/01-submit-code.ipynb +++ b/notebooks/api/0.8/01-submit-code.ipynb @@ -318,6 +318,22 @@ " return (float(total / 1_000_000), float(noise))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before we can run this, we need to run:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "pip install opendp\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -583,7 +599,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/notebooks/api/0.8/03-data-scientist-download-result.ipynb b/notebooks/api/0.8/03-data-scientist-download-result.ipynb index 7788b7484ca..81ce4f783fc 100644 --- a/notebooks/api/0.8/03-data-scientist-download-result.ipynb +++ b/notebooks/api/0.8/03-data-scientist-download-result.ipynb @@ -54,7 +54,7 @@ }, "outputs": [], "source": [ - "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True)" + "node = sy.orchestra.launch(name=\"test-domain-1\", dev_mode=True)" ] }, { @@ -166,19 +166,19 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Because the output policy is `OutputPolicyExecuteOnce`, this function cannot be run with other inputs. We can verify the validatiy of the policy as follows" + "domain_client.code" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "ops.valid" + "Because the output policy is `OutputPolicyExecuteOnce`, this function cannot be run with other inputs. We can verify the validatiy of the policy as follows" ] }, { @@ -189,8 +189,8 @@ }, "outputs": [], "source": [ - "assert isinstance(ops.valid, sy.SyftError)\n", - "assert ops.count == 1" + "assert isinstance(ops.is_valid, sy.SyftError)\n", + "assert ops.count > 0" ] }, { @@ -231,7 +231,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/notebooks/api/0.8/05-custom-policy.ipynb b/notebooks/api/0.8/05-custom-policy.ipynb index a72942170a7..3095d051a07 100644 --- a/notebooks/api/0.8/05-custom-policy.ipynb +++ b/notebooks/api/0.8/05-custom-policy.ipynb @@ -41,7 +41,7 @@ }, "outputs": [], "source": [ - "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True)" + "node = sy.orchestra.launch(name=\"test-domain-1\", port=\"auto\", dev_mode=True, reset=True)" ] }, { @@ -412,7 +412,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/notebooks/helm/direct_azure.ipynb b/notebooks/helm/direct_azure.ipynb deleted file mode 100644 index ddb00875b61..00000000000 --- a/notebooks/helm/direct_azure.ipynb +++ /dev/null @@ -1,879 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "# stdlib\n", - "import os\n", - "\n", - "# syft absolute\n", - "import syft as sy" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Logged into <test: High side Domain> as <info@openmined.org>\n" - ] - }, - { - "data": { - "text/html": [ - "<div class=\"alert-warning\" style=\"padding:5px;\"><strong>SyftWarning</strong>: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.</div><br />" - ], - "text/plain": [ - "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# node = sy.orchestra.launch(\n", - "# name=\"test-domain-helm2\",\n", - "# dev_mode=True,\n", - "# reset=True,\n", - "# n_consumers=4,\n", - "# create_producer=True,\n", - "# )\n", - "# client = node.login(email=\"info@openmined.org\", password=\"changethis\")\n", - "\n", - "client = sy.login(\n", - " url=\"http://localhost:8080\", email=\"info@openmined.org\", password=\"changethis\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [], - "source": [ - "# # syft absolute\n", - "# from syft.store.blob_storage import BlobStorageClientConfig\n", - "# from syft.store.blob_storage import BlobStorageConfig\n", - "# from syft.store.blob_storage.seaweedfs import SeaweedFSClient\n", - "# from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig\n", - "\n", - "# blob_config = BlobStorageConfig(\n", - "# client_type=SeaweedFSClient,\n", - "# client_config=SeaweedFSClientConfig(\n", - "# host=\"http://0.0.0.0\",\n", - "# port=\"8333\",\n", - "# access_key=\"admin\",\n", - "# secret_key=\"admin\",\n", - "# bucket_name=\"test_bucket\",\n", - "# region=\"us-east-1\",\n", - "# # mount_port=4001\n", - "# ),\n", - "# )\n", - "# node.python_node.init_blob_storage(blob_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<div class=\"alert-success\" style=\"padding:5px;\"><strong>SyftSuccess</strong>: Mounting Azure Successful!</div><br />" - ], - "text/plain": [ - "SyftSuccess: Mounting Azure Successful!" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.api.services.blob_storage.mount_azure(\n", - " account_name=\"helmprojectstorage\",\n", - " container_name=\"helm\",\n", - " account_key=os.environ[\"HELM_STORAGE_ACCOUNT_KEY\"],\n", - " bucket_name=\"helmazurebucket\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [], - "source": [ - "files = client.api.services.blob_storage.get_files_from_bucket(\"helmazurebucket\")" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "<style>\n", - " body[data-jp-theme-light='false'] {\n", - " --primary-color: #111111;\n", - " --secondary-color: #212121;\n", - " --tertiary-color: #CFCDD6;\n", - " --button-color: #111111;\n", - " }\n", - "\n", - " body {\n", - " --primary-color: #ffffff;\n", - " --secondary-color: #f5f5f5;\n", - " --tertiary-color: #000000de;\n", - " --button-color: #d1d5db;\n", - " }\n", - "\n", - " .header-1 {\n", - " font-style: normal;\n", - " font-weight: 600;\n", - " font-size: 2.0736em;\n", - " line-height: 100%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: #17161D;\n", - " }\n", - "\n", - " .header-2 {\n", - " font-style: normal;\n", - " font-weight: 600;\n", - " font-size: 1.728em;\n", - " line-height: 100%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: #17161D;\n", - " }\n", - "\n", - " .header-3 {\n", - " font-style: normal;\n", - " font-weight: 600;\n", - " font-size: 1.44em;\n", - " line-height: 100%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .header-4 {\n", - " font-style: normal;\n", - " font-weight: 600;\n", - " font-size: 1.2em;\n", - " line-height: 100%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: #17161D;\n", - " }\n", - "\n", - " .paragraph {\n", - " font-style: normal;\n", - " font-weight: 400;\n", - " font-size: 14px;\n", - " line-height: 100%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: #2E2B3B;\n", - " }\n", - "\n", - " .paragraph-sm {\n", - " font-family: 'Roboto';\n", - " font-style: normal;\n", - " font-weight: 400;\n", - " font-size: 11.62px;\n", - " line-height: 100%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: #2E2B3B;\n", - " }\n", - " .code-text {\n", - " font-family: 'Consolas';\n", - " font-style: normal;\n", - " font-weight: 400;\n", - " font-size: 13px;\n", - " line-height: 130%;\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " color: #2E2B3B;\n", - " }\n", - "\n", - " .numbering-entry { display: none }\n", - "\n", - " /* Tooltip container */\n", - " .tooltip {\n", - " position: relative;\n", - " display: inline-block;\n", - " border-bottom: 1px dotted black; /* If you want dots under the hoverable text */\n", - " }\n", - "\n", - " /* Tooltip text */\n", - " .tooltip .tooltiptext {\n", - " visibility: hidden;\n", - " width: 120px;\n", - " background-color: black;\n", - " color: #fff;\n", - " text-align: center;\n", - " padding: 5px 0;\n", - " border-radius: 6px;\n", - "\n", - " /* Position the tooltip text - see examples below! */\n", - " position: absolute;\n", - " z-index: 1;\n", - " }\n", - "\n", - " .repr-cell {\n", - " padding-top: 20px;\n", - " }\n", - "\n", - " .text-bold {\n", - " font-weight: bold;\n", - " }\n", - "\n", - " .pr-8 {\n", - " padding-right: 8px;\n", - " }\n", - " .pt-8 {\n", - " padding-top: 8px;\n", - " }\n", - " .pl-8 {\n", - " padding-left: 8px;\n", - " }\n", - " .pb-8 {\n", - " padding-bottom: 8px;\n", - " }\n", - "\n", - " .py-25{\n", - " padding-top: 25px;\n", - " padding-bottom: 25px;\n", - " }\n", - "\n", - " .flex {\n", - " display: flex;\n", - " }\n", - "\n", - " .gap-10 {\n", - " gap: 10px;\n", - " }\n", - " .items-center{\n", - " align-items: center;\n", - " }\n", - "\n", - " .folder-icon {\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .search-input{\n", - " display: flex;\n", - " flex-direction: row;\n", - " align-items: center;\n", - " padding: 8px 12px;\n", - " width: 343px;\n", - " height: 24px;\n", - " /* Lt On Surface/Low */\n", - " background-color: var(--secondary-color);\n", - " border-radius: 30px;\n", - "\n", - " /* Lt On Surface/Highest */\n", - " color: var(--tertiary-color);\n", - " border:none;\n", - " /* Inside auto layout */\n", - " flex: none;\n", - " order: 0;\n", - " flex-grow: 0;\n", - " }\n", - " .search-input:focus {\n", - " outline: none;\n", - " }\n", - " .search-input:focus::placeholder,\n", - " .search-input::placeholder { /* Chrome, Firefox, Opera, Safari 10.1+ */\n", - " color: var(--tertiary-color);\n", - " opacity: 1; /* Firefox */\n", - " }\n", - "\n", - " .search-button{\n", - " /* Search */\n", - " leading-trim: both;\n", - " text-edge: cap;\n", - " display: flex;\n", - " align-items: center;\n", - " text-align: center;\n", - "\n", - " /* Primary/On Light */\n", - " background-color: var(--button-color);\n", - " color: var(--tertiary-color);\n", - "\n", - " border-radius: 30px;\n", - " border-color: var(--secondary-color);\n", - " border-style: solid;\n", - " box-shadow: rgba(60, 64, 67, 0.3) 0px 1px 2px 0px, rgba(60, 64, 67, 0.15) 0px 1px 3px 1px;\n", - " cursor: pointer;\n", - " /* Inside auto layout */\n", - " flex: none;\n", - " order: 1;\n", - " flex-grow: 0;\n", - " }\n", - "\n", - " .grid-table1f0c653a8f0140bb98abde83607cb8db {\n", - " display:grid;\n", - " grid-template-columns: 1fr repeat(8, 1fr);\n", - " grid-template-rows: repeat(2, 1fr);\n", - " overflow-x: auto;\n", - " }\n", - "\n", - " .grid-std-cells {\n", - " grid-column: span 4;\n", - "\n", - " }\n", - " .grid-index-cells {\n", - " grid-column: span 1;\n", - " /* tmp fix to make left col stand out (fix with font-family) */\n", - " font-weight: 600;\n", - " background-color: var(--secondary-color) !important;\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .grid-header {\n", - " /* Auto layout */\n", - " display: flex;\n", - " flex-direction: column;\n", - " align-items: center;\n", - " padding: 6px 4px;\n", - "\n", - " /* Lt On Surface/Surface */\n", - " /* Lt On Surface/High */\n", - " border: 1px solid #CFCDD6;\n", - " /* tmp fix to make header stand out (fix with font-family) */\n", - " font-weight: 600;\n", - " background-color: var(--secondary-color);\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .grid-row {\n", - " display: flex;\n", - " flex-direction: column;\n", - " align-items: flex-start;\n", - " padding: 6px 4px;\n", - " overflow: hidden;\n", - " border: 1px solid #CFCDD6;\n", - " background-color: var(--primary-color);\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .badge {\n", - " code-text;\n", - " border-radius: 30px;\n", - " }\n", - "\n", - " .badge-blue {\n", - " badge;\n", - " background-color: #C2DEF0;\n", - " color: #1F567A;\n", - " }\n", - "\n", - " .badge-purple {\n", - " badge;\n", - " background-color: #C9CFE8;\n", - " color: #373B7B;\n", - " }\n", - "\n", - " .badge-green {\n", - " badge;\n", - "\n", - " /* Success/Container */\n", - " background-color: #D5F1D5;\n", - " color: #256B24;\n", - " }\n", - "\n", - " .badge-red {\n", - " badge;\n", - " background-color: #F2D9DE;\n", - " color: #9B2737;\n", - " }\n", - "\n", - " .badge-gray {\n", - " badge;\n", - " background-color: #ECEBEF;\n", - " color: #2E2B3B;\n", - " }\n", - " .paginationContainer{\n", - " width: 100%;\n", - " height: 30px;\n", - " display: flex;\n", - " justify-content: center;\n", - " gap: 8px;\n", - " padding: 5px;\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .page{\n", - " color: black;\n", - " font-weight: bold;\n", - " color: var(--tertiary-color);\n", - " }\n", - " .page:hover {\n", - " color: #38bdf8;\n", - " cursor: pointer;\n", - " }\n", - " .clipboard:hover{\n", - " cursor: pointer;\n", - " color: var(--tertiary-color);\n", - " }\n", - "\n", - " .search-field {\n", - " display: flex;\n", - " align-items: center;\n", - " border-radius: 30px;\n", - " background-color: var(--secondary-color);\n", - " }\n", - "\n", - " .syft-dropdown {\n", - " margin: 5px;\n", - " margin-left: 5px;\n", - " position: relative;\n", - " display: inline-block;\n", - " text-align: center;\n", - " background-color: var(--button-color);\n", - " min-width: 100px;\n", - " padding: 2px;\n", - " border-radius: 30px;\n", - " }\n", - "\n", - " .syft-dropdown:hover {\n", - " cursor: pointer;\n", - " }\n", - " .syft-dropdown-content {\n", - " margin-top:26px;\n", - " display: none;\n", - " position: absolute;\n", - " min-width: 100px;\n", - " box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);\n", - " padding: 12px 6px;\n", - " z-index: 1;\n", - " background-color: var(--primary-color);\n", - " color: var(--tertiary-color);\n", - " }\n", - " .dd-options {\n", - " padding-top: 4px;\n", - " }\n", - " .dd-options:first-of-type {\n", - " padding-top: 0px;\n", - " }\n", - "\n", - " .dd-options:hover {\n", - " cursor: pointer;\n", - " background: #d1d5db;\n", - " }\n", - " .arrow {\n", - " border: solid black;\n", - " border-width: 0 3px 3px 0;\n", - " display: inline-block;\n", - " padding: 3px;\n", - " }\n", - " .down {\n", - " transform: rotate(45deg);\n", - " -webkit-transform: rotate(45deg);\n", - " }\n", - "</style>\n", - "\n", - "\n", - " <div style='margin-top:15px;'>\n", - " <div class='flex gap-10' style='align-items: center;'>\n", - " <div class='folder-icon'><svg width=\"32\" height=\"32\" viewBox=\"0 0 32 32\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"> <path d=\"M28 6H4C3.73478 6 3.48043 6.10536 3.29289 6.29289C3.10536 6.48043 3 6.73478 3 7V24C3 24.5304 3.21071 25.0391 3.58579 25.4142C3.96086 25.7893 4.46957 26 5 26H27C27.5304 26 28.0391 25.7893 28.4142 25.4142C28.7893 25.0391 29 24.5304 29 24V7C29 6.73478 28.8946 6.48043 28.7071 6.29289C28.5196 6.10536 28.2652 6 28 6ZM5 14H10V18H5V14ZM12 14H27V18H12V14ZM27 8V12H5V8H27ZM5 20H10V24H5V20ZM27 24H12V20H27V24Z\" fill=\"#343330\"/></svg></div>\n", - " <div><p class='header-3'>BlobFile List</p></div>\n", - " </div>\n", - "\n", - " <div style=\"padding-top: 16px; display:flex;justify-content: space-between; align-items: center;\">\n", - " <div class='pt-25 gap-10' style=\"display:flex;\">\n", - " <div class=\"search-field\">\n", - " <div id='search-menu1f0c653a8f0140bb98abde83607cb8db' class=\"syft-dropdown\" onclick=\"{\n", - " let doc = document.getElementById('search-dropdown-content1f0c653a8f0140bb98abde83607cb8db')\n", - " if (doc.style.display === 'block'){\n", - " doc.style.display = 'none'\n", - " } else {\n", - " doc.style.display = 'block'\n", - " }\n", - " }\">\n", - " <div id='search-dropdown-content1f0c653a8f0140bb98abde83607cb8db' class='syft-dropdown-content'></div>\n", - " <script>\n", - " var element1f0c653a8f0140bb98abde83607cb8db = [{\"id\": {\"value\": \"c6eb664977464afa8e6e49b518b1ca3f\", \"type\": \"clipboard\"}, \"file_name\": \"train-22.jsonl\"}, {\"id\": {\"value\": \"7dd1e9dcdca443b785ebc19e20082810\", \"type\": \"clipboard\"}, \"file_name\": \"train-28.jsonl\"}, {\"id\": {\"value\": \"be6e757db1de40bf86aa5376d2a7d38e\", \"type\": \"clipboard\"}, \"file_name\": \"train-10.jsonl\"}, {\"id\": {\"value\": \"d08884257c55492ebc748d6116d110e0\", \"type\": \"clipboard\"}, \"file_name\": \"train-09.jsonl\"}, {\"id\": {\"value\": \"99ebe41e7ce14123924a4788a88917b0\", \"type\": \"clipboard\"}, \"file_name\": \"train-00.jsonl\"}, {\"id\": {\"value\": \"883b12180b4e4dc2a7ffdfb6c4ecf28e\", \"type\": \"clipboard\"}, \"file_name\": \"train-16.jsonl\"}, {\"id\": {\"value\": \"3b3a6dd2d8684d41817f1c8141e70a5f\", \"type\": \"clipboard\"}, \"file_name\": \"train-03.jsonl\"}, {\"id\": {\"value\": \"63cb44e6f14842f880eeef8f279cd3bf\", \"type\": \"clipboard\"}, \"file_name\": \"train-08.jsonl\"}, {\"id\": {\"value\": \"8b22e5778b474d3988cc40cc90535fc0\", \"type\": \"clipboard\"}, \"file_name\": \"train-15.jsonl\"}, {\"id\": {\"value\": \"8f233131bf4846939726332780a07c60\", \"type\": \"clipboard\"}, \"file_name\": \"train-14.jsonl\"}, {\"id\": {\"value\": \"72a156bb972b49178473c65a23366dfe\", \"type\": \"clipboard\"}, \"file_name\": \"train-05.jsonl\"}, {\"id\": {\"value\": \"a30b93d9e83b49d18d750fc9e1a3b7d6\", \"type\": \"clipboard\"}, \"file_name\": \"pile-uncopyrighted@3be90335b66f24456a5d6659d9c8d208c0357119-test.jsonl\"}, {\"id\": {\"value\": \"3c90a243a4c04c229932bc3f9761b29a\", \"type\": \"clipboard\"}, \"file_name\": \"train-18.jsonl\"}, {\"id\": {\"value\": \"8e5f3d15161c45439e07c9335121d9cb\", \"type\": \"clipboard\"}, \"file_name\": \"train-24.jsonl\"}, {\"id\": {\"value\": \"21c2fda9b7154d1e913b5c1acdc6b673\", \"type\": \"clipboard\"}, \"file_name\": \"train-17.jsonl\"}, {\"id\": {\"value\": \"5007841eaa7344bda0172be78bfda8d5\", \"type\": \"clipboard\"}, \"file_name\": \"train-02.jsonl\"}, {\"id\": {\"value\": \"47aa7b16178c4b78afe3ae67c636c74b\", \"type\": \"clipboard\"}, \"file_name\": \"train-01.jsonl\"}, {\"id\": {\"value\": \"03bb5e02bdf24d9abf947e6253dafdf6\", \"type\": \"clipboard\"}, \"file_name\": \"train-21.jsonl\"}, {\"id\": {\"value\": \"a5bedd7d4be94d6ab80cd2a8783e6b92\", \"type\": \"clipboard\"}, \"file_name\": \"train-06.jsonl\"}, {\"id\": {\"value\": \"b01e67ed720848e5bed1cf203074e93a\", \"type\": \"clipboard\"}, \"file_name\": \"pile-uncopyrighted@3be90335b66f24456a5d6659d9c8d208c0357119-val.jsonl\"}, {\"id\": {\"value\": \"f1ba5789316a429bbc3df50b7ac47328\", \"type\": \"clipboard\"}, \"file_name\": \"train-20.jsonl\"}, {\"id\": {\"value\": \"f48d604a007a403caacbee285b31147e\", \"type\": \"clipboard\"}, \"file_name\": \"train-11.jsonl\"}, {\"id\": {\"value\": \"fb7fcbb35e2343ae9960ac106915f79f\", \"type\": \"clipboard\"}, \"file_name\": \"train-04.jsonl\"}, {\"id\": {\"value\": \"863dc5e76a024915b4564f45e26f20b0\", \"type\": \"clipboard\"}, \"file_name\": \"train-19.jsonl\"}, {\"id\": {\"value\": \"773ee81539174876975f8e2d6f5cb5d2\", \"type\": \"clipboard\"}, \"file_name\": \"train-13.jsonl\"}, {\"id\": {\"value\": \"8db2289f99a0482aa22ac308ed162c30\", \"type\": \"clipboard\"}, \"file_name\": \"train-27.jsonl\"}, {\"id\": {\"value\": \"cb2ef738082c49418ed70eb05a193770\", \"type\": \"clipboard\"}, \"file_name\": \"test.json\"}, {\"id\": {\"value\": \"77e21eeadda94f8a9e1bd735131d9971\", \"type\": \"clipboard\"}, \"file_name\": \"train-26.jsonl\"}, {\"id\": {\"value\": \"a7776c9595114c8ca0a2d60f6a964d9c\", \"type\": \"clipboard\"}, \"file_name\": \"train-07.jsonl\"}, {\"id\": {\"value\": \"055d4a5a61974dd4a075c42bbaaefc30\", \"type\": \"clipboard\"}, \"file_name\": \"train-12.jsonl\"}, {\"id\": {\"value\": \"fe8cec1260544c5f8053aa2fa115ae0c\", \"type\": \"clipboard\"}, \"file_name\": \"filtered_scenario_data_new.jsonl\"}, {\"id\": {\"value\": \"555328628e7546ce86a992712d10a313\", \"type\": \"clipboard\"}, \"file_name\": \"train-25.jsonl\"}, {\"id\": {\"value\": \"128d3962d1d04e45aef457906e16a7ab\", \"type\": \"clipboard\"}, \"file_name\": \"train-29.jsonl\"}, {\"id\": {\"value\": \"83c2188ace1e4cf5bdbf6b906bd56bae\", \"type\": \"clipboard\"}, \"file_name\": \"train-23.jsonl\"}]\n", - " var page_size1f0c653a8f0140bb98abde83607cb8db = 5\n", - " var pageIndex1f0c653a8f0140bb98abde83607cb8db = 1\n", - " var paginatedElements1f0c653a8f0140bb98abde83607cb8db = []\n", - " var activeFilter1f0c653a8f0140bb98abde83607cb8db;\n", - "\n", - " function buildDropDownMenu(elements){\n", - " let init_filter;\n", - " let menu = document.getElementById('search-dropdown-content1f0c653a8f0140bb98abde83607cb8db')\n", - " if (elements.length > 0) {\n", - " let sample = elements[0]\n", - " for (const attr in sample) {\n", - " if (typeof init_filter === 'undefined'){\n", - " init_filter = attr;\n", - " }\n", - " let content = document.createElement('div');\n", - " content.onclick = function(event) {\n", - " event.stopPropagation()\n", - " document.getElementById('menu-active-filter1f0c653a8f0140bb98abde83607cb8db').innerText = attr;\n", - " activeFilter1f0c653a8f0140bb98abde83607cb8db = attr;\n", - " document.getElementById(\n", - " 'search-dropdown-content1f0c653a8f0140bb98abde83607cb8db'\n", - " ).style.display= 'none';\n", - " }\n", - " content.classList.add(\"dd-options\");\n", - " content.innerText = attr;\n", - " menu.appendChild(content);\n", - " }\n", - " } else {\n", - " let init_filter = '---'\n", - " }\n", - " let dropdown_field = document.getElementById('search-menu1f0c653a8f0140bb98abde83607cb8db')\n", - " let span = document.createElement('span')\n", - " span.setAttribute('id', 'menu-active-filter1f0c653a8f0140bb98abde83607cb8db')\n", - " span.innerText = init_filter\n", - " activeFilter1f0c653a8f0140bb98abde83607cb8db = init_filter;\n", - " dropdown_field.appendChild(span)\n", - " }\n", - "\n", - " buildDropDownMenu(element1f0c653a8f0140bb98abde83607cb8db)\n", - " </script>\n", - " </div>\n", - " <input id='searchKey1f0c653a8f0140bb98abde83607cb8db' class='search-input' placeholder='Enter search here ...' />\n", - " </div>\n", - " <button class='search-button' type=\"button\" onclick=\"searchGrid1f0c653a8f0140bb98abde83607cb8db(element1f0c653a8f0140bb98abde83607cb8db)\">\n", - " <svg width=\"11\" height=\"10\" viewBox=\"0 0 11 10\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><path d=\"M10.5652 9.23467L8.21819 6.88811C8.89846 6.07141 9.23767 5.02389 9.16527 3.96345C9.09287 2.90302 8.61443 1.91132 7.82948 1.19466C7.04453 0.477995 6.01349 0.0915414 4.95087 0.115691C3.88824 0.139841 2.87583 0.572735 2.12425 1.32432C1.37266 2.0759 0.939768 3.08831 0.915618 4.15094C0.891468 5.21357 1.27792 6.2446 1.99459 7.02955C2.71125 7.8145 3.70295 8.29294 4.76338 8.36535C5.82381 8.43775 6.87134 8.09853 7.68804 7.41827L10.0346 9.7653C10.0694 9.80014 10.1108 9.82778 10.1563 9.84663C10.2018 9.86549 10.2506 9.87519 10.2999 9.87519C10.3492 9.87519 10.398 9.86549 10.4435 9.84663C10.489 9.82778 10.5304 9.80014 10.5652 9.7653C10.6001 9.73046 10.6277 9.68909 10.6466 9.64357C10.6654 9.59805 10.6751 9.54926 10.6751 9.49998C10.6751 9.45071 10.6654 9.40192 10.6466 9.3564C10.6277 9.31088 10.6001 9.26951 10.5652 9.23467ZM1.67491 4.24998C1.67491 3.58247 1.87285 2.92995 2.2437 2.37493C2.61455 1.81992 3.14165 1.38734 3.75835 1.13189C4.37506 0.876446 5.05366 0.809609 5.70834 0.939835C6.36303 1.07006 6.96439 1.3915 7.4364 1.8635C7.9084 2.3355 8.22984 2.93687 8.36006 3.59155C8.49029 4.24624 8.42345 4.92484 8.168 5.54154C7.91256 6.15824 7.47998 6.68535 6.92496 7.05619C6.36995 7.42704 5.71742 7.62498 5.04991 7.62498C4.15511 7.62399 3.29724 7.26809 2.66452 6.63537C2.0318 6.00265 1.6759 5.14479 1.67491 4.24998Z\" fill=\"currentColor\"/></svg>\n", - " <span class='pl-8'>Search</span>\n", - " </button>\n", - " </div>\n", - "\n", - " <div><h4 id='total1f0c653a8f0140bb98abde83607cb8db'>0</h4></div>\n", - " </div>\n", - " <div id='table1f0c653a8f0140bb98abde83607cb8db' class='grid-table1f0c653a8f0140bb98abde83607cb8db' style='margin-top: 25px;'>\n", - " <script>\n", - " function paginate1f0c653a8f0140bb98abde83607cb8db(arr, size) {\n", - " const res = [];\n", - " for (let i = 0; i < arr.length; i += size) {\n", - " const chunk = arr.slice(i, i + size);\n", - " res.push(chunk);\n", - " }\n", - "\n", - " return res;\n", - " }\n", - "\n", - " function searchGrid1f0c653a8f0140bb98abde83607cb8db(elements){\n", - " let searchKey = document.getElementById('searchKey1f0c653a8f0140bb98abde83607cb8db').value;\n", - " let result;\n", - " if (searchKey === ''){\n", - " result = elements;\n", - " } else {\n", - " result = elements.filter((element) => {\n", - " let property = element[activeFilter1f0c653a8f0140bb98abde83607cb8db]\n", - " if (typeof property === 'object' && property !== null){\n", - " return property.value.toLowerCase().includes(searchKey.toLowerCase());\n", - " } else if (typeof property === 'string' ) {\n", - " return element[activeFilter1f0c653a8f0140bb98abde83607cb8db].toLowerCase().includes(searchKey.toLowerCase());\n", - " } else if (property !== null ) {\n", - " return element[activeFilter1f0c653a8f0140bb98abde83607cb8db].toString() === searchKey;\n", - " } else {\n", - " return element[activeFilter1f0c653a8f0140bb98abde83607cb8db] === searchKey;\n", - " }\n", - " } );\n", - " }\n", - " resetById1f0c653a8f0140bb98abde83607cb8db('table1f0c653a8f0140bb98abde83607cb8db');\n", - " resetById1f0c653a8f0140bb98abde83607cb8db('pag1f0c653a8f0140bb98abde83607cb8db');\n", - " result = paginate1f0c653a8f0140bb98abde83607cb8db(result, page_size1f0c653a8f0140bb98abde83607cb8db)\n", - " paginatedElements1f0c653a8f0140bb98abde83607cb8db = result\n", - " buildGrid1f0c653a8f0140bb98abde83607cb8db(result,pageIndex1f0c653a8f0140bb98abde83607cb8db);\n", - " buildPaginationContainer1f0c653a8f0140bb98abde83607cb8db(result);\n", - " }\n", - "\n", - " function resetById1f0c653a8f0140bb98abde83607cb8db(id){\n", - " let element = document.getElementById(id);\n", - " while (element.firstChild) {\n", - " element.removeChild(element.firstChild);\n", - " }\n", - " }\n", - "\n", - " function buildGrid1f0c653a8f0140bb98abde83607cb8db(items, pageIndex){\n", - " let headers = Object.keys(element1f0c653a8f0140bb98abde83607cb8db[0]);\n", - "\n", - " let grid = document.getElementById(\"table1f0c653a8f0140bb98abde83607cb8db\");\n", - " let div = document.createElement(\"div\");\n", - " div.classList.add('grid-header', 'grid-index-cells');\n", - " grid.appendChild(div);\n", - " headers.forEach((title) =>{\n", - " let div = document.createElement(\"div\");\n", - " div.classList.add('grid-header', 'grid-std-cells');\n", - " div.innerText = title;\n", - "\n", - " grid.appendChild(div);\n", - " });\n", - "\n", - " let page = items[pageIndex -1]\n", - " if (page !== 'undefine'){\n", - " let table_index1f0c653a8f0140bb98abde83607cb8db = ((pageIndex - 1) * page_size1f0c653a8f0140bb98abde83607cb8db)\n", - " page.forEach((item) => {\n", - " let grid = document.getElementById(\"table1f0c653a8f0140bb98abde83607cb8db\");\n", - " // Add new index value in index cells\n", - " let divIndex = document.createElement(\"div\");\n", - " divIndex.classList.add('grid-row', 'grid-index-cells');\n", - " divIndex.innerText = table_index1f0c653a8f0140bb98abde83607cb8db;\n", - " grid.appendChild(divIndex);\n", - "\n", - " // Iterate over the actual obj\n", - " for (const attr in item) {\n", - " let div = document.createElement(\"div\");\n", - " if (typeof item[attr] === 'object'\n", - " && item[attr] !== null\n", - " && item[attr].hasOwnProperty('type')) {\n", - " if (item[attr].type.includes('badge')){\n", - " let badge_div = document.createElement(\"div\");\n", - " badge_div.classList.add('badge',item[attr].type)\n", - " badge_div.innerText = String(item[attr].value).toUpperCase();\n", - " div.appendChild(badge_div);\n", - " div.classList.add('grid-row','grid-std-cells');\n", - " } else if (item[attr].type === \"clipboard\") {\n", - " div.classList.add('grid-row','grid-std-cells');\n", - "\n", - " // Create clipboard div\n", - " let clipboard_div = document.createElement('div');\n", - " clipboard_div.style.display= 'flex';\n", - " clipboard_div.classList.add(\"gap-10\")\n", - " clipboard_div.style.justifyContent = \"space-between\";\n", - "\n", - " let id_text = document.createElement('div');\n", - " if (item[attr].value == \"None\"){\n", - " id_text.innerText = \"None\";\n", - " }\n", - " else{\n", - " id_text.innerText = item[attr].value.slice(0,5) + \"...\";\n", - " }\n", - "\n", - " clipboard_div.appendChild(id_text);\n", - " let clipboard_img = document.createElement('div');\n", - " clipboard_img.classList.add(\"clipboard\")\n", - " div.onclick = function() {\n", - " navigator.clipboard.writeText(item[attr].value);\n", - " };\n", - " clipboard_img.innerHTML = \"<svg width='8' height='8' viewBox='0 0 8 8' fill='none' xmlns='http://www.w3.org/2000/svg'><path d='M7.4375 0.25H2.4375C2.35462 0.25 2.27513 0.282924 2.21653 0.341529C2.15792 0.400134 2.125 0.47962 2.125 0.5625V2.125H0.5625C0.47962 2.125 0.400134 2.15792 0.341529 2.21653C0.282924 2.27513 0.25 2.35462 0.25 2.4375V7.4375C0.25 7.52038 0.282924 7.59987 0.341529 7.65847C0.400134 7.71708 0.47962 7.75 0.5625 7.75H5.5625C5.64538 7.75 5.72487 7.71708 5.78347 7.65847C5.84208 7.59987 5.875 7.52038 5.875 7.4375V5.875H7.4375C7.52038 5.875 7.59987 5.84208 7.65847 5.78347C7.71708 5.72487 7.75 5.64538 7.75 5.5625V0.5625C7.75 0.47962 7.71708 0.400134 7.65847 0.341529C7.59987 0.282924 7.52038 0.25 7.4375 0.25ZM5.25 7.125H0.875V2.75H5.25V7.125ZM7.125 5.25H5.875V2.4375C5.875 2.35462 5.84208 2.27513 5.78347 2.21653C5.72487 2.15792 5.64538 2.125 5.5625 2.125H2.75V0.875H7.125V5.25Z' fill='#464158'/></svg>\";\n", - "\n", - " clipboard_div.appendChild(clipboard_img);\n", - " div.appendChild(clipboard_div);\n", - " }\n", - " } else{\n", - " div.classList.add('grid-row','grid-std-cells');\n", - " if (item[attr] == null) {\n", - " text = ' '\n", - " } else {\n", - " text = String(item[attr])\n", - " }\n", - " if (text.length > 150){\n", - " text = text.slice(0,150) + \"...\";\n", - " }\n", - " text = text.replaceAll(\"\\n\", \"</br>\");\n", - " div.innerHTML = text;\n", - " }\n", - " grid.appendChild(div);\n", - " }\n", - " table_index1f0c653a8f0140bb98abde83607cb8db = table_index1f0c653a8f0140bb98abde83607cb8db + 1;\n", - " })\n", - " }\n", - " }\n", - " paginatedElements1f0c653a8f0140bb98abde83607cb8db = paginate1f0c653a8f0140bb98abde83607cb8db(element1f0c653a8f0140bb98abde83607cb8db, page_size1f0c653a8f0140bb98abde83607cb8db)\n", - " buildGrid1f0c653a8f0140bb98abde83607cb8db(paginatedElements1f0c653a8f0140bb98abde83607cb8db, 1)\n", - " document.getElementById('total1f0c653a8f0140bb98abde83607cb8db').innerText = \"Total: \" + element1f0c653a8f0140bb98abde83607cb8db.length\n", - " </script>\n", - " </div>\n", - " <div id='pag1f0c653a8f0140bb98abde83607cb8db' class='paginationContainer'>\n", - " <script>\n", - " function buildPaginationContainer1f0c653a8f0140bb98abde83607cb8db(paginatedElements){\n", - " let pageContainer = document.getElementById(\"pag1f0c653a8f0140bb98abde83607cb8db\");\n", - " for (let i = 0; i < paginatedElements.length; i++) {\n", - " let div = document.createElement(\"div\");\n", - " div.classList.add('page');\n", - " if(i===0) div.style.color = \"gray\";\n", - " else div.style.color = 'var(--tertiary-color, \"gray\")';\n", - " div.onclick = function(event) {\n", - " let indexes = document.getElementsByClassName('page');\n", - " for (let index of indexes) { index.style.color = 'var(--tertiary-color, \"gray\")' }\n", - " event.target.style.color = \"gray\";\n", - " setPage1f0c653a8f0140bb98abde83607cb8db(i + 1);\n", - " };\n", - " div.innerText = i + 1;\n", - " pageContainer.appendChild(div);\n", - " }\n", - " }\n", - "\n", - " function setPage1f0c653a8f0140bb98abde83607cb8db(newPage){\n", - " pageIndex = newPage\n", - " resetById1f0c653a8f0140bb98abde83607cb8db('table1f0c653a8f0140bb98abde83607cb8db')\n", - " buildGrid1f0c653a8f0140bb98abde83607cb8db(paginatedElements1f0c653a8f0140bb98abde83607cb8db, pageIndex)\n", - " }\n", - "\n", - " buildPaginationContainer1f0c653a8f0140bb98abde83607cb8db(paginatedElements1f0c653a8f0140bb98abde83607cb8db)\n", - " </script>\n", - " </div>\n", - " </div>\n", - " </div>\n", - " </div>\n" - ], - "text/plain": [ - "[syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile,\n", - " syft.types.blob_storage.BlobFile]" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "files" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [], - "source": [ - "file = [f for f in files if f.file_name == \"test.json\"][0]" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "```python\n", - "class BlobFile:\n", - " id: str = cb2ef738082c49418ed70eb05a193770\n", - " file_name: str = \"test.json\"\n", - "\n", - "```" - ], - "text/plain": [ - "syft.types.blob_storage.BlobFile" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "file" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "b'{\\n\"abc\": \"def\"\\n}'" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "file.read()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/tutorials/data-scientist/05-syft-functions.ipynb b/notebooks/tutorials/data-scientist/05-syft-functions.ipynb index a04ff15d24d..7f426a596ba 100644 --- a/notebooks/tutorials/data-scientist/05-syft-functions.ipynb +++ b/notebooks/tutorials/data-scientist/05-syft-functions.ipynb @@ -641,7 +641,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb b/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb index 66472c77fb5..3764f475b49 100644 --- a/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb +++ b/notebooks/tutorials/enclaves/Enclave-single-notebook-DO-DS.ipynb @@ -682,6 +682,14 @@ "source": [ "assert result_ptr.syft_action_data == 813" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d632521", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb b/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb index 4df99dec84a..621f98a3ac2 100644 --- a/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb +++ b/notebooks/tutorials/model-auditing/colab/01-user-log.ipynb @@ -687,7 +687,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.16" }, "toc": { "base_numbering": 1, diff --git a/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb b/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb index 58826a8ac4e..71d80d51230 100644 --- a/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb +++ b/notebooks/tutorials/model-training/00-data-owner-upload-data.ipynb @@ -386,7 +386,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.16" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true } }, "nbformat": 4, diff --git a/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb b/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb index e45a0124ec3..c0266dc6d3d 100644 --- a/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb +++ b/notebooks/tutorials/model-training/01-data-scientist-submit-code.ipynb @@ -558,7 +558,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.16" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true } }, "nbformat": 4, diff --git a/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb b/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb index 7834a406012..bd6ed479b72 100644 --- a/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb +++ b/notebooks/tutorials/model-training/02-data-owner-review-approve-code.ipynb @@ -103,7 +103,7 @@ "outputs": [], "source": [ "# gettting a reference to the user code object\n", - "user_code = change.link\n", + "user_code = change.code\n", "\n", "# viewing the actual code submitted for request\n", "user_code.show_code" @@ -304,7 +304,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.16" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true } }, "nbformat": 4, diff --git a/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb b/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb index 1a758c200be..560069b172e 100644 --- a/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb +++ b/notebooks/tutorials/model-training/03-data-scientist-download-results.ipynb @@ -246,7 +246,20 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.9.16" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true } }, "nbformat": 4, diff --git a/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb b/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb index 577c10204b6..4088bda8a55 100644 --- a/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb +++ b/notebooks/tutorials/pandas-cookbook/01-reading-from-a-csv.ipynb @@ -699,6 +699,16 @@ "request = project_notification.link.events[0].request" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6508050f", + "metadata": {}, + "outputs": [], + "source": [ + "func = request.code" + ] + }, { "cell_type": "code", "execution_count": null, @@ -708,8 +718,8 @@ }, "outputs": [], "source": [ - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "# func = request.code\n", + "#" ] }, { @@ -853,6 +863,14 @@ "source": [ "node.land()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e80dab85", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb b/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb index c8fa334cdce..8d887116aa9 100644 --- a/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb +++ b/notebooks/tutorials/pandas-cookbook/02-selecting-data-finding-common-complain.ipynb @@ -889,8 +889,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb b/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb index 8e38b213bb2..467f5f02873 100644 --- a/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb +++ b/notebooks/tutorials/pandas-cookbook/03-which-borough-has-the-most-noise-complaints.ipynb @@ -1003,8 +1003,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb b/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb index d150435e3cf..22fb760c644 100644 --- a/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb +++ b/notebooks/tutorials/pandas-cookbook/04-weekday-bike-most-groupby-aggregate.ipynb @@ -775,8 +775,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb b/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb index 31b6a07756b..e3600391853 100644 --- a/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb +++ b/notebooks/tutorials/pandas-cookbook/05-combining-dataframes-scraping-weather-data.ipynb @@ -962,8 +962,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb b/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb index db66e1497ef..67d1f2ce4b3 100644 --- a/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb +++ b/notebooks/tutorials/pandas-cookbook/06-string-operations-which-month-was-the-snowiest.ipynb @@ -865,8 +865,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb b/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb index 479453fe34d..59b6b5a3dc7 100644 --- a/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb +++ b/notebooks/tutorials/pandas-cookbook/07-cleaning-up-messy-data.ipynb @@ -920,8 +920,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb b/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb index bcf0160ea84..bc0268c923f 100644 --- a/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb +++ b/notebooks/tutorials/pandas-cookbook/08-how-to-deal-with-timestamps.ipynb @@ -871,8 +871,7 @@ "outputs": [], "source": [ "request = project_notification.link.events[0].request\n", - "func = request.changes[0].link\n", - "op = func.output_policy_type" + "func = request.code" ] }, { diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index e0835c0c5af..8a19a55d284 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -20,6 +20,7 @@ ARG UID # Setup Python DEV RUN --mount=type=cache,target=/var/cache/apk,sharing=locked \ apk update && \ + apk upgrade && \ apk add build-base gcc tzdata python-$PYTHON_VERSION-dev py$PYTHON_VERSION-pip && \ ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone # uncomment for creating rootless user @@ -73,6 +74,7 @@ ARG USER_GRP # Setup Python RUN --mount=type=cache,target=/var/cache/apk,sharing=locked \ apk update && \ + apk upgrade && \ apk add tzdata git bash python-$PYTHON_VERSION py$PYTHON_VERSION-pip && \ ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone && \ # Uncomment for rootless user @@ -92,12 +94,10 @@ ENV PATH=$PATH:$HOME/.local/bin \ SERVICE_NAME="backend" \ RELEASE="production" \ DEV_MODE="False" \ + DEBUGGER_ENABLED="False" \ CONTAINER_HOST="docker" \ - PORT=80\ - HTTP_PORT=80 \ - HTTPS_PORT=443 \ - DOMAIN_CONNECTION_PORT=3030 \ - IGNORE_TLS_ERRORS="False" \ + OBLV_ENABLED="False" \ + OBLV_LOCALHOST_PORT=3030 \ DEFAULT_ROOT_EMAIL="info@openmined.org" \ DEFAULT_ROOT_PASSWORD="changethis" \ STACK_API_KEY="changeme" \ diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 9f13b7b54fd..8081a603967 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -2,17 +2,19 @@ import os import secrets from typing import Any -from typing import Dict from typing import List from typing import Optional from typing import Union # third party from pydantic import AnyHttpUrl -from pydantic import BaseSettings from pydantic import EmailStr from pydantic import HttpUrl -from pydantic import validator +from pydantic import field_validator +from pydantic import model_validator +from pydantic_settings import BaseSettings +from pydantic_settings import SettingsConfigDict +from typing_extensions import Self _truthy = {"yes", "y", "true", "t", "on", "1"} _falsy = {"no", "n", "false", "f", "off", "0"} @@ -50,7 +52,8 @@ class Settings(BaseSettings): # "http://localhost:8080", "http://local.dockertoolbox.tiangolo.com"]' BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] - @validator("BACKEND_CORS_ORIGINS", pre=True) + @field_validator("BACKEND_CORS_ORIGINS", mode="before") + @classmethod def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: if isinstance(v, str) and not v.startswith("["): return [i.strip() for i in v.split(",")] @@ -62,25 +65,22 @@ def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str SENTRY_DSN: Optional[HttpUrl] = None - @validator("SENTRY_DSN", pre=True) + @field_validator("SENTRY_DSN", mode="before") + @classmethod def sentry_dsn_can_be_blank(cls, v: str) -> Optional[str]: if v is None or len(v) == 0: return None return v - SMTP_TLS: bool = True - SMTP_PORT: Optional[int] = None - SMTP_HOST: Optional[str] = None - SMTP_USER: Optional[str] = None - SMTP_PASSWORD: Optional[str] = None EMAILS_FROM_EMAIL: Optional[EmailStr] = None EMAILS_FROM_NAME: Optional[str] = None - @validator("EMAILS_FROM_NAME") - def get_project_name(cls, v: Optional[str], values: Dict[str, Any]) -> str: - if not v: - return values["PROJECT_NAME"] - return v + @model_validator(mode="after") + def get_project_name(self) -> Self: + if not self.EMAILS_FROM_NAME: + self.EMAILS_FROM_NAME = self.PROJECT_NAME + + return self EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48 EMAIL_TEMPLATES_DIR: str = os.path.expandvars( @@ -88,15 +88,15 @@ def get_project_name(cls, v: Optional[str], values: Dict[str, Any]) -> str: ) EMAILS_ENABLED: bool = False - @validator("EMAILS_ENABLED", pre=True) - def get_emails_enabled(cls, v: bool, values: Dict[str, Any]) -> bool: - return bool( - values.get("SMTP_HOST") - and values.get("SMTP_PORT") - and values.get("EMAILS_FROM_EMAIL") + @model_validator(mode="after") + def get_emails_enabled(self) -> Self: + self.EMAILS_ENABLED = bool( + self.SMTP_HOST and self.SMTP_PORT and self.EMAILS_FROM_EMAIL ) - DEFAULT_ROOT_EMAIL: EmailStr = EmailStr("info@openmined.org") + return self + + DEFAULT_ROOT_EMAIL: EmailStr = "info@openmined.org" DEFAULT_ROOT_PASSWORD: str = "changethis" USERS_OPEN_REGISTRATION: bool = False @@ -144,14 +144,18 @@ def get_emails_enabled(cls, v: bool, values: Dict[str, Any]) -> bool: SINGLE_CONTAINER_MODE: bool = str_to_bool(os.getenv("SINGLE_CONTAINER_MODE", False)) CONSUMER_SERVICE_NAME: Optional[str] = os.getenv("CONSUMER_SERVICE_NAME") INMEMORY_WORKERS: bool = str_to_bool(os.getenv("INMEMORY_WORKERS", True)) + SMTP_USERNAME: str = os.getenv("SMTP_USERNAME", "") + EMAIL_SENDER: str = os.getenv("EMAIL_SENDER", "") + SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "") + SMTP_TLS: bool = True + SMTP_PORT: Optional[str] = os.getenv("SMTP_PORT", "") + SMTP_HOST: Optional[str] = os.getenv("SMTP_HOST", "") TEST_MODE: bool = ( True if os.getenv("TEST_MODE", "false").lower() == "true" else False ) ASSOCIATION_TIMEOUT: int = 10 - - class Config: - case_sensitive = True + model_config = SettingsConfigDict(case_sensitive=True) settings = Settings() diff --git a/packages/grid/backend/grid/core/node.py b/packages/grid/backend/grid/core/node.py index f2f61187597..89010e661dd 100644 --- a/packages/grid/backend/grid/core/node.py +++ b/packages/grid/backend/grid/core/node.py @@ -94,4 +94,9 @@ def seaweedfs_config() -> SeaweedFSConfig: queue_config=queue_config, migrate=True, in_memory_workers=settings.INMEMORY_WORKERS, + smtp_username=settings.SMTP_USERNAME, + smtp_password=settings.SMTP_PASSWORD, + email_sender=settings.EMAIL_SENDER, + smtp_port=settings.SMTP_PORT, + smtp_host=settings.SMTP_HOST, ) diff --git a/packages/grid/backend/grid/logger/config.py b/packages/grid/backend/grid/logger/config.py index 7c5ea9ddb41..5f2376a9615 100644 --- a/packages/grid/backend/grid/logger/config.py +++ b/packages/grid/backend/grid/logger/config.py @@ -11,7 +11,7 @@ from typing import Union # third party -from pydantic import BaseSettings +from pydantic_settings import BaseSettings # LOGURU_LEVEL type for version>3.8 @@ -40,14 +40,9 @@ class LogConfig(BaseSettings): LOGURU_LEVEL: str = LogLevel.INFO.value LOGURU_SINK: Optional[str] = "/var/log/pygrid/grid.log" - LOGURU_COMPRESSION: Optional[str] - LOGURU_ROTATION: Union[ - Optional[str], - Optional[int], - Optional[time], - Optional[timedelta], - ] - LOGURU_RETENTION: Union[Optional[str], Optional[int], Optional[timedelta]] + LOGURU_COMPRESSION: Optional[str] = None + LOGURU_ROTATION: Union[str, int, time, timedelta, None] = None + LOGURU_RETENTION: Union[str, int, timedelta, None] = None LOGURU_COLORIZE: Optional[bool] = True LOGURU_SERIALIZE: Optional[bool] = False LOGURU_BACKTRACE: Optional[bool] = True diff --git a/packages/grid/backend/grid/start.sh b/packages/grid/backend/grid/start.sh index 6b65070f981..2880800eee4 100755 --- a/packages/grid/backend/grid/start.sh +++ b/packages/grid/backend/grid/start.sh @@ -8,10 +8,12 @@ APP_MODULE=grid.main:app LOG_LEVEL=${LOG_LEVEL:-info} HOST=${HOST:-0.0.0.0} PORT=${PORT:-80} -RELOAD="" NODE_TYPE=${NODE_TYPE:-domain} APPDIR=${APPDIR:-$HOME/app} +RELOAD="" +DEBUG_CMD="" + # For debugging permissions ls -lisa $HOME/data ls -lisa $APPDIR/syft/ @@ -24,6 +26,13 @@ then pip install --user -e "$APPDIR/syft[telemetry,data_science]" fi +# only set by kubernetes to avoid conflict with docker tests +if [[ ${DEBUGGER_ENABLED} == "True" ]]; +then + pip install --user debugpy + DEBUG_CMD="python -m debugpy --listen 0.0.0.0:5678 -m" +fi + set +e export NODE_PRIVATE_KEY=$(python $APPDIR/grid/bootstrap.py --private_key) export NODE_UID=$(python $APPDIR/grid/bootstrap.py --uid) @@ -33,4 +42,4 @@ set -e echo "NODE_UID=$NODE_UID" echo "NODE_TYPE=$NODE_TYPE" -exec uvicorn $RELOAD --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" +exec $DEBUG_CMD uvicorn $RELOAD --host $HOST --port $PORT --log-level $LOG_LEVEL "$APP_MODULE" diff --git a/packages/grid/backend/install_oblivious.sh b/packages/grid/backend/install_oblivious.sh index d383e4fbd94..486b812ac9e 100755 --- a/packages/grid/backend/install_oblivious.sh +++ b/packages/grid/backend/install_oblivious.sh @@ -2,7 +2,7 @@ echo "Running install_oblivious.sh with RELEASE=${RELEASE}" -if [[ ("${ENABLE_OBLV}" == "true") && ("${SERVICE_NAME}" == "backend" || "${SERVICE_NAME}" == "celeryworker" ) ]]; then +if [[ ("${OBLV_ENABLED}" == "true") && ("${SERVICE_NAME}" == "backend" || "${SERVICE_NAME}" == "celeryworker" ) ]]; then echo "Allowed to install Oblv CLI" # Oblivious Proxy Client Installation mkdir -p oblv-ccli-0.4.0-x86_64-unknown-linux-musl @@ -11,5 +11,5 @@ if [[ ("${ENABLE_OBLV}" == "true") && ("${SERVICE_NAME}" == "backend" || "${ ln -sf $(pwd)/oblv-ccli-0.4.0-x86_64-unknown-linux-musl/oblv /usr/local/bin/oblv #-f is for force echo "Installed Oblivious CLI: $(/usr/local/bin/oblv --version)" else - echo "Oblivious CLI not installed ENABLE_OBLV:${ENABLE_OBLV} , SERVICE_NAME:${SERVICE_NAME} " + echo "Oblivious CLI not installed OBLV_ENABLED:${OBLV_ENABLED} , SERVICE_NAME:${SERVICE_NAME} " fi diff --git a/packages/grid/default.env b/packages/grid/default.env index d599e47cf4e..7ba01f7a770 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -43,9 +43,9 @@ DEFAULT_ROOT_PASSWORD=changethis SMTP_TLS=True SMTP_PORT=587 SMTP_HOST= -SMTP_USER= +SMTP_USERNAME= SMTP_PASSWORD= -EMAILS_FROM_EMAIL=info@openmined.org +EMAIL_SENDER= SERVER_HOST="https://${DOMAIN}" NETWORK_CHECK_INTERVAL=60 DOMAIN_CHECK_INTERVAL=60 @@ -61,7 +61,7 @@ INMEMORY_WORKERS=True USE_NEW_SERVICE=False # Frontend -VITE_PUBLIC_API_BASE_URL="/api/v2" +BACKEND_API_BASE_URL="/api/v2" # SeaweedFS S3_ENDPOINT="seaweedfs" @@ -109,9 +109,9 @@ NODE_SIDE_TYPE=high USE_BLOB_STORAGE=False #Oblivious -ENABLE_OBLV=false +OBLV_ENABLED=false OBLV_KEY_PATH="~/.oblv" -DOMAIN_CONNECTION_PORT=3030 +OBLV_LOCALHOST_PORT=3030 # Registation ENABLE_SIGNUP=False diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 44b26219e39..ac81a14cccc 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -60,21 +60,21 @@ images: deployments: syft: helm: + releaseName: syft-dev chart: name: ./helm/syft values: - syft: + global: registry: ${CONTAINER_REGISTRY} version: dev-${DEVSPACE_TIMESTAMP} + useDefaultSecrets: true registry: - maxStorage: "5Gi" + storageSize: "5Gi" node: - settings: - nodeName: ${NODE_NAME} - nodeType: "domain" - defaultWorkerPoolCount: 1 - configuration: - devmode: True + name: ${NODE_NAME} + rootEmail: info@openmined.org + defaultWorkerPoolCount: 1 + resourcesPreset: micro dev: mongo: @@ -97,16 +97,22 @@ dev: app.kubernetes.io/name: syft app.kubernetes.io/component: backend env: + - name: RELEASE + value: development - name: DEV_MODE value: "True" - logs: {} + - name: DEBUGGER_ENABLED + value: "True" + ports: + - port: "5678" # debugpy sync: - path: ./backend/grid:/root/app/grid - path: ../syft:/root/app/syft + ssh: {} profiles: - name: gateway patches: - op: replace - path: deployments.syft.helm.values.node.settings.nodeType + path: deployments.syft.helm.values.node.type value: "gateway" diff --git a/packages/grid/docker-compose.yml b/packages/grid/docker-compose.yml index 9d4be113cb6..07615ebb787 100644 --- a/packages/grid/docker-compose.yml +++ b/packages/grid/docker-compose.yml @@ -51,7 +51,7 @@ services: - PORT=80 - HTTP_PORT=${HTTP_PORT} - HTTPS_PORT=${HTTPS_PORT} - - VITE_PUBLIC_API_BASE_URL=${VITE_PUBLIC_API_BASE_URL} + - BACKEND_API_BASE_URL=${BACKEND_API_BASE_URL} extra_hosts: - "host.docker.internal:host-gateway" labels: @@ -149,8 +149,8 @@ services: - JAEGER_PORT=${JAEGER_PORT} - ASSOCIATION_TIMEOUT=${ASSOCIATION_TIMEOUT} - DEV_MODE=${DEV_MODE} - - DOMAIN_CONNECTION_PORT=${DOMAIN_CONNECTION_PORT} - - ENABLE_OBLV=${ENABLE_OBLV} + - OBLV_LOCALHOST_PORT=${OBLV_LOCALHOST_PORT} + - OBLV_ENABLED=${OBLV_ENABLED} - DEFAULT_ROOT_EMAIL=${DEFAULT_ROOT_EMAIL} - DEFAULT_ROOT_PASSWORD=${DEFAULT_ROOT_PASSWORD} - BACKEND_STORAGE_PATH=${BACKEND_STORAGE_PATH} @@ -196,8 +196,8 @@ services: # - JAEGER_HOST=${JAEGER_HOST} # - JAEGER_PORT=${JAEGER_PORT} # - DEV_MODE=${DEV_MODE} - # - DOMAIN_CONNECTION_PORT=${DOMAIN_CONNECTION_PORT} - # - ENABLE_OBLV=${ENABLE_OBLV} + # - OBLV_LOCALHOST_PORT=${OBLV_LOCALHOST_PORT} + # - OBLV_ENABLED=${OBLV_ENABLED} # network_mode: service:proxy # volumes: # - credentials-data:/root/data/creds/ @@ -230,8 +230,8 @@ services: # - JAEGER_HOST=${JAEGER_HOST} # - JAEGER_PORT=${JAEGER_PORT} # - DEV_MODE=${DEV_MODE} - # - DOMAIN_CONNECTION_PORT=${DOMAIN_CONNECTION_PORT} - # - ENABLE_OBLV=${ENABLE_OBLV} + # - OBLV_LOCALHOST_PORT=${OBLV_LOCALHOST_PORT} + # - OBLV_ENABLED=${OBLV_ENABLED} # command: "/app/grid/worker-start.sh" # network_mode: service:proxy # volumes: diff --git a/packages/grid/frontend/.env.example b/packages/grid/frontend/.env.example index fc65e495bc3..21b1621cc43 100644 --- a/packages/grid/frontend/.env.example +++ b/packages/grid/frontend/.env.example @@ -1 +1 @@ -VITE_PUBLIC_API_BASE_URL=https://localhost/api/v2 +BACKEND_API_BASE_URL=https://localhost/api/v2 diff --git a/packages/grid/frontend/frontend.dockerfile b/packages/grid/frontend/frontend.dockerfile index 788008a6a87..f05aae7e410 100644 --- a/packages/grid/frontend/frontend.dockerfile +++ b/packages/grid/frontend/frontend.dockerfile @@ -1,10 +1,11 @@ -FROM node:18-alpine as base +FROM cgr.dev/chainguard/wolfi-base as base -ARG VITE_PUBLIC_API_BASE_URL -ENV VITE_PUBLIC_API_BASE_URL ${VITE_PUBLIC_API_BASE_URL} -ENV NODE_TYPE domain +ARG BACKEND_API_BASE_URL="/api/v2/" +ENV BACKEND_API_BASE_URL ${BACKEND_API_BASE_URL} -RUN apk update && apk upgrade --available +RUN apk update && \ + apk upgrade && \ + apk add --no-cache nodejs-20 pnpm corepack WORKDIR /app diff --git a/packages/grid/frontend/package.json b/packages/grid/frontend/package.json index 5157edd0353..2c7e7c31ca2 100644 --- a/packages/grid/frontend/package.json +++ b/packages/grid/frontend/package.json @@ -5,7 +5,7 @@ "scripts": { "dev": "pnpm i && vite dev --host --port 80", "build": "vite build", - "preview": "vite preview --host --port 80", + "preview": "vite preview --host 0.0.0.0 --port 80", "lint": "prettier --plugin-search-dir . --check . && eslint .", "test:e2e": "playwright test", "test:unit": "vitest run", diff --git a/packages/grid/frontend/pnpm-lock.yaml b/packages/grid/frontend/pnpm-lock.yaml index 3d7b6901efe..087164903a7 100644 --- a/packages/grid/frontend/pnpm-lock.yaml +++ b/packages/grid/frontend/pnpm-lock.yaml @@ -552,7 +552,7 @@ packages: sirv: 2.0.3 svelte: 3.59.2 tiny-glob: 0.2.9 - undici: 5.27.0 + undici: 6.6.2 vite: 4.5.2(@types/node@20.8.2) transitivePeerDependencies: - supports-color @@ -2519,9 +2519,9 @@ packages: resolution: {integrity: sha512-uY/99gMLIOlJPwATcMVYfqDSxUR9//AUcgZMzwfSTJPDKzA1S8mX4VLqa+fiAtveraQUBCz4FFcwVZBGbwBXIw==} dev: true - /undici@5.27.0: - resolution: {integrity: sha512-l3ydWhlhOJzMVOYkymLykcRRXqbUaQriERtR70B9LzNkZ4bX52Fc8wbTDneMiwo8T+AemZXvXaTx+9o5ROxrXg==} - engines: {node: '>=14.0'} + /undici@6.6.2: + resolution: {integrity: sha512-vSqvUE5skSxQJ5sztTZ/CdeJb1Wq0Hf44hlYMciqHghvz+K88U0l7D6u1VsndoFgskDcnU+nG3gYmMzJVzd9Qg==} + engines: {node: '>=18.0'} dependencies: '@fastify/busboy': 2.0.0 dev: true diff --git a/packages/grid/frontend/src/lib/constants.ts b/packages/grid/frontend/src/lib/constants.ts index d7d6bc5053f..4aac3d7936c 100644 --- a/packages/grid/frontend/src/lib/constants.ts +++ b/packages/grid/frontend/src/lib/constants.ts @@ -1,5 +1,4 @@ -export const API_BASE_URL = - import.meta.env.VITE_PUBLIC_API_BASE_URL || "/api/v2" +export const API_BASE_URL = import.meta.env.BACKEND_API_BASE_URL || "/api/v2" export const syftRoles = { 1: "Guest", diff --git a/packages/grid/frontend/src/routes/health/+page.svelte b/packages/grid/frontend/src/routes/health/+page.svelte new file mode 100644 index 00000000000..0c47afa24a3 --- /dev/null +++ b/packages/grid/frontend/src/routes/health/+page.svelte @@ -0,0 +1 @@ +<p>ok</p> diff --git a/packages/grid/helm/kubelinter-config.yaml b/packages/grid/helm/kubelinter-config.yaml new file mode 100644 index 00000000000..f3f18327888 --- /dev/null +++ b/packages/grid/helm/kubelinter-config.yaml @@ -0,0 +1,15 @@ +checks: + addAllBuiltIn: true + exclude: + - "access-to-create-pods" + - "default-service-account" + - "dnsconfig-options" + - "minimum-three-replicas" + - "no-node-affinity" + - "non-isolated-pod" + - "privileged-ports" + - "read-secret-from-env-var" + - "required-annotation-email" + - "required-label-owner" + - "run-as-non-root" + - "use-namespace" diff --git a/packages/grid/helm/syft/Chart.yaml b/packages/grid/helm/syft/Chart.yaml index f5d50941c01..9dd2c701980 100644 --- a/packages/grid/helm/syft/Chart.yaml +++ b/packages/grid/helm/syft/Chart.yaml @@ -4,4 +4,5 @@ description: Perform numpy-like analysis on data that remains in someone elses s type: application version: "0.8.5-beta.1" appVersion: "0.8.5-beta.1" -icon: https://raw.githubusercontent.com/OpenMined/PySyft/dev/docs/img/title_syft_light.png \ No newline at end of file +home: https://github.com/OpenMined/PySyft/ +icon: https://raw.githubusercontent.com/OpenMined/PySyft/dev/docs/img/title_syft_light.png diff --git a/packages/grid/helm/syft/templates/_labels.tpl b/packages/grid/helm/syft/templates/_labels.tpl new file mode 100644 index 00000000000..23f0b8f07f5 --- /dev/null +++ b/packages/grid/helm/syft/templates/_labels.tpl @@ -0,0 +1,31 @@ +{{/* +Common Chart Name +Usage: + {{- include "common.chartname" . }} +*/}} +{{- define "common.chartname" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" -}} +{{- end -}} + +{{/* +Common labels for all resources +Usage: + {{- include "common.chartname" . | indent 4}} +*/}} +{{- define "common.labels" -}} +app.kubernetes.io/name: {{ .Chart.Name }} +app.kubernetes.io/version: {{ .Chart.Version }} +app.kubernetes.io/instance: {{ .Release.Name }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +helm.sh/chart: {{ include "common.chartname" . }} +{{- end -}} + +{{/* +Common labels for all resources +Usage: + {{- include "common.selectorLabels" . | indent 4}} +*/}} +{{- define "common.selectorLabels" -}} +app.kubernetes.io/name: {{ .Chart.Name }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end -}} diff --git a/packages/grid/helm/syft/templates/_resources.tpl b/packages/grid/helm/syft/templates/_resources.tpl new file mode 100644 index 00000000000..edff154a8dd --- /dev/null +++ b/packages/grid/helm/syft/templates/_resources.tpl @@ -0,0 +1,74 @@ +{{/* +Pod resource limit presets + +Usage: + {{- include "common.resources.preset" (dict "type" "nano") }} + +Params: + type - String (Required) - One of resource presets: nano, micro, small, medium, large, xlarge, 2xlarge, 4xlarge +*/}} +{{- define "common.resources.preset" -}} +{{- $presets := dict + "nano" (dict + "requests" (dict "cpu" "100m" "memory" "128Mi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "200m" "memory" "256Mi" "ephemeral-storage" "1Gi") + ) + "micro" (dict + "requests" (dict "cpu" "250m" "memory" "256Mi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "500m" "memory" "512Mi" "ephemeral-storage" "1Gi") + ) + "small" (dict + "requests" (dict "cpu" "500m" "memory" "512Mi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "1.0" "memory" "1Gi" "ephemeral-storage" "1Gi") + ) + "medium" (dict + "requests" (dict "cpu" "500m" "memory" "1Gi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "1.0" "memory" "2Gi" "ephemeral-storage" "1Gi") + ) + "large" (dict + "requests" (dict "cpu" "1.0" "memory" "2Gi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "2.0" "memory" "4Gi" "ephemeral-storage" "1Gi") + ) + "xlarge" (dict + "requests" (dict "cpu" "2.0" "memory" "4Gi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "4.0" "memory" "8Gi" "ephemeral-storage" "1Gi") + ) + "2xlarge" (dict + "requests" (dict "cpu" "4.0" "memory" "8Gi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "8.0" "memory" "16Gi" "ephemeral-storage" "1Gi") + ) + "4xlarge" (dict + "requests" (dict "cpu" "8.0" "memory" "16Gi" "ephemeral-storage" "50Mi") + "limits" (dict "cpu" "16.0" "memory" "32Gi" "ephemeral-storage" "1Gi") + ) + }} +{{- if hasKey $presets .type -}} +{{- index $presets .type | toYaml -}} +{{- else -}} +{{- printf "ERROR: Preset key '%s' invalid. Allowed values are %s" .type (join "," (keys $presets)) | fail -}} +{{- end -}} +{{- end -}} + + +{{/* +Set resource limits based on preset or custom values. If both are provided, custom values take precedence. +Defaults to empty limits and requests. + +Usage: + resources: {{ include "common.resources.set" (dict "preset" "nano") | indent 4 }} + resources: {{ include "common.resources.set" (dict "resources" (dict "cpu" "100m" "memory" "128Mi")) | indent 12 }} + resources: {{ include "common.resources.set" (dict "preset" "nano" "resources" (dict "cpu" "100m" "memory" "128Mi")) | indent 2 }} + +Params: + resources - Dict (Optional) - Custom resources values + preset - String (Optional) - One of resource presets: nano, micro, small, medium, large, xlarge, 2xlarge, 4xlarge +*/}} +{{- define "common.resources.set" -}} +{{- if .resources -}} + {{- .resources | toYaml -}} +{{- else if .preset -}} + {{- include "common.resources.preset" (dict "type" .preset) -}} +{{- else -}} + {{- (dict "requests" nil "limits" nil) | toYaml -}} +{{- end -}} +{{- end -}} diff --git a/packages/grid/helm/syft/templates/_secrets.tpl b/packages/grid/helm/syft/templates/_secrets.tpl new file mode 100644 index 00000000000..4d0ad6bd153 --- /dev/null +++ b/packages/grid/helm/syft/templates/_secrets.tpl @@ -0,0 +1,55 @@ +{{/* +Lookup value from an existing secret. WILL NOT base64 decode the value. + +Usage: + {{- include "common.secrets.get" (dict "secret" "some-secret-name" "key" "keyName" "context" $) }} + +Params: + secret - String (Required) - Name of the 'Secret' resource where the key is stored. + key - String - (Required) - Name of the key in the secret. + context - Context (Required) - Parent context. +*/}} +{{- define "common.secrets.get" -}} + {{- $value := "" -}} + {{- $secretData := (lookup "v1" "Secret" .context.Release.Namespace .secret).data -}} + + {{- if and $secretData (hasKey $secretData .key) -}} + {{- $value = index $secretData .key -}} + {{- end -}} + + {{- if $value -}} + {{- printf "%s" $value -}} + {{- end -}} + +{{- end -}} + +{{/* +Re-use or set a new randomly generated secret value from an existing secret. +If global.useDefaultSecrets is set to true, the default value will be used if the secret does not exist. + +Usage: + {{- include "common.secrets.set " (dict "secret" "some-secret-name" "default" "default-value" "context" $ ) }} + +Params: + secret - String (Required) - Name of the 'Secret' resource where the key is stored. + key - String - (Required) - Name of the key in the secret. + default - String - (Optional) - Default value to use if the secret does not exist. + length - Int - (Optional) - The length of the generated secret. Default is 32. + context - Context (Required) - Parent context. +*/}} +{{- define "common.secrets.set" -}} + {{- $secretVal := "" -}} + {{- $existingSecret := include "common.secrets.get" (dict "secret" .secret "key" .key "context" .context ) | default "" -}} + + {{- if $existingSecret -}} + {{- $secretVal = $existingSecret -}} + {{- else if .context.Values.global.useDefaultSecrets -}} + {{- $secretVal = .default | b64enc -}} + {{- else -}} + {{- $length := .length | default 32 -}} + {{- $secretVal = randAlphaNum $length | b64enc -}} + {{- end -}} + + {{- printf "%s" $secretVal -}} + +{{- end -}} diff --git a/packages/grid/helm/syft/templates/backend-headless-service.yaml b/packages/grid/helm/syft/templates/backend-headless-service.yaml deleted file mode 100644 index 47078fb5af2..00000000000 --- a/packages/grid/helm/syft/templates/backend-headless-service.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: backend-headless -spec: - clusterIP: None - ports: - - name: web - port: 80 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: backend - app.kubernetes.io/managed-by: Helm diff --git a/packages/grid/helm/syft/templates/backend-service.yaml b/packages/grid/helm/syft/templates/backend-service.yaml deleted file mode 100644 index 22df9b615da..00000000000 --- a/packages/grid/helm/syft/templates/backend-service.yaml +++ /dev/null @@ -1,22 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: backend - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm -spec: - externalIPs: null - ports: - - name: web - port: 80 - protocol: TCP - targetPort: 80 - - name: queue - port: {{ .Values.queue.port }} - protocol: TCP - targetPort: {{ .Values.queue.port }} - selector: - app.kubernetes.io/component: backend - type: ClusterIP diff --git a/packages/grid/helm/syft/templates/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend-statefulset.yaml deleted file mode 100644 index ebebae80cc2..00000000000 --- a/packages/grid/helm/syft/templates/backend-statefulset.yaml +++ /dev/null @@ -1,164 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: backend - app.kubernetes.io/managed-by: Helm - name: backend -spec: - podManagementPolicy: OrderedReady - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: backend - app.kubernetes.io/managed-by: Helm - serviceName: backend-headless - template: - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: backend - app.kubernetes.io/managed-by: Helm - spec: - affinity: null - serviceAccountName: backend-service-account - containers: - - args: null - command: null - imagePullPolicy: Always - env: - - name: K8S_POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: K8S_NAMESPACE - valueFrom: - fieldRef: - fieldPath: metadata.namespace - - name: MONGO_PORT - value: "{{ .Values.mongo.port }}" - - name: MONGO_HOST - value: {{ .Values.mongo.host }} - - name: MONGO_USERNAME - value: {{ .Values.mongo.username }} - - name: MONGO_PASSWORD - valueFrom: - secretKeyRef: - name: {{ .Values.secrets.mongo }} - key: rootPassword - - name: SERVICE_NAME - value: backend - - name: RELEASE - value: production - - name: VERSION - value: "{{ .Values.syft.version }}" - - name: VERSION_HASH - value: {{ .Values.node.settings.versionHash }} - - name: NODE_TYPE - value: {{ .Values.node.settings.nodeType }} - - name: NODE_NAME - value: {{ .Values.node.settings.nodeName }} - - name: NODE_SIDE_TYPE - value: {{ .Values.node.settings.nodeSideType }} - - name: STACK_API_KEY - valueFrom: - secretKeyRef: - name: {{ .Values.secrets.syft }} - key: stackApiKey - - name: PORT - value: "80" - - name: IGNORE_TLS_ERRORS - value: "false" - - name: HTTP_PORT - value: "80" - - name: HTTPS_PORT - value: "443" - - name: CONTAINER_HOST - value: "k8s" - - name: TRACE - value: "false" - - name: JAEGER_HOST - value: localhost - - name: JAEGER_PORT - value: "14268" - - name: DEV_MODE - value: "false" - - name: DOMAIN_CONNECTION_PORT - value: "3030" - - name: ENABLE_OBLV - value: "false" - - name: DEFAULT_ROOT_EMAIL - value: {{ .Values.node.settings.defaultRootEmail }} - - name: DEFAULT_ROOT_PASSWORD - valueFrom: - secretKeyRef: - name: {{ .Values.secrets.syft }} - key: defaultRootPassword - - name: S3_ROOT_USER - value: "{{ .Values.seaweedfs.s3RootUser }}" - - name: S3_ROOT_PWD - valueFrom: - secretKeyRef: - name: {{ .Values.secrets.seaweedfs }} - key: s3RootPassword - - name: S3_PORT - value: "{{ .Values.seaweedfs.s3Port }}" - - name: SEAWEED_MOUNT_PORT - value: "{{ .Values.seaweedfs.mountPort }}" - - name: QUEUE_PORT - value: "{{ .Values.queue.port }}" - - name: CREATE_PRODUCER - value: "true" - - name: N_CONSUMERS - value: "0" - - name: INMEMORY_WORKERS - value: "{{ .Values.node.settings.inMemoryWorkers }}" - - name: LOG_LEVEL - value: {{ .Values.node.settings.logLevel }} - - name: DEFAULT_WORKER_POOL_IMAGE - value: {{ .Values.syft.registry }}/openmined/grid-backend:{{ .Values.syft.version }} - - name: DEFAULT_WORKER_POOL_COUNT - value: "{{ .Values.node.settings.defaultWorkerPoolCount }}" - envFrom: null - image: {{ .Values.syft.registry }}/openmined/grid-backend:{{ .Values.syft.version }} - lifecycle: null - livenessProbe: null - name: container-0 - readinessProbe: null - securityContext: null - startupProbe: null - volumeDevices: null - volumeMounts: - - mountPath: /root/data/creds/ - name: credentials-data - readOnly: false - subPath: credentials-data - dnsConfig: null - ephemeralContainers: null - hostAliases: null - imagePullSecrets: null - initContainers: null - nodeName: null - nodeSelector: null - overhead: null - readinessGates: null - securityContext: null - terminationGracePeriodSeconds: 5 - tolerations: null - topologySpreadConstraints: null - volumeClaimTemplates: - - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: backend - app.kubernetes.io/managed-by: Helm - name: credentials-data - spec: - accessModes: - - ReadWriteOnce - resources: - requests: - storage: 100Mi diff --git a/packages/grid/helm/syft/templates/backend/backend-headless-service.yaml b/packages/grid/helm/syft/templates/backend/backend-headless-service.yaml new file mode 100644 index 00000000000..fd1bb1768df --- /dev/null +++ b/packages/grid/helm/syft/templates/backend/backend-headless-service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: backend-headless + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend +spec: + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: backend + clusterIP: None + ports: + - name: api + port: 80 diff --git a/packages/grid/helm/syft/templates/backend/backend-secret.yaml b/packages/grid/helm/syft/templates/backend/backend-secret.yaml new file mode 100644 index 00000000000..12b14be20bd --- /dev/null +++ b/packages/grid/helm/syft/templates/backend/backend-secret.yaml @@ -0,0 +1,16 @@ +{{- $secretName := "backend-secret" }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend +type: Opaque +data: + defaultRootPassword: {{ include "common.secrets.set" (dict + "secret" $secretName + "key" "defaultRootPassword" + "default" .Values.node.defaultSecret.defaultRootPassword + "context" $) + }} diff --git a/packages/grid/helm/syft/templates/backend-service-account.yaml b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml similarity index 67% rename from packages/grid/helm/syft/templates/backend-service-account.yaml rename to packages/grid/helm/syft/templates/backend/backend-service-account.yaml index 56e552b2b7f..a466d0c3fe4 100644 --- a/packages/grid/helm/syft/templates/backend-service-account.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-service-account.yaml @@ -3,12 +3,9 @@ kind: ServiceAccount metadata: name: backend-service-account labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend --- - apiVersion: v1 kind: Secret metadata: @@ -16,21 +13,17 @@ metadata: annotations: kubernetes.io/service-account.name: "backend-service-account" labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend type: kubernetes.io/service-account-token - --- - apiVersion: rbac.authorization.k8s.io/v1 kind: Role metadata: name: backend-service-role labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend rules: - apiGroups: [""] resources: ["pods", "configmaps", "secrets"] @@ -45,15 +38,13 @@ rules: resources: ["statefulsets"] verbs: ["create", "get", "list", "watch", "update", "patch", "delete"] --- - apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding metadata: name: backend-service-role-binding labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend subjects: - kind: ServiceAccount name: backend-service-account diff --git a/packages/grid/helm/syft/templates/backend/backend-service.yaml b/packages/grid/helm/syft/templates/backend/backend-service.yaml new file mode 100644 index 00000000000..2839a8c45e2 --- /dev/null +++ b/packages/grid/helm/syft/templates/backend/backend-service.yaml @@ -0,0 +1,27 @@ +apiVersion: v1 +kind: Service +metadata: + name: backend + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: backend + ports: + - name: api + protocol: TCP + port: 80 + targetPort: 80 + - name: queue + protocol: TCP + port: {{ .Values.node.queuePort }} + targetPort: {{ .Values.node.queuePort }} + {{- if .Values.node.debuggerEnabled }} + - name: debug + port: 5678 + targetPort: 5678 + protocol: TCP + {{- end }} diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml new file mode 100644 index 00000000000..3673312d922 --- /dev/null +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -0,0 +1,162 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: backend + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: backend +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: backend + serviceName: backend-headless + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: backend + spec: + containers: + - name: backend-container + image: {{ .Values.global.registry }}/openmined/grid-backend:{{ .Values.global.version }} + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.node.resources "preset" .Values.node.resourcesPreset) | nindent 12 }} + env: + # kubernetes runtime + - name: K8S_POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: K8S_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace + - name: CONTAINER_HOST + value: "k8s" + # syft + - name: NODE_TYPE + value: {{ .Values.node.type | default "domain" | quote }} + - name: NODE_NAME + value: {{ .Values.node.name | default .Release.Name | quote }} + - name: NODE_SIDE_TYPE + value: {{ .Values.node.side | quote }} + - name: DEFAULT_ROOT_EMAIL + value: {{ .Values.node.rootEmail | required "node.rootEmail is required" | quote }} + - name: DEFAULT_ROOT_PASSWORD + valueFrom: + secretKeyRef: + name: {{ .Values.node.secretKeyName | required "node.secretKeyName is required" }} + key: defaultRootPassword + - name: LOG_LEVEL + value: {{ .Values.node.logLevel | quote }} + - name: QUEUE_PORT + value: {{ .Values.node.queuePort | quote }} + - name: CREATE_PRODUCER + value: "true" + - name: INMEMORY_WORKERS + value: {{ .Values.node.inMemoryWorkers | quote }} + - name: DEFAULT_WORKER_POOL_IMAGE + value: "{{ .Values.global.registry }}/openmined/grid-backend:{{ .Values.global.version }}" + - name: DEFAULT_WORKER_POOL_COUNT + value: {{ .Values.node.defaultWorkerPoolCount | quote }} + {{- if .Values.node.debuggerEnabled }} + - name: DEBUGGER_ENABLED + value: "true" + {{- end }} + # MongoDB + - name: MONGO_PORT + value: {{ .Values.mongo.port | quote }} + - name: MONGO_HOST + value: "mongo" + - name: MONGO_USERNAME + value: {{ .Values.mongo.username | quote }} + - name: MONGO_PASSWORD + valueFrom: + secretKeyRef: + name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} + key: rootPassword + # SMTP + - name: SMTP_HOST + value: {{ .Values.node.smtp.host | quote }} + - name: SMTP_PORT + value: {{ .Values.node.smtp.port | quote }} + - name: SMTP_USERNAME + value: {{ .Values.node.smtp.username | quote }} + - name: SMTP_PASSWORD + value: {{ .Values.node.smtp.password | quote }} + - name: EMAIL_SENDER + value: {{ .Values.node.smtp.from | quote}} + # SeaweedFS + {{- if ne .Values.node.type "gateway"}} + - name: S3_ROOT_USER + value: {{ .Values.seaweedfs.s3.rootUser | quote }} + - name: S3_ROOT_PWD + valueFrom: + secretKeyRef: + name: {{ .Values.seaweedfs.secretKeyName | required "seaweedfs.secretKeyName is required" }} + key: s3RootPassword + - name: S3_PORT + value: {{ .Values.seaweedfs.s3.port | quote }} + - name: SEAWEED_MOUNT_PORT + value: {{ .Values.seaweedfs.mountApi.port | quote }} + {{- end }} + # Tracing + - name: TRACE + value: "false" + - name: SERVICE_NAME + value: "backend" + - name: JAEGER_HOST + value: "localhost" + - name: JAEGER_PORT + value: "14268" + # Oblivious + {{- if .Values.node.oblv.enabled }} + - name: OBLV_LOCALHOST_PORT + value: {{ .Values.node.oblv.port | quote }} + - name: OBLV_ENABLED + value: {{ .Values.node.oblv.enabled | quote }} + {{- end }} + {{- if .Values.node.env }} + {{- toYaml .Values.node.env | nindent 12 }} + {{- end }} + ports: + - name: api-port + containerPort: 80 + volumeMounts: + - mountPath: /root/data/creds/ + name: credentials-data + readOnly: false + subPath: credentials-data + startupProbe: + httpGet: + path: /api/v2/metadata?probe=startupProbe + port: api-port + failureThreshold: 30 + periodSeconds: 10 + livenessProbe: + httpGet: + path: /api/v2/?probe=livenessProbe + port: api-port + periodSeconds: 15 + timeoutSeconds: 5 + failureThreshold: 3 + readinessProbe: null + serviceAccountName: backend-service-account + terminationGracePeriodSeconds: 5 + volumeClaimTemplates: + - metadata: + name: credentials-data + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: backend + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 10Mi diff --git a/packages/grid/helm/syft/templates/frontend-deployment.yaml b/packages/grid/helm/syft/templates/frontend-deployment.yaml deleted file mode 100644 index f43fd0018dc..00000000000 --- a/packages/grid/helm/syft/templates/frontend-deployment.yaml +++ /dev/null @@ -1,62 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: frontend - app.kubernetes.io/managed-by: Helm - name: frontend -spec: - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: frontend - app.kubernetes.io/managed-by: Helm - strategy: - type: Recreate - template: - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: frontend - app.kubernetes.io/managed-by: Helm - spec: - affinity: null - containers: - - args: null - command: null - env: - - name: VERSION - value: "{{ .Values.syft.version }}" - - name: VERSION_HASH - value: {{ .Values.node.settings.versionHash }} - - name: NODE_TYPE - value: {{ .Values.node.settings.nodeType }} - - name: NEXT_PUBLIC_API_URL - value: ${NEXT_PUBLIC_API_URL} - envFrom: null - image: {{ .Values.syft.registry }}/openmined/grid-frontend:{{ .Values.syft.version }} - lifecycle: null - livenessProbe: null - name: container-0 - readinessProbe: null - securityContext: null - startupProbe: null - volumeDevices: null - volumeMounts: null - dnsConfig: null - ephemeralContainers: null - hostAliases: null - imagePullSecrets: null - initContainers: null - nodeName: null - nodeSelector: null - overhead: null - readinessGates: null - securityContext: null - terminationGracePeriodSeconds: 5 - tolerations: null - topologySpreadConstraints: null - volumes: null diff --git a/packages/grid/helm/syft/templates/frontend-service.yaml b/packages/grid/helm/syft/templates/frontend-service.yaml deleted file mode 100644 index ad60d1c4a22..00000000000 --- a/packages/grid/helm/syft/templates/frontend-service.yaml +++ /dev/null @@ -1,19 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: frontend -spec: - externalIPs: null - ports: - - name: port-0 - port: 80 - protocol: TCP - targetPort: 80 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: frontend - type: ClusterIP diff --git a/packages/grid/helm/syft/templates/frontend/frontend-deployment.yaml b/packages/grid/helm/syft/templates/frontend/frontend-deployment.yaml new file mode 100644 index 00000000000..4147251b111 --- /dev/null +++ b/packages/grid/helm/syft/templates/frontend/frontend-deployment.yaml @@ -0,0 +1,40 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: frontend + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: frontend +spec: + replicas: 1 + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: frontend + strategy: + type: Recreate + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: frontend + spec: + containers: + - name: frontend-container + image: {{ .Values.global.registry }}/openmined/grid-frontend:{{ .Values.global.version }} + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.frontend.resources "preset" .Values.frontend.resourcesPreset) | nindent 12 }} + {{- if .Values.frontend.env }} + env: {{ toYaml .Values.frontend.env | nindent 12 }} + {{- end }} + ports: + - name: ui-port + containerPort: 80 + livenessProbe: + httpGet: + path: /health?probe=startupProbe + port: ui-port + periodSeconds: 15 + timeoutSeconds: 5 + failureThreshold: 3 + terminationGracePeriodSeconds: 5 diff --git a/packages/grid/helm/syft/templates/frontend/frontend-service.yaml b/packages/grid/helm/syft/templates/frontend/frontend-service.yaml new file mode 100644 index 00000000000..3a37e03e927 --- /dev/null +++ b/packages/grid/helm/syft/templates/frontend/frontend-service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: frontend + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: frontend +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: frontend + ports: + - name: ui + protocol: TCP + port: 80 + targetPort: 80 diff --git a/packages/grid/helm/syft/templates/global/ingress.yaml b/packages/grid/helm/syft/templates/global/ingress.yaml new file mode 100644 index 00000000000..677a66313a6 --- /dev/null +++ b/packages/grid/helm/syft/templates/global/ingress.yaml @@ -0,0 +1,44 @@ +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: ingress + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: ingress + {{- if or .Values.ingress.annotations .Values.ingress.class }} + annotations: + {{- if .Values.ingress.class }} + kubernetes.io/ingress.class: {{ .Values.ingress.class | quote }} + {{- end }} + {{- if .Values.ingress.annotations }} + {{- toYaml .Values.ingress.annotations | nindent 4 }} + {{- end }} + {{- end }} +spec: + {{- if .Values.ingress.className }} + ingressClassName: {{ .Values.ingress.className | quote }} + {{- end }} + defaultBackend: + service: + name: proxy + port: + number: 80 + rules: + - host: {{ .Values.ingress.hostname | quote }} + http: + paths: + - backend: + service: + name: proxy + port: + number: 80 + path: / + pathType: Prefix + {{- if .Values.ingress.tls.enabled }} + tls: + - hosts: + - {{ .Values.ingress.hostname | required "ingress.hostname is required when TLS is enabled" | quote }} + {{- if .Values.ingress.tls.secretName }} + secretName: {{ .Values.ingress.tls.secretName }} + {{- end}} + {{- end }} diff --git a/packages/grid/helm/syft/templates/grid-stack-ingress-ingress.yaml b/packages/grid/helm/syft/templates/grid-stack-ingress-ingress.yaml deleted file mode 100644 index 6aed72bd414..00000000000 --- a/packages/grid/helm/syft/templates/grid-stack-ingress-ingress.yaml +++ /dev/null @@ -1,34 +0,0 @@ -{{- if not .Values.node.settings.tls }} -apiVersion: networking.k8s.io/v1 -kind: Ingress -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: ingress - app.kubernetes.io/managed-by: Helm - name: grid-stack-ingress - {{- if .Values.ingress.class }} - annotations: - kubernetes.io/ingress.class: {{ .Values.ingress.class }} - {{- end }} -spec: - {{- if .Values.ingress.className }} - ingressClassName: {{ .Values.ingress.className }} - {{- end }} - defaultBackend: - service: - name: proxy - port: - number: 80 - rules: - - http: - paths: - - backend: - service: - name: proxy - port: - number: 80 - path: / - pathType: Prefix -{{ end }} diff --git a/packages/grid/helm/syft/templates/grid-stack-ingress-tls-ingress.yaml b/packages/grid/helm/syft/templates/grid-stack-ingress-tls-ingress.yaml deleted file mode 100644 index 58db0a03e29..00000000000 --- a/packages/grid/helm/syft/templates/grid-stack-ingress-tls-ingress.yaml +++ /dev/null @@ -1,38 +0,0 @@ -{{- if .Values.node.settings.tls }} -apiVersion: networking.k8s.io/v1 -kind: Ingress -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: ingress - app.kubernetes.io/managed-by: Helm - name: grid-stack-ingress-tls - {{- if .Values.ingress.class }} - annotations: - kubernetes.io/ingress.class: {{ .Values.ingress.class }} - {{- end }} -spec: - {{- if .Values.ingress.className }} - ingressClassName: {{ .Values.ingress.className }} - {{- end }} - defaultBackend: - service: - name: proxy - port: - number: 80 - rules: - - host: {{ .Values.node.settings.hostname }} - http: - paths: - - backend: - service: - name: proxy - port: - number: 80 - path: / - pathType: Prefix - tls: - - hosts: - - {{ .Values.node.settings.hostname }} -{{ end }} diff --git a/packages/grid/helm/syft/templates/mongo-headless-service.yaml b/packages/grid/helm/syft/templates/mongo-headless-service.yaml deleted file mode 100644 index bab93cc0d6a..00000000000 --- a/packages/grid/helm/syft/templates/mongo-headless-service.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: mongo-headless -spec: - clusterIP: None - ports: - - name: web - port: 80 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: mongo - app.kubernetes.io/managed-by: Helm diff --git a/packages/grid/helm/syft/templates/mongo-secret.yaml b/packages/grid/helm/syft/templates/mongo-secret.yaml deleted file mode 100644 index a5cd98bf636..00000000000 --- a/packages/grid/helm/syft/templates/mongo-secret.yaml +++ /dev/null @@ -1,23 +0,0 @@ -apiVersion: v1 -kind: Secret -metadata: - name: "mongo-default-secret" - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm -type: Opaque -data: - {{- if not .Values.configuration.devmode }} - - {{- $secretObj := (lookup "v1" "Secret" .Release.Namespace "mongo-default-secret") | default dict }} - {{- $secretData := (get $secretObj "data") | default dict }} - {{- $rootPasswordEnv := (get $secretData "rootPassword") | default (randAlphaNum 32 | b64enc) }} - - rootPassword: {{ $rootPasswordEnv | quote }} - - {{- else }} - - rootPassword: {{ "example" | b64enc }} # Base64-encoded "example" - - {{- end }} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/mongo-service.yaml b/packages/grid/helm/syft/templates/mongo-service.yaml deleted file mode 100644 index 8880d73378c..00000000000 --- a/packages/grid/helm/syft/templates/mongo-service.yaml +++ /dev/null @@ -1,19 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: mongo -spec: - externalIPs: null - ports: - - name: port-0 - port: 27017 - protocol: TCP - targetPort: 27017 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: mongo - type: ClusterIP diff --git a/packages/grid/helm/syft/templates/mongo-statefulset.yaml b/packages/grid/helm/syft/templates/mongo-statefulset.yaml deleted file mode 100644 index 68804a33c5c..00000000000 --- a/packages/grid/helm/syft/templates/mongo-statefulset.yaml +++ /dev/null @@ -1,78 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: mongo - app.kubernetes.io/managed-by: Helm - name: mongo -spec: - podManagementPolicy: OrderedReady - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: mongo - app.kubernetes.io/managed-by: Helm - serviceName: mongo-headless - template: - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: mongo - app.kubernetes.io/managed-by: Helm - spec: - affinity: null - containers: - - args: null - command: null - env: - - name: MONGO_INITDB_ROOT_USERNAME - value: {{ .Values.mongo.username }} - - name: MONGO_INITDB_ROOT_PASSWORD - valueFrom: - secretKeyRef: - name: {{ .Values.secrets.mongo }} - key: rootPassword - envFrom: null - image: mongo:7.0.4 - lifecycle: null - livenessProbe: null - name: container-0 - readinessProbe: null - securityContext: null - startupProbe: null - volumeDevices: null - volumeMounts: - - mountPath: /data/db - name: mongo-data - readOnly: false - subPath: '' - dnsConfig: null - ephemeralContainers: null - hostAliases: null - imagePullSecrets: null - initContainers: null - nodeName: null - nodeSelector: null - overhead: null - readinessGates: null - securityContext: null - terminationGracePeriodSeconds: 5 - tolerations: null - topologySpreadConstraints: null - volumes: null - volumeClaimTemplates: - - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: mongo - app.kubernetes.io/managed-by: Helm - name: mongo-data - spec: - accessModes: - - ReadWriteOnce - resources: - requests: - storage: 5Gi diff --git a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml b/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml new file mode 100644 index 00000000000..7cb97ee3592 --- /dev/null +++ b/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: mongo-headless + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: mongo +spec: + clusterIP: None + ports: + - name: mongo + port: 27017 + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: mongo diff --git a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml b/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml new file mode 100644 index 00000000000..a58fb2b72c6 --- /dev/null +++ b/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml @@ -0,0 +1,16 @@ +{{- $secretName := "mongo-secret" }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: mongo +type: Opaque +data: + rootPassword: {{ include "common.secrets.set" (dict + "secret" $secretName + "key" "rootPassword" + "default" .Values.mongo.defaultSecret.rootPassword + "context" $) + }} diff --git a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml b/packages/grid/helm/syft/templates/mongo/mongo-service.yaml new file mode 100644 index 00000000000..a789f4e8f86 --- /dev/null +++ b/packages/grid/helm/syft/templates/mongo/mongo-service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: mongo + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: mongo +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: mongo + ports: + - name: mongo + port: 27017 + protocol: TCP + targetPort: 27017 diff --git a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml b/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml new file mode 100644 index 00000000000..dfddffbcb48 --- /dev/null +++ b/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml @@ -0,0 +1,60 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: mongo + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: mongo +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: mongo + serviceName: mongo-headless + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: mongo + spec: + containers: + - name: mongo-container + image: mongo:7 + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.mongo.resources "preset" .Values.mongo.resourcesPreset) | nindent 12 }} + env: + - name: MONGO_INITDB_ROOT_USERNAME + value: {{ .Values.mongo.username | required "mongo.username is required" | quote }} + - name: MONGO_INITDB_ROOT_PASSWORD + valueFrom: + secretKeyRef: + name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} + key: rootPassword + {{- if .Values.mongo.env }} + {{- toYaml .Values.mongo.env | nindent 12 }} + {{- end }} + volumeMounts: + - mountPath: /data/db + name: mongo-data + readOnly: false + subPath: '' + ports: + - name: mongo-port + containerPort: 27017 + terminationGracePeriodSeconds: 5 + volumeClaimTemplates: + - metadata: + name: mongo-data + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: mongo + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.mongo.storageSize | quote }} diff --git a/packages/grid/helm/syft/templates/proxy-deployment.yaml b/packages/grid/helm/syft/templates/proxy-deployment.yaml deleted file mode 100644 index 3ef7c1717cc..00000000000 --- a/packages/grid/helm/syft/templates/proxy-deployment.yaml +++ /dev/null @@ -1,62 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: proxy - app.kubernetes.io/managed-by: Helm - name: proxy -spec: - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: proxy - app.kubernetes.io/managed-by: Helm - strategy: - type: Recreate - template: - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: proxy - app.kubernetes.io/managed-by: Helm - spec: - affinity: null - containers: - - args: null - command: null - env: - - name: SERVICE_NAME - value: proxy - envFrom: null - image: traefik:v2.10 - lifecycle: null - livenessProbe: null - name: container-0 - readinessProbe: null - securityContext: null - startupProbe: null - volumeDevices: null - volumeMounts: - - mountPath: /etc/traefik - name: traefik-conf - readOnly: false - dnsConfig: null - ephemeralContainers: null - hostAliases: null - imagePullSecrets: null - initContainers: null - nodeName: null - nodeSelector: null - overhead: null - readinessGates: null - securityContext: null - terminationGracePeriodSeconds: 5 - tolerations: null - topologySpreadConstraints: null - volumes: - - configMap: - name: traefik-main-config - name: traefik-conf diff --git a/packages/grid/helm/syft/templates/proxy-service.yaml b/packages/grid/helm/syft/templates/proxy-service.yaml deleted file mode 100644 index 1c289f3e0be..00000000000 --- a/packages/grid/helm/syft/templates/proxy-service.yaml +++ /dev/null @@ -1,19 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: proxy -spec: - externalIPs: null - ports: - - name: proxy - port: 80 - protocol: TCP - targetPort: 80 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: proxy - type: ClusterIP diff --git a/packages/grid/helm/syft/templates/traefik-main-config-configmap.yaml b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml similarity index 92% rename from packages/grid/helm/syft/templates/traefik-main-config-configmap.yaml rename to packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml index 829e7220b55..2a90d4ff232 100644 --- a/packages/grid/helm/syft/templates/traefik-main-config-configmap.yaml +++ b/packages/grid/helm/syft/templates/proxy/proxy-configmap.yaml @@ -1,4 +1,10 @@ apiVersion: v1 +kind: ConfigMap +metadata: + name: proxy-config + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: proxy data: dynamic.yml: | http: @@ -63,6 +69,4 @@ data: providers: file: filename: /etc/traefik/dynamic.yml -kind: ConfigMap -metadata: - name: traefik-main-config + diff --git a/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml b/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml new file mode 100644 index 00000000000..404cbecdb54 --- /dev/null +++ b/packages/grid/helm/syft/templates/proxy/proxy-deployment.yaml @@ -0,0 +1,52 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: proxy + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: proxy +spec: + replicas: 1 + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: proxy + strategy: + type: Recreate + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: proxy + spec: + containers: + - name: proxy-container + image: traefik:v2.11 + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.proxy.resources "preset" .Values.proxy.resourcesPreset) | nindent 12 }} + {{- if .Values.proxy.env }} + env: {{ toYaml .Values.proxy.env | nindent 12 }} + {{- end }} + ports: + - name: proxy-port + containerPort: 80 + - name: ping-port + containerPort: 8082 + volumeMounts: + - mountPath: /etc/traefik + name: traefik-conf + readOnly: false + startupProbe: null + livenessProbe: + httpGet: + path: /ping?probe=livenessProbe + port: ping-port + periodSeconds: 15 + timeoutSeconds: 5 + failureThreshold: 3 + readinessProbe: null + terminationGracePeriodSeconds: 5 + volumes: + - configMap: + name: proxy-config + name: traefik-conf diff --git a/packages/grid/helm/syft/templates/proxy/proxy-service.yaml b/packages/grid/helm/syft/templates/proxy/proxy-service.yaml new file mode 100644 index 00000000000..322e9898e64 --- /dev/null +++ b/packages/grid/helm/syft/templates/proxy/proxy-service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: proxy + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: proxy +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: proxy + ports: + - name: proxy + protocol: TCP + port: 80 + targetPort: 80 diff --git a/packages/grid/helm/syft/templates/registry-service.yaml b/packages/grid/helm/syft/templates/registry-service.yaml deleted file mode 100644 index f96060e3a4d..00000000000 --- a/packages/grid/helm/syft/templates/registry-service.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: registry - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm -spec: - type: ClusterIP - ports: - - protocol: TCP - port: 80 - targetPort: 5000 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: registry diff --git a/packages/grid/helm/syft/templates/registry-statefulset.yaml b/packages/grid/helm/syft/templates/registry-statefulset.yaml deleted file mode 100644 index c4fb60d474d..00000000000 --- a/packages/grid/helm/syft/templates/registry-statefulset.yaml +++ /dev/null @@ -1,47 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - name: registry - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: registry - app.kubernetes.io/managed-by: Helm -spec: - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: registry - app.kubernetes.io/managed-by: Helm - template: - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: registry - app.kubernetes.io/managed-by: Helm - spec: - containers: - - image: registry:2 - name: registry - env: - - name: REGISTRY_STORAGE_DELETE_ENABLED - value: "true" - ports: - - containerPort: 5000 - volumeMounts: - - mountPath: /var/lib/registry - name: registry-data - volumeClaimTemplates: - - metadata: - name: registry-data - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: registry - app.kubernetes.io/managed-by: Helm - spec: - accessModes: - - ReadWriteOnce - resources: - requests: - storage: {{ .Values.registry.maxStorage }} diff --git a/packages/grid/helm/syft/templates/registry/registry-service.yaml b/packages/grid/helm/syft/templates/registry/registry-service.yaml new file mode 100644 index 00000000000..c132545bf2c --- /dev/null +++ b/packages/grid/helm/syft/templates/registry/registry-service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: registry + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: registry +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: registry + ports: + - name: registry + protocol: TCP + port: 80 + targetPort: 5000 diff --git a/packages/grid/helm/syft/templates/registry/registry-statefulset.yaml b/packages/grid/helm/syft/templates/registry/registry-statefulset.yaml new file mode 100644 index 00000000000..3e48131a694 --- /dev/null +++ b/packages/grid/helm/syft/templates/registry/registry-statefulset.yaml @@ -0,0 +1,66 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: registry + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: registry +spec: + replicas: 1 + serviceName: registry + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: registry + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: registry + spec: + containers: + - name: registry-container + image: registry:2 + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.registry.resources "preset" .Values.registry.resourcesPreset) | nindent 12 }} + env: + - name: REGISTRY_STORAGE_DELETE_ENABLED + value: "true" + {{- if .Values.registry.env }} + {{- toYaml .Values.registry.env | nindent 12 }} + {{- end }} + ports: + - name: registry-port + containerPort: 5000 + volumeMounts: + - mountPath: /var/lib/registry + name: registry-data + startupProbe: + httpGet: + path: /v2/?probe=startupProbe + port: registry-port + failureThreshold: 30 + periodSeconds: 10 + livenessProbe: + httpGet: + path: /v2/?probe=livenessProbe + port: registry-port + initialDelaySeconds: 5 + periodSeconds: 15 + timeoutSeconds: 5 + failureThreshold: 3 + terminationGracePeriodSeconds: 5 + volumeClaimTemplates: + - metadata: + name: registry-data + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: registry + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.registry.storageSize | quote }} diff --git a/packages/grid/helm/syft/templates/seaweedfs-headless-service.yaml b/packages/grid/helm/syft/templates/seaweedfs-headless-service.yaml deleted file mode 100644 index 03320064af4..00000000000 --- a/packages/grid/helm/syft/templates/seaweedfs-headless-service.yaml +++ /dev/null @@ -1,19 +0,0 @@ -{{- if ne .Values.node.settings.nodeType "gateway"}} -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: seaweedfs-headless -spec: - clusterIP: None - ports: - - name: web - port: 80 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: seaweedfs - app.kubernetes.io/managed-by: Helm -{{ end }} diff --git a/packages/grid/helm/syft/templates/seaweedfs-secret.yaml b/packages/grid/helm/syft/templates/seaweedfs-secret.yaml deleted file mode 100644 index 504de8e8561..00000000000 --- a/packages/grid/helm/syft/templates/seaweedfs-secret.yaml +++ /dev/null @@ -1,19 +0,0 @@ -apiVersion: v1 -kind: Secret -metadata: - name: "seaweedfs-default-secret" -type: Opaque -data: - {{- if not .Values.configuration.devmode }} - - {{- $secretObj := (lookup "v1" "Secret" .Release.Namespace "seaweedfs-default-secret") | default dict }} - {{- $secretData := (get $secretObj "data") | default dict }} - {{- $s3RootPasswordEnv := (get $secretData "s3RootPassword") | default (randAlphaNum 32 | b64enc) }} - - s3RootPassword: {{ $s3RootPasswordEnv | quote }} - - {{- else }} - - s3RootPassword: {{ "admin" | b64enc }} # Base64-encoded "admin" - - {{- end }} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/seaweedfs-service.yaml b/packages/grid/helm/syft/templates/seaweedfs-service.yaml deleted file mode 100644 index 9343112b019..00000000000 --- a/packages/grid/helm/syft/templates/seaweedfs-service.yaml +++ /dev/null @@ -1,29 +0,0 @@ -{{- if ne .Values.node.settings.nodeType "gateway"}} -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm - name: seaweedfs -spec: - externalIPs: null - ports: - - name: port-0 - port: 8888 - protocol: TCP - targetPort: 8888 - - name: port-1 - port: 8333 - protocol: TCP - targetPort: 8333 - - name: port-2 - port: 4001 - protocol: TCP - targetPort: 4001 - selector: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: seaweedfs - type: ClusterIP -{{ end }} diff --git a/packages/grid/helm/syft/templates/seaweedfs-statefulset.yaml b/packages/grid/helm/syft/templates/seaweedfs-statefulset.yaml deleted file mode 100644 index d13e4091139..00000000000 --- a/packages/grid/helm/syft/templates/seaweedfs-statefulset.yaml +++ /dev/null @@ -1,97 +0,0 @@ -{{- if ne .Values.node.settings.nodeType "gateway"}} -apiVersion: apps/v1 -kind: StatefulSet -metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/component: seaweedfs - app.kubernetes.io/managed-by: Helm - name: seaweedfs -spec: - podManagementPolicy: OrderedReady - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: seaweedfs - app.kubernetes.io/managed-by: Helm - serviceName: seaweedfs-headless - template: - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: seaweedfs - app.kubernetes.io/managed-by: Helm - spec: - affinity: null - containers: - - args: null - command: null - env: - - name: S3_VOLUME_SIZE_MB - value: "{{ .Values.seaweedfs.s3VolumeSizeMB }}" - - name: S3_ROOT_USER - value: "{{ .Values.seaweedfs.s3RootUser }}" - - name: S3_ROOT_PWD - valueFrom: - secretKeyRef: - name: {{ .Values.secrets.seaweedfs }} - key: s3RootPassword - - name: S3_PORT - value: "{{ .Values.seaweedfs.s3Port }}" - - name: SEAWEED_MOUNT_PORT - value: "{{ .Values.seaweedfs.mountPort }}" - envFrom: null - image: {{ .Values.syft.registry }}/openmined/grid-seaweedfs:{{ .Values.syft.version }} - lifecycle: null - livenessProbe: null - name: container-0 - readinessProbe: null - securityContext: null - startupProbe: null - volumeDevices: null - volumeMounts: - - mountPath: /etc/seaweedfs/filer.toml - name: seaweedfs-config - readOnly: false - subPath: filer.toml - - mountPath: /etc/seaweedfs/start.sh - name: seaweedfs-config - readOnly: false - subPath: start.sh - - mountPath: /data/blob - name: seaweedfs-data - readOnly: false - subPath: '' - dnsConfig: null - ephemeralContainers: null - hostAliases: null - imagePullSecrets: null - initContainers: null - nodeName: null - nodeSelector: null - overhead: null - readinessGates: null - securityContext: null - terminationGracePeriodSeconds: 5 - tolerations: null - topologySpreadConstraints: null - volumes: - - configMap: - name: seaweedfs-config - name: seaweedfs-config - volumeClaimTemplates: - - metadata: - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/component: seaweedfs - app.kubernetes.io/managed-by: Helm - name: seaweedfs-data - spec: - accessModes: - - ReadWriteOnce - resources: - requests: - storage: {{ .Values.seaweedfs.maxStorage }} -{{ end }} diff --git a/packages/grid/helm/syft/templates/seaweedfs-config-configmap.yaml b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-configmap.yaml similarity index 76% rename from packages/grid/helm/syft/templates/seaweedfs-config-configmap.yaml rename to packages/grid/helm/syft/templates/seaweedfs/seaweedfs-configmap.yaml index 54168fa460b..c4a558d9fd0 100644 --- a/packages/grid/helm/syft/templates/seaweedfs-config-configmap.yaml +++ b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-configmap.yaml @@ -1,5 +1,11 @@ -{{- if ne .Values.node.settings.nodeType "gateway"}} +{{- if ne .Values.node.type "gateway"}} apiVersion: v1 +kind: ConfigMap +metadata: + name: seaweedfs-config + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: seaweedfs data: filer.toml: | [leveldb2] @@ -12,7 +18,4 @@ data: echo "s3.configure -access_key ${S3_ROOT_USER} -secret_key ${S3_ROOT_PWD} -user iam -actions Read,Write,List,Tagging,Admin -apply" \ | weed shell > /dev/null 2>&1 \ & weed server -s3 -s3.port=${S3_PORT} -master.volumeSizeLimitMB=${S3_VOLUME_SIZE_MB} -kind: ConfigMap -metadata: - name: seaweedfs-config {{ end }} diff --git a/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-secret.yaml b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-secret.yaml new file mode 100644 index 00000000000..c4a0e9b5b09 --- /dev/null +++ b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-secret.yaml @@ -0,0 +1,18 @@ +{{- if ne .Values.node.type "gateway"}} +{{- $secretName := "seaweedfs-secret" }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: seaweedfs +type: Opaque +data: + s3RootPassword: {{ include "common.secrets.set" (dict + "secret" $secretName + "key" "s3RootPassword" + "default" .Values.seaweedfs.defaultSecret.s3RootPassword + "context" $) + }} +{{ end }} diff --git a/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-service.yaml b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-service.yaml new file mode 100644 index 00000000000..33b7ea2dd96 --- /dev/null +++ b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-service.yaml @@ -0,0 +1,27 @@ +{{- if ne .Values.node.type "gateway"}} +apiVersion: v1 +kind: Service +metadata: + name: seaweedfs + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: seaweedfs +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: seaweedfs + ports: + - name: filer + protocol: TCP + port: 8888 + targetPort: 8888 + - name: s3 + protocol: TCP + port: 8333 + targetPort: 8333 + - name: mount-api + protocol: TCP + port: 4001 + targetPort: 4001 +{{ end }} diff --git a/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-statefulset.yaml b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-statefulset.yaml new file mode 100644 index 00000000000..825a8b58d68 --- /dev/null +++ b/packages/grid/helm/syft/templates/seaweedfs/seaweedfs-statefulset.yaml @@ -0,0 +1,77 @@ +{{- if ne .Values.node.type "gateway"}} +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: seaweedfs + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: seaweedfs +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: seaweedfs + serviceName: seaweedfs + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: seaweedfs + spec: + containers: + - name: seaweedfs-container + image: {{ .Values.global.registry }}/openmined/grid-seaweedfs:{{ .Values.global.version }} + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.seaweedfs.resources "preset" .Values.seaweedfs.resourcesPreset) | nindent 12 }} + env: + - name: S3_VOLUME_SIZE_MB + value: {{ .Values.seaweedfs.s3.volumeSizeMB | quote }} + - name: S3_ROOT_USER + value: {{ .Values.seaweedfs.s3.rootUser | quote }} + - name: S3_ROOT_PWD + valueFrom: + secretKeyRef: + name: {{ .Values.seaweedfs.secretKeyName | required "seaweedfs.secretKeyName is required" }} + key: s3RootPassword + - name: S3_PORT + value: {{ .Values.seaweedfs.s3.port | quote }} + - name: SEAWEED_MOUNT_PORT + value: {{ .Values.seaweedfs.mountApi.port | quote }} + {{- if .Values.seaweedfs.env }} + {{- toYaml .Values.seaweedfs.env | nindent 12 }} + {{- end }} + volumeMounts: + - mountPath: /etc/seaweedfs/filer.toml + name: seaweedfs-config + readOnly: false + subPath: filer.toml + - mountPath: /etc/seaweedfs/start.sh + name: seaweedfs-config + readOnly: false + subPath: start.sh + - mountPath: /data/blob + name: seaweedfs-data + readOnly: false + subPath: '' + terminationGracePeriodSeconds: 5 + volumes: + - configMap: + name: seaweedfs-config + name: seaweedfs-config + volumeClaimTemplates: + - metadata: + name: seaweedfs-data + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: seaweedfs + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.seaweedfs.storageSize | quote }} +{{ end }} diff --git a/packages/grid/helm/syft/templates/syft-secret.yaml b/packages/grid/helm/syft/templates/syft-secret.yaml deleted file mode 100644 index 57242b3b1ed..00000000000 --- a/packages/grid/helm/syft/templates/syft-secret.yaml +++ /dev/null @@ -1,26 +0,0 @@ -apiVersion: v1 -kind: Secret -metadata: - name: "syft-default-secret" - labels: - app.kubernetes.io/name: {{ .Chart.Name }} - app.kubernetes.io/version: {{ .Chart.AppVersion }} - app.kubernetes.io/managed-by: Helm -type: Opaque -data: - {{- if not .Values.configuration.devmode }} - - {{- $secretObj := (lookup "v1" "Secret" .Release.Namespace "syft-default-secret") | default dict }} - {{- $secretData := (get $secretObj "data") | default dict }} - {{- $stackApiKeyEnv := (get $secretData "stackApiKey") | default (randAlphaNum 32 | b64enc) }} - {{- $defaultRootPasswordEnv := (get $secretData "defaultRootPassword") | default (randAlphaNum 32 | b64enc) }} - - stackApiKey: {{ $stackApiKeyEnv | quote }} - defaultRootPassword: {{ $defaultRootPasswordEnv | quote }} - - {{- else }} - - stackApiKey: {{ "changeme" | b64enc }} - defaultRootPassword: {{ "changethis" | b64enc}} - - {{- end }} diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 84f03c797df..59af4023e30 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -1,59 +1,147 @@ -secrets: - syft: syft-default-secret - mongo: mongo-default-secret - seaweedfs: seaweedfs-default-secret +global: + # Affects only backend, frontend, and seaweedfs containers + registry: docker.io + version: 0.8.5-beta.1 -configuration: - devmode: false + # Force default secret values for development. DO NOT USE IN PRODUCTION + useDefaultSecrets: false mongo: + # MongoDB config port: 27017 - host: "mongo" - username: "root" + username: root + + # Extra environment vars + env: null + + # Pod Resource Limits + resourcesPreset: nano + resources: null + + # PVC storage size + storageSize: 5Gi + + # Mongo secret name. Override this if you want to use a self-managed secret. + secretKeyName: mongo-secret + + # Dev mode default passwords + defaultSecret: + rootPassword: example + +frontend: + # Extra environment vars + env: null + + # Pod Resource Limits + resourcesPreset: nano + resources: null seaweedfs: - # SeaweedFS config - mountPort: 4001 - # SeaweedFS S3 Config - s3VolumeSizeMB: 1024 - s3Port: 8333 - s3RootUser: "admin" - # storage limits - maxStorage: "5Gi" - -queue: - port: 5556 + # S3 settings + s3: + port: 8333 + rootUser: admin + volumeSizeMB: 1024 + + # Mount API settings + mountApi: + port: 4001 + + # Extra environment vars + env: null + + # PVC storage size + storageSize: 5Gi + + # Seaweed secret name. Override this if you want to use a self-managed secret. + # Secret must contain the following keys: + # - s3RootPassword + secretKeyName: seaweedfs-secret + + # Pod Resource Limits + resourcesPreset: nano + resources: null + + # Dev mode default passwords + defaultSecret: + s3RootPassword: admin + +proxy: + # Extra environment vars + env: null + + # Pod Resource Limits + resourcesPreset: nano + resources: null registry: - maxStorage: "10Gi" + # Extra environment vars + env: null -syft: - registry: "docker.io" - version: 0.8.5-beta.1 + # PVC storage size + storageSize: 5Gi + + # Pod Resource Limits + resourcesPreset: nano + resources: null node: - settings: - tls: false - hostname: "" # do not make this localhost - nodeName: "mynode" - nodeType: "domain" - versionHash: "abc" - nodeSideType: "high" - defaultRootEmail: "info@openmined.org" - logLevel: "info" - inMemoryWorkers: false - defaultWorkerPoolCount: 1 - -# ---------------------------------------- -# For Azure -# className: "azure-application-gateway" -# ---------------------------------------- -# For AWS -# className: "alb" -# ---------------------------------------- -# For GCE, https://cloud.google.com/kubernetes-engine/docs/how-to/load-balance-ingress#create-ingress -# class: "gce" -# ---------------------------------------- + # Syft settings + name: null + rootEmail: info@openmined.org + type: domain + side: high + inMemoryWorkers: false + defaultWorkerPoolCount: 1 + queuePort: 5556 + logLevel: info + debuggerEnabled: false + + # SMTP Settings + smtp: + host: smtp.sendgrid.net + port: 587 + from: noreply@openmined.org + username: apikey + password: password + + # Oblivious settings + oblv: + enabled: false + port: 3030 + + # Extra environment vars + env: null + + # Pod Resource Limits + resourcesPreset: small + resources: null + + # Seaweed secret name. Override this if you want to use a self-managed secret. + # Secret must contain the following keys: + # - defaultRootPassword + secretKeyName: backend-secret + + # Dev mode default passwords + defaultSecret: + defaultRootPassword: changethis + ingress: + hostname: null # do not make this localhost + + tls: + enabled: false + secretName: null + + # ---------------------------------------- + # For Azure + # className: azure-application-gateway + # ---------------------------------------- + # For AWS + # className: alb + # ---------------------------------------- + # For GCE, https://cloud.google.com/kubernetes-engine/docs/how-to/load-balance-ingress#create-ingress + # class: gce + # ---------------------------------------- class: null className: null diff --git a/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml b/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml index e5849906f7c..eef9e420ab8 100644 --- a/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml +++ b/packages/grid/podman/podman-kube/podman-syft-kube-config.yaml @@ -58,7 +58,7 @@ data: USE_NEW_SERVICE: False # Frontend - VITE_PUBLIC_API_BASE_URL: "/api/v2" + BACKEND_API_BASE_URL: "/api/v2" # SeaweedFS S3_ENDPOINT: "seaweedfs" @@ -104,9 +104,9 @@ data: USE_BLOB_STORAGE: False #Oblivious - ENABLE_OBLV: false + OBLV_ENABLED: false OBLV_KEY_PATH: "~/.oblv" - DOMAIN_CONNECTION_PORT: 3030 + OBLV_LOCALHOST_PORT: 3030 # Registation ENABLE_SIGNUP: False diff --git a/packages/grid/scripts/helm_upgrade.sh b/packages/grid/scripts/helm_upgrade.sh new file mode 100644 index 00000000000..82418eaa99a --- /dev/null +++ b/packages/grid/scripts/helm_upgrade.sh @@ -0,0 +1,53 @@ +#! /bin/bash + +set -e + +HELM_REPO="openmined/syft" +DOMAIN_NAME="test-domain" +KUBE_NAMESPACE="syft" +KUBE_CONTEXT=${KUBE_CONTEXT:-"k3d-syft-dev"} + +UPGRADE_TYPE=$1 + +PROD="openmined/syft" +BETA="openmined/syft --devel" +DEV="./helm/syft" + +if [ "$UPGRADE_TYPE" == "ProdToBeta" ]; then + INSTALL_SOURCE=$PROD # latest published prod + UPGRADE_SOURCE=$BETA # latest published beta + INSTALL_ARGS="" + UPGRADE_ARGS="" +elif [ "$UPGRADE_TYPE" == "BetaToDev" ]; then + INSTALL_SOURCE=$BETA # latest published beta + UPGRADE_SOURCE=$DEV # local chart + INSTALL_ARGS="" + UPGRADE_ARGS="" +elif [ "$UPGRADE_TYPE" == "ProdToDev" ]; then + INSTALL_SOURCE=$PROD # latest published prod + UPGRADE_SOURCE=$DEV # local chart + INSTALL_ARGS="" + UPGRADE_ARGS="" +else + echo Invalid upgrade type $UPGRADE_TYPE + exit 1 +fi + +kubectl config use-context $KUBE_CONTEXT +kubectl delete namespace syft || true +helm repo add openmined https://openmined.github.io/PySyft/helm +helm repo update openmined + +echo Installing syft... +helm install $DOMAIN_NAME $INSTALL_SOURCE $INSTALL_ARGS --namespace $KUBE_NAMESPACE --create-namespace +helm ls -A + +WAIT_TIME=5 bash ./scripts/wait_for.sh service backend --namespace $KUBE_NAMESPACE +WAIT_TIME=5 bash ./scripts/wait_for.sh pod default-pool-0 --namespace $KUBE_NAMESPACE + +echo Upgrading syft... +helm upgrade $DOMAIN_NAME $UPGRADE_SOURCE $UPGRADE_ARGS --namespace $KUBE_NAMESPACE +helm ls -A + +echo "Post-upgrade sleep" && sleep 5 +WAIT_TIME=5 bash ./scripts/wait_for.sh service backend --namespace $KUBE_NAMESPACE diff --git a/packages/hagrid/hagrid/cli.py b/packages/hagrid/hagrid/cli.py index df1428930b7..69f367ba553 100644 --- a/packages/hagrid/hagrid/cli.py +++ b/packages/hagrid/hagrid/cli.py @@ -322,6 +322,41 @@ def clean(location: str) -> None: type=str, help="Container image tag to use", ) +@click.option( + "--smtp-username", + default=None, + required=False, + type=str, + help="Username used to auth in email server and enable notification via emails", +) +@click.option( + "--smtp-password", + default=None, + required=False, + type=str, + help="Password used to auth in email server and enable notification via emails", +) +@click.option( + "--smtp-port", + default=None, + required=False, + type=str, + help="Port used by email server to send notification via emails", +) +@click.option( + "--smtp-host", + default=None, + required=False, + type=str, + help="Address used by email server to send notification via emails", +) +@click.option( + "--smtp-sender", + default=None, + required=False, + type=str, + help="Sender email used to deliver PyGrid email notifications.", +) @click.option( "--build-src", default=DEFAULT_BRANCH, @@ -1309,6 +1344,12 @@ def create_launch_cmd( else: parsed_kwargs["node_side_type"] = NodeSideType.HIGH_SIDE.value + parsed_kwargs["smtp_username"] = kwargs["smtp_username"] + parsed_kwargs["smtp_password"] = kwargs["smtp_password"] + parsed_kwargs["smtp_port"] = kwargs["smtp_port"] + parsed_kwargs["smtp_host"] = kwargs["smtp_host"] + parsed_kwargs["smtp_sender"] = kwargs["smtp_sender"] + parsed_kwargs["enable_warnings"] = not kwargs["no_warnings"] # choosing deployment type @@ -2156,6 +2197,11 @@ def create_launch_docker_cmd( single_container_mode = kwargs["deployment_type"] == "single_container" in_mem_workers = kwargs.get("in_mem_workers") + smtp_username = kwargs.get("smtp_username") + smtp_sender = kwargs.get("smtp_sender") + smtp_password = kwargs.get("smtp_password") + smtp_port = kwargs.get("smtp_port") + smtp_host = kwargs.get("smtp_host") enable_oblv = bool(kwargs["oblv"]) print(" - NAME: " + str(snake_name)) @@ -2215,11 +2261,16 @@ def create_launch_docker_cmd( "STACK_API_KEY": str( generate_sec_random_password(length=48, special_chars=False) ), - "ENABLE_OBLV": str(enable_oblv).lower(), + "OBLV_ENABLED": str(enable_oblv).lower(), "CREDENTIALS_VOLUME": host_path, "NODE_SIDE_TYPE": kwargs["node_side_type"], "SINGLE_CONTAINER_MODE": single_container_mode, "INMEMORY_WORKERS": in_mem_workers, + "SMTP_USERNAME": smtp_username, + "SMTP_PASSWORD": smtp_password, + "EMAIL_SENDER": smtp_sender, + "SMTP_PORT": smtp_port, + "SMTP_HOST": smtp_host, } if "trace" in kwargs and kwargs["trace"] is True: diff --git a/packages/hagrid/hagrid/orchestra.py b/packages/hagrid/hagrid/orchestra.py index 8a0e74c06b6..8ee771c0036 100644 --- a/packages/hagrid/hagrid/orchestra.py +++ b/packages/hagrid/hagrid/orchestra.py @@ -9,6 +9,7 @@ import inspect import os import subprocess # nosec +import sys from threading import Thread from typing import Any from typing import Callable @@ -598,7 +599,10 @@ def shutdown( elif "No resource found to remove for project" in land_output: print(f" ✅ {snake_name} Container does not exist") else: - print(f"❌ Unable to remove container: {snake_name} :{land_output}") + print( + f"❌ Unable to remove container: {snake_name} :{land_output}", + file=sys.stderr, + ) @staticmethod def reset(name: str, deployment_type_enum: DeploymentType) -> None: diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index e88ce585f21..56f052a9231 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -26,35 +26,35 @@ package_dir = # Add here dependencies of your project (semicolon/line-separated), e.g. syft = - bcrypt==4.0.1 - boto3==1.28.65 + bcrypt==4.1.2 + boto3==1.34.56 forbiddenfruit==0.1.4 gevent==23.9.1 loguru==0.7.2 - networkx==2.8 + networkx==3.2.1 packaging>=23.0 - pyarrow==14.0.1 - pycapnp==1.3.0 - pydantic[email]==1.10.13 - pymongo==4.6.1 + pyarrow==15.0.0 + # pycapnp is beta version, update to stable version when available + pycapnp==2.0.0b2 + pydantic[email]==2.6.0 + pydantic-settings==2.2.1 + pymongo==4.6.2 pynacl==1.5.0 pyzmq>=23.2.1,<=25.1.1 - redis==4.6.0 requests==2.31.0 RestrictedPython==7.0 - result==0.10.0 - tqdm==4.66.1 - typeguard==2.13.3 - typing_extensions==4.8.0 - sherlock[redis,filelock]==0.4.1 - uvicorn[standard]==0.24.0.post1 - fastapi==0.103.2 - psutil==5.9.6 + result==0.16.1 + tqdm==4.66.2 + typeguard==4.1.5 + typing_extensions==4.10.0 + sherlock[filelock]==0.4.1 + uvicorn[standard]==0.27.1 + fastapi==0.110.0 + psutil==5.9.8 hagrid>=0.3 - itables==1.6.2 - safetensors==0.4.1 + itables==1.7.1 argon2-cffi==23.1.0 - matplotlib==3.8.0 + matplotlib==3.8.3 # jaxlib is a DL library but we are needing it for serialization jaxlib==0.4.20 jax==0.4.20 @@ -62,9 +62,9 @@ syft = numpy>=1.23.5,<=1.24.4 pandas==1.5.3 docker==6.1.3 - kr8s==0.13.1 + kr8s==0.13.5 PyYAML==6.0.1 - azure-storage-blob==12.19 + azure-storage-blob==12.19.1 install_requires = %(syft)s @@ -92,11 +92,11 @@ dev = %(test_plugins)s %(telemetry)s bandit==1.7.7 - ruff==0.1.6 + ruff==0.3.0 importlib-metadata==6.8.0 - isort==5.12.0 + isort==5.13.2 mypy==1.7.1 - pre-commit==3.5.0 + pre-commit==3.6.2 safety>=2.4.0b2 telemetry = diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index d438f9ef9bd..2a0fcfa5b6d 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -93,6 +93,9 @@ logger.start() try: + # third party + from IPython import get_ipython + get_ipython() # noqa: F821 # TODO: add back later or auto detect # display( diff --git a/packages/syft/src/syft/abstract_node.py b/packages/syft/src/syft/abstract_node.py index 2341d6e4926..046c7e493ff 100644 --- a/packages/syft/src/syft/abstract_node.py +++ b/packages/syft/src/syft/abstract_node.py @@ -2,12 +2,17 @@ from enum import Enum from typing import Callable from typing import Optional +from typing import TYPE_CHECKING from typing import Union # relative from .serde.serializable import serializable from .types.uid import UID +if TYPE_CHECKING: + # relative + from .service.service import AbstractService + @serializable() class NodeType(str, Enum): @@ -37,5 +42,5 @@ class AbstractNode: node_side_type: Optional[NodeSideType] in_memory_workers: bool - def get_service(self, path_or_func: Union[str, Callable]) -> Callable: + def get_service(self, path_or_func: Union[str, Callable]) -> "AbstractService": raise NotImplementedError diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index cb4e2d8aa2f..aab69ab04f1 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -16,6 +16,7 @@ from typing import Tuple from typing import Union from typing import _GenericAlias +from typing import cast from typing import get_args from typing import get_origin @@ -52,6 +53,7 @@ from ..service.warnings import WarningContext from ..types.identity import Identity from ..types.syft_object import SYFT_OBJECT_VERSION_1 +from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftBaseObject from ..types.syft_object import SyftMigrationRegistry from ..types.syft_object import SyftObject @@ -63,6 +65,7 @@ if TYPE_CHECKING: # relative + from ..node import Node from ..service.job.job_stash import Job @@ -115,11 +118,11 @@ class APIEndpoint(SyftObject): module_path: str name: str description: str - doc_string: Optional[str] + doc_string: Optional[str] = None signature: Signature has_self: bool = False - pre_kwargs: Optional[Dict[str, Any]] - warning: Optional[APIEndpointWarning] + pre_kwargs: Optional[Dict[str, Any]] = None + warning: Optional[APIEndpointWarning] = None @serializable() @@ -132,16 +135,16 @@ class LibEndpoint(SyftBaseObject): module_path: str name: str description: str - doc_string: Optional[str] + doc_string: Optional[str] = None signature: Signature has_self: bool = False - pre_kwargs: Optional[Dict[str, Any]] + pre_kwargs: Optional[Dict[str, Any]] = None @serializable(attrs=["signature", "credentials", "serialized_message"]) class SignedSyftAPICall(SyftObject): __canonical_name__ = "SignedSyftAPICall" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 credentials: SyftVerifyKey signature: bytes @@ -205,7 +208,7 @@ class SyftAPIData(SyftBaseObject): __version__ = SYFT_OBJECT_VERSION_1 # fields - data: Any + data: Any = None def sign(self, credentials: SyftSigningKey) -> SignedSyftAPICall: signed_message = credentials.signing_key.sign(_serialize(self, to_bytes=True)) @@ -231,17 +234,17 @@ class RemoteFunction(SyftObject): signature: Signature path: str make_call: Callable - pre_kwargs: Optional[Dict[str, Any]] + pre_kwargs: Optional[Dict[str, Any]] = None communication_protocol: PROTOCOL_TYPE - warning: Optional[APIEndpointWarning] + warning: Optional[APIEndpointWarning] = None @property def __ipython_inspector_signature_override__(self) -> Optional[Signature]: return self.signature def prepare_args_and_kwargs( - self, args: List[Any], kwargs: Dict[str, Any] - ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + self, args: Union[list, tuple], kwargs: dict[str, Any] + ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]: # Validate and migrate args and kwargs res = validate_callable_args_and_kwargs(args, kwargs, self.signature) if isinstance(res, SyftError): @@ -278,7 +281,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: api_call = SyftAPICall( node_uid=self.node_uid, path=self.path, - args=_valid_args, + args=list(_valid_args), kwargs=_valid_kwargs, blocking=blocking, ) @@ -303,8 +306,8 @@ class RemoteUserCodeFunction(RemoteFunction): api: SyftAPI def prepare_args_and_kwargs( - self, args: List[Any], kwargs: Dict[str, Any] - ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + self, args: Union[list, tuple], kwargs: Dict[str, Any] + ) -> Union[SyftError, tuple[tuple, dict[str, Any]]]: # relative from ..service.action.action_object import convert_to_pointers @@ -505,9 +508,12 @@ def _repr_html_(self) -> Any: results = self.get_all() return results._repr_html_() + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return NotImplementedError + def debox_signed_syftapicall_response( - signed_result: SignedSyftAPICall, + signed_result: Union[SignedSyftAPICall, Any], ) -> Union[Any, SyftError]: if not isinstance(signed_result, SignedSyftAPICall): return SyftError(message="The result is not signed") @@ -824,16 +830,16 @@ def build_endpoint_tree( ) @property - def services(self) -> Optional[APIModule]: + def services(self) -> APIModule: if self.api_module is None: self.generate_endpoints() - return self.api_module + return cast(APIModule, self.api_module) @property - def lib(self) -> Optional[APIModule]: + def lib(self) -> APIModule: if self.libs is None: self.generate_endpoints() - return self.libs + return cast(APIModule, self.libs) def has_service(self, service_name: str) -> bool: return hasattr(self.services, service_name) @@ -939,25 +945,35 @@ class NodeIdentity(Identity): node_name: str @staticmethod - def from_api(api: SyftAPI) -> Optional[NodeIdentity]: + def from_api(api: SyftAPI) -> NodeIdentity: # stores the name root verify key of the domain node - if api.connection is not None: - node_metadata = api.connection.get_node_metadata(api.signing_key) - return NodeIdentity( - node_name=node_metadata.name, - node_id=api.node_uid, - verify_key=SyftVerifyKey.from_string(node_metadata.verify_key), - ) - return None + if api.connection is None: + raise ValueError("{api}'s connection is None. Can't get the node identity") + node_metadata = api.connection.get_node_metadata(api.signing_key) + return NodeIdentity( + node_name=node_metadata.name, + node_id=api.node_uid, + verify_key=SyftVerifyKey.from_string(node_metadata.verify_key), + ) @classmethod def from_change_context(cls, context: ChangeContext) -> NodeIdentity: + if context.node is None: + raise ValueError(f"{context}'s node is None") return cls( node_name=context.node.name, node_id=context.node.id, verify_key=context.node.signing_key.verify_key, ) + @classmethod + def from_node(cls, node: Node) -> NodeIdentity: + return cls( + node_name=node.name, + node_id=node.id, + verify_key=node.signing_key.verify_key, + ) + def __eq__(self, other: Any) -> bool: if not isinstance(other, NodeIdentity): return False @@ -1002,7 +1018,7 @@ def validate_callable_args_and_kwargs( if issubclass(v, EmailStr): v = str try: - check_type(key, value, v) # raises Exception + check_type(value, v) # raises Exception success = True break # only need one to match except Exception: # nosec @@ -1010,7 +1026,7 @@ def validate_callable_args_and_kwargs( if not success: raise TypeError() else: - check_type(key, value, t) # raises Exception + check_type(value, t) # raises Exception except TypeError: _type_str = getattr(t, "__name__", str(t)) msg = f"`{key}` must be of type `{_type_str}` not `{type(value).__name__}`" @@ -1038,10 +1054,10 @@ def validate_callable_args_and_kwargs( for v in t.__args__: if issubclass(v, EmailStr): v = str - check_type(param_key, arg, v) # raises Exception + check_type(arg, v) # raises Exception break # only need one to match else: - check_type(param_key, arg, t) # raises Exception + check_type(arg, t) # raises Exception except TypeError: t_arg = type(arg) if ( @@ -1063,5 +1079,5 @@ def validate_callable_args_and_kwargs( return _valid_args, _valid_kwargs -RemoteFunction.update_forward_refs() -RemoteUserCodeFunction.update_forward_refs() +RemoteFunction.model_rebuild(force=True) +RemoteUserCodeFunction.model_rebuild(force=True) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 11b4b2ab9a0..02dd8641f4a 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -13,19 +13,18 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING -from typing import Tuple from typing import Type from typing import Union from typing import cast # third party from argon2 import PasswordHasher -import pydantic +from pydantic import field_validator import requests from requests import Response from requests import Session from requests.adapters import HTTPAdapter -from requests.packages.urllib3.util.retry import Retry +from requests.packages.urllib3.util.retry import Retry # type: ignore[import-untyped] from typing_extensions import Self # relative @@ -94,9 +93,9 @@ def forward_message_to_proxy( proxy_target_uid: UID, path: str, credentials: Optional[SyftSigningKey] = None, - args: Optional[Tuple] = None, + args: Optional[list] = None, kwargs: Optional[Dict] = None, -): +) -> Union[Any, SyftError]: kwargs = {} if kwargs is None else kwargs args = [] if args is None else args call = SyftAPICall( @@ -136,14 +135,19 @@ class HTTPConnection(NodeConnection): __canonical_name__ = "HTTPConnection" __version__ = SYFT_OBJECT_VERSION_1 - proxy_target_uid: Optional[UID] url: GridURL + proxy_target_uid: Optional[UID] = None routes: Type[Routes] = Routes - session_cache: Optional[Session] + session_cache: Optional[Session] = None - @pydantic.validator("url", pre=True, always=True) - def make_url(cls, v: Union[GridURL, str]) -> GridURL: - return GridURL.from_url(v).as_container_host() + @field_validator("url", mode="before") + @classmethod + def make_url(cls, v: Any) -> Any: + return ( + GridURL.from_url(v).as_container_host() + if isinstance(v, (str, GridURL)) + else v + ) def with_proxy(self, proxy_target_uid: UID) -> Self: return HTTPConnection(url=self.url, proxy_target_uid=proxy_target_uid) @@ -155,7 +159,7 @@ def get_cache_key(self) -> str: def api_url(self) -> GridURL: return self.url.with_path(self.routes.ROUTE_API_CALL.value) - def to_blob_route(self, path: str, **kwargs) -> GridURL: + def to_blob_route(self, path: str, **kwargs: Any) -> GridURL: _path = self.routes.ROUTE_BLOB_STORE.value + path return self.url.with_path(_path) @@ -330,7 +334,7 @@ class PythonConnection(NodeConnection): __version__ = SYFT_OBJECT_VERSION_1 node: AbstractNode - proxy_target_uid: Optional[UID] + proxy_target_uid: Optional[UID] = None def with_proxy(self, proxy_target_uid: UID) -> Self: return PythonConnection(node=self.node, proxy_target_uid=proxy_target_uid) @@ -347,7 +351,7 @@ def get_node_metadata(self, credentials: SyftSigningKey) -> NodeMetadataJSON: else: return self.node.metadata.to(NodeMetadataJSON) - def to_blob_route(self, path: str, host=None) -> GridURL: + def to_blob_route(self, path: str, host: Optional[str] = None) -> GridURL: # TODO: FIX! if host is not None: return GridURL(host_or_ip=host, port=8333).with_path(path) @@ -474,8 +478,8 @@ def __init__( self.metadata = metadata self.credentials: Optional[SyftSigningKey] = credentials self._api = api - self.communication_protocol = None - self.current_protocol = None + self.communication_protocol: Optional[Union[int, str]] = None + self.current_protocol: Optional[Union[int, str]] = None self.post_init() @@ -485,7 +489,7 @@ def get_env(self) -> str: def post_init(self) -> None: if self.metadata is None: self._fetch_node_metadata(self.credentials) - + self.metadata = cast(NodeMetadataJSON, self.metadata) self.communication_protocol = self._get_communication_protocol( self.metadata.supported_protocols ) @@ -528,9 +532,11 @@ def create_project( project = project_create.start() return project - def sync_code_from_request(self, request): + # TODO: type of request should be REQUEST, but it will give circular import error + def sync_code_from_request(self, request: Any) -> Union[SyftSuccess, SyftError]: # relative from ..service.code.user_code import UserCode + from ..service.code.user_code import UserCodeStatusCollection from ..store.linked_obj import LinkedObject code: Union[UserCode, SyftError] = request.code @@ -541,9 +547,12 @@ def sync_code_from_request(self, request): code.node_uid = self.id code.user_verify_key = self.verify_key - def get_nested_codes(code: UserCode): - result = [] - for __, (linked_code_obj, _) in code.nested_codes.items(): + def get_nested_codes(code: UserCode) -> list[UserCode]: + result: list[UserCode] = [] + if code.nested_codes is None: + return result + + for _, (linked_code_obj, _) in code.nested_codes.items(): nested_code = linked_code_obj.resolve nested_code = deepcopy(nested_code) nested_code.node_uid = code.node_uid @@ -551,19 +560,29 @@ def get_nested_codes(code: UserCode): result.append(nested_code) result += get_nested_codes(nested_code) - updated_code_links = { - nested_code.service_func_name: (LinkedObject.from_obj(nested_code), {}) - for nested_code in result - } - code.nested_codes = updated_code_links return result + def get_code_statusses(codes: List[UserCode]) -> List[UserCodeStatusCollection]: + statusses = [] + for code in codes: + status = deepcopy(code.status) + statusses.append(status) + code.status_link = LinkedObject.from_obj(status, node_uid=code.node_uid) + return statusses + nested_codes = get_nested_codes(code) + statusses = get_code_statusses(nested_codes + [code]) for c in nested_codes + [code]: res = self.code.submit(c) if isinstance(res, SyftError): return res + + for status in statusses: + res = self.api.services.code_status.create(status) + if isinstance(res, SyftError): + return res + self._fetch_api(self.credentials) return SyftSuccess(message="User Code Submitted") @@ -591,7 +610,7 @@ def verify_key(self) -> SyftVerifyKey: @classmethod def from_url(cls, url: Union[str, GridURL]) -> Self: - return cls(connection=HTTPConnection(GridURL.from_url(url))) + return cls(connection=HTTPConnection(url=GridURL.from_url(url))) @classmethod def from_node(cls, node: AbstractNode) -> Self: @@ -625,8 +644,7 @@ def api(self) -> SyftAPI: # invalidate API if self._api is None or (self._api.signing_key != self.credentials): self._fetch_api(self.credentials) - - return self._api + return cast(SyftAPI, self._api) # we are sure self._api is not None after fetch def guest(self) -> Self: return self.__class__( @@ -641,7 +659,8 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]: self_node_route = connection_to_route(self.connection) remote_node_route = connection_to_route(client.connection) - + if client.metadata is None: + return SyftError(f"client {client}'s metadata is None!") result = self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, @@ -676,14 +695,16 @@ def settings(self) -> Optional[APIModule]: @property def notifications(self) -> Optional[APIModule]: - print( - "WARNING: Notifications is currently is in a beta state, so use carefully!" - ) - print("If possible try using client.requests/client.projects") if self.api.has_service("notifications"): return self.api.services.notifications return None + @property + def notifier(self) -> Optional[APIModule]: + if self.api.has_service("notifier"): + return self.api.services.notifier + return None + @property def peers(self) -> Optional[Union[List[NodePeer], SyftError]]: if self.api.has_service("network"): @@ -699,10 +720,11 @@ def me(self) -> Optional[Union[UserView, SyftError]]: def login_as_guest(self) -> Self: _guest_client = self.guest() - print( - f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " - f"{self.metadata.node_type.capitalize()}> as GUEST" - ) + if self.metadata is not None: + print( + f"Logged into <{self.name}: {self.metadata.node_side_type.capitalize()}-side " + f"{self.metadata.node_type.capitalize()}> as GUEST" + ) return _guest_client @@ -746,11 +768,11 @@ def login( client.__logged_in_user = email - if user_private_key is not None: + if user_private_key is not None and client.users is not None: client.__user_role = user_private_key.role client.__logged_in_username = client.users.get_current_user().name - if signing_key is not None: + if signing_key is not None and client.metadata is not None: print( f"Logged into <{client.name}: {client.metadata.node_side_type.capitalize()} side " f"{client.metadata.node_type.capitalize()}> as <{email}>" @@ -788,16 +810,18 @@ def login( # relative from ..node.node import CODE_RELOADER - CODE_RELOADER[thread_ident()] = client._reload_user_code + thread_id = thread_ident() + if thread_id is not None: + CODE_RELOADER[thread_id] = client._reload_user_code return client - def _reload_user_code(self): + def _reload_user_code(self) -> None: # relative from ..service.code.user_code import load_approved_policy_code user_code_items = self.code.get_all_for_user() - load_approved_policy_code(user_code_items) + load_approved_policy_code(user_code_items=user_code_items, context=None) def register( self, @@ -807,7 +831,7 @@ def register( password_verify: Optional[str] = None, institution: Optional[str] = None, website: Optional[str] = None, - ): + ) -> Optional[Union[SyftError, SyftSigningKey]]: if not email: email = input("Email: ") if not password: @@ -832,7 +856,10 @@ def register( except Exception as e: return SyftError(message=str(e)) - if self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value: + if ( + self.metadata + and self.metadata.node_side_type == NodeSideType.HIGH_SIDE.value + ): message = ( "You're registering a user to a high side " f"{self.metadata.node_type}, which could " @@ -841,7 +868,7 @@ def register( if self.metadata.show_warnings and not prompt_warning_message( message=message ): - return + return None response = self.connection.register(new_user=new_user) if isinstance(response, tuple): @@ -878,16 +905,20 @@ def _fetch_node_metadata(self, credentials: SyftSigningKey) -> None: metadata.check_version(__version__) self.metadata = metadata - def _fetch_api(self, credentials: SyftSigningKey): + def _fetch_api(self, credentials: SyftSigningKey) -> None: _api: SyftAPI = self.connection.get_api( credentials=credentials, communication_protocol=self.communication_protocol, ) - def refresh_callback(): + def refresh_callback() -> None: return self._fetch_api(self.credentials) _api.refresh_api_callback = refresh_callback + + if self.credentials is None: + raise ValueError(f"{self}'s credentials (signing key) is None!") + APIRegistry.set_api_for( node_uid=self.id, user_verify_key=self.credentials.verify_key, @@ -927,7 +958,7 @@ def register( password: str, institution: Optional[str] = None, website: Optional[str] = None, -): +) -> Optional[Union[SyftError, SyftSigningKey]]: guest_client = connect(url=url, port=port) return guest_client.register( name=name, @@ -944,13 +975,13 @@ def login_as_guest( node: Optional[AbstractNode] = None, port: Optional[int] = None, verbose: bool = True, -): +) -> SyftClient: _client = connect(url=url, node=node, port=port) if isinstance(_client, SyftError): return _client - if verbose: + if verbose and _client.metadata is not None: print( f"Logged into <{_client.name}: {_client.metadata.node_side_type.capitalize()}-" f"side {_client.metadata.node_type.capitalize()}> as GUEST" @@ -1023,7 +1054,7 @@ def add_client( password: str, connection: NodeConnection, syft_client: SyftClient, - ): + ) -> None: hash_key = cls._get_key(email, password, connection.get_cache_key()) cls.__credentials_store__[hash_key] = syft_client cls.__client_cache__[syft_client.id] = syft_client @@ -1034,7 +1065,7 @@ def add_client_by_uid_and_verify_key( verify_key: SyftVerifyKey, node_uid: UID, syft_client: SyftClient, - ): + ) -> None: hash_key = str(node_uid) + str(verify_key) cls.__client_cache__[hash_key] = syft_client @@ -1051,8 +1082,8 @@ def get_client( ) -> Optional[SyftClient]: # we have some bugs here so lets disable until they are fixed. return None - hash_key = cls._get_key(email, password, connection.get_cache_key()) - return cls.__credentials_store__.get(hash_key, None) + # hash_key = cls._get_key(email, password, connection.get_cache_key()) + # return cls.__credentials_store__.get(hash_key, None) @classmethod def get_client_for_node_uid(cls, node_uid: UID) -> Optional[SyftClient]: diff --git a/packages/syft/src/syft/client/connection.py b/packages/syft/src/syft/client/connection.py index 5b9928c8355..a94cb1c0707 100644 --- a/packages/syft/src/syft/client/connection.py +++ b/packages/syft/src/syft/client/connection.py @@ -10,7 +10,7 @@ class NodeConnection(SyftObject): __canonical_name__ = "NodeConnection" __version__ = SYFT_OBJECT_VERSION_1 - def get_cache_key() -> str: + def get_cache_key(self) -> str: raise NotImplementedError def __repr__(self) -> str: diff --git a/packages/syft/src/syft/client/domain_client.py b/packages/syft/src/syft/client/domain_client.py index d9d8cc3ae4e..57b60e0f489 100644 --- a/packages/syft/src/syft/client/domain_client.py +++ b/packages/syft/src/syft/client/domain_client.py @@ -4,11 +4,14 @@ # stdlib from pathlib import Path import re +from typing import List from typing import Optional from typing import TYPE_CHECKING from typing import Union +from typing import cast # third party +from hagrid.orchestra import NodeHandle from loguru import logger from tqdm import tqdm @@ -24,7 +27,9 @@ from ..service.dataset.dataset import CreateDataset from ..service.response import SyftError from ..service.response import SyftSuccess +from ..service.sync.diff_state import ResolvedSyncState from ..service.user.roles import Roles +from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole from ..types.blob_storage import BlobFile from ..types.uid import UID @@ -35,13 +40,14 @@ from .client import SyftClient from .client import login from .client import login_as_guest +from .connection import NodeConnection if TYPE_CHECKING: # relative from ..service.project.project import Project -def _get_files_from_glob(glob_path: str) -> list: +def _get_files_from_glob(glob_path: str) -> list[Path]: files = Path().glob(glob_path) return [f for f in files if f.is_file() and not f.name.startswith(".")] @@ -59,7 +65,7 @@ def _contains_subdir(dir: Path) -> bool: def add_default_uploader( - user, obj: Union[CreateDataset, CreateAsset] + user: UserView, obj: Union[CreateDataset, CreateAsset] ) -> Union[CreateDataset, CreateAsset]: uploader = None for contributor in obj.contributors: @@ -88,6 +94,9 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError # relative from ..types.twin_object import TwinObject + if self.users is None: + return SyftError(f"can't get user service for {self}") + user = self.users.get_current_user() dataset = add_default_uploader(user, dataset) for i in range(len(dataset.asset_list)): @@ -95,9 +104,12 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError dataset.asset_list[i] = add_default_uploader(user, asset) dataset._check_asset_must_contain_mock() - dataset_size = 0 + dataset_size: float = 0.0 # TODO: Refactor so that object can also be passed to generate warnings + + self.api.connection = cast(NodeConnection, self.api.connection) + metadata = self.api.connection.get_node_metadata(self.api.signing_key) if ( @@ -132,26 +144,77 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError dataset_size += get_mb_size(asset.data) dataset.mb_size = dataset_size valid = dataset.check() - if valid.ok(): - return self.api.services.dataset.add(dataset=dataset) - else: - if len(valid.err()) > 0: - return tuple(valid.err()) - return valid.err() + if isinstance(valid, SyftError): + return valid + return self.api.services.dataset.add(dataset=dataset) + + # def get_permissions_for_other_node( + # self, + # items: list[Union[ActionObject, SyftObject]], + # ) -> dict: + # if len(items) > 0: + # if not len({i.syft_node_location for i in items}) == 1 or ( + # not len({i.syft_client_verify_key for i in items}) == 1 + # ): + # raise ValueError("permissions from different nodes") + # item = items[0] + # api = APIRegistry.api_for( + # item.syft_node_location, item.syft_client_verify_key + # ) + # if api is None: + # raise ValueError( + # f"Can't access the api. Please log in to {item.syft_node_location}" + # ) + # return api.services.sync.get_permissions(items) + # else: + # return {} + + def apply_state( + self, resolved_state: ResolvedSyncState + ) -> Union[SyftSuccess, SyftError]: + if len(resolved_state.delete_objs): + raise NotImplementedError("TODO implement delete") + items = resolved_state.create_objs + resolved_state.update_objs + + action_objects = [x for x in items if isinstance(x, ActionObject)] + # permissions = self.get_permissions_for_other_node(items) + permissions: dict[UID, set[str]] = {} + for p in resolved_state.new_permissions: + if p.uid in permissions: + permissions[p.uid].add(p.permission_string) + else: + permissions[p.uid] = {p.permission_string} + + for action_object in action_objects: + action_object = action_object.refresh_object() + action_object.send(self) + + res = self.api.services.sync.sync_items(items, permissions) + if isinstance(res, SyftError): + return res + + # Add updated node state to store to have a previous_state for next sync + new_state = self.api.services.sync.get_state(add_to_store=True) + if isinstance(new_state, SyftError): + return new_state + + self._fetch_api(self.credentials) + return res def upload_files( self, file_list: Union[BlobFile, list[BlobFile], str, list[str], Path, list[Path]], - allow_recursive=False, - show_files=False, + allow_recursive: bool = False, + show_files: bool = False, ) -> Union[SyftSuccess, SyftError]: if not file_list: return SyftError(message="No files to upload") if not isinstance(file_list, list): - file_list = [file_list] + file_list = [file_list] # type: ignore[assignment] + file_list = cast(list, file_list) - expanded_file_list = [] + expanded_file_list: List[Union[BlobFile, Path]] = [] for file in file_list: if isinstance(file, BlobFile): @@ -212,7 +275,7 @@ def connect_to_gateway( handle: Optional[NodeHandle] = None, # noqa: F821 email: Optional[str] = None, password: Optional[str] = None, - ) -> None: + ) -> Optional[Union[SyftSuccess, SyftError]]: if via_client is not None: client = via_client elif handle is not None: @@ -228,9 +291,12 @@ def connect_to_gateway( res = self.exchange_route(client) if isinstance(res, SyftSuccess): - return SyftSuccess( - message=f"Connected {self.metadata.node_type} to {client.name} gateway" - ) + if self.metadata: + return SyftSuccess( + message=f"Connected {self.metadata.node_type} to {client.name} gateway" + ) + else: + return SyftSuccess(message=f"Connected to {client.name} gateway") return res @property @@ -303,10 +369,28 @@ def worker_images(self) -> Optional[APIModule]: return self.api.services.worker_image return None + @property + def sync(self) -> Optional[APIModule]: + if self.api.has_service("sync"): + return self.api.services.sync + return None + + @property + def code_status(self) -> Optional[APIModule]: + if self.api.has_service("code_status"): + return self.api.services.code_status + return None + + @property + def output(self) -> Optional[APIModule]: + if self.api.has_service("output"): + return self.api.services.output + return None + def get_project( self, - name: str = None, - uid: UID = None, + name: Optional[str] = None, + uid: Optional[UID] = None, ) -> Optional[Project]: """Get project by name or UID""" @@ -373,18 +457,17 @@ def _repr_html_(self) -> str: url = getattr(self.connection, "url", None) node_details = f"<strong>URL:</strong> {url}<br />" if url else "" - node_details += ( - f"<strong>Node Type:</strong> {self.metadata.node_type.capitalize()}<br />" - ) - node_side_type = ( - "Low Side" - if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value - else "High Side" - ) - node_details += f"<strong>Node Side Type:</strong> {node_side_type}<br />" - node_details += ( - f"<strong>Syft Version:</strong> {self.metadata.syft_version}<br />" - ) + if self.metadata is not None: + node_details += f"<strong>Node Type:</strong> {self.metadata.node_type.capitalize()}<br />" + node_side_type = ( + "Low Side" + if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value + else "High Side" + ) + node_details += f"<strong>Node Side Type:</strong> {node_side_type}<br />" + node_details += ( + f"<strong>Syft Version:</strong> {self.metadata.syft_version}<br />" + ) return f""" <style> diff --git a/packages/syft/src/syft/client/enclave_client.py b/packages/syft/src/syft/client/enclave_client.py index e0a09167805..59c11aaf50b 100644 --- a/packages/syft/src/syft/client/enclave_client.py +++ b/packages/syft/src/syft/client/enclave_client.py @@ -2,14 +2,20 @@ from __future__ import annotations # stdlib +from typing import Any from typing import Optional from typing import TYPE_CHECKING +from typing import Union + +# third party +from hagrid.orchestra import NodeHandle # relative from ..abstract_node import NodeSideType from ..client.api import APIRegistry from ..img.base64 import base64read from ..serde.serializable import serializable +from ..service.metadata.node_metadata import NodeMetadataJSON from ..service.network.routes import NodeRouteType from ..service.response import SyftError from ..service.response import SyftSuccess @@ -68,7 +74,7 @@ def connect_to_gateway( handle: Optional[NodeHandle] = None, # noqa: F821 email: Optional[str] = None, password: Optional[str] = None, - ) -> None: + ) -> Optional[Union[SyftSuccess, SyftError]]: if via_client is not None: client = via_client elif handle is not None: @@ -82,17 +88,20 @@ def connect_to_gateway( if isinstance(client, SyftError): return client + self.metadata: NodeMetadataJSON = self.metadata res = self.exchange_route(client) + if isinstance(res, SyftSuccess): return SyftSuccess( message=f"Connected {self.metadata.node_type} to {client.name} gateway" ) + return res def get_enclave_metadata(self) -> EnclaveMetadata: return EnclaveMetadata(route=self.connection.route) - def request_code_execution(self, code: SubmitUserCode): + def request_code_execution(self, code: SubmitUserCode) -> Union[Any, SyftError]: # relative from ..service.code.user_code_service import SubmitUserCode @@ -100,6 +109,8 @@ def request_code_execution(self, code: SubmitUserCode): raise Exception( f"The input code should be of type: {SubmitUserCode} got:{type(code)}" ) + if code.input_policy_init_kwargs is None: + raise ValueError(f"code {code}'s input_policy_init_kwargs is None") enclave_metadata = self.get_enclave_metadata() @@ -154,18 +165,17 @@ def _repr_html_(self) -> str: url = getattr(self.connection, "url", None) node_details = f"<strong>URL:</strong> {url}<br />" if url else "" - node_details += ( - f"<strong>Node Type:</strong> {self.metadata.node_type.capitalize()}<br />" - ) - node_side_type = ( - "Low Side" - if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value - else "High Side" - ) - node_details += f"<strong>Node Side Type:</strong> {node_side_type}<br />" - node_details += ( - f"<strong>Syft Version:</strong> {self.metadata.syft_version}<br />" - ) + if self.metadata: + node_details += f"<strong>Node Type:</strong> {self.metadata.node_type.capitalize()}<br />" + node_side_type = ( + "Low Side" + if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value + else "High Side" + ) + node_details += f"<strong>Node Side Type:</strong> {node_side_type}<br />" + node_details += ( + f"<strong>Syft Version:</strong> {self.metadata.syft_version}<br />" + ) return f""" <style> diff --git a/packages/syft/src/syft/client/gateway_client.py b/packages/syft/src/syft/client/gateway_client.py index 2a569cabd22..cae8bc076cf 100644 --- a/packages/syft/src/syft/client/gateway_client.py +++ b/packages/syft/src/syft/client/gateway_client.py @@ -2,11 +2,9 @@ from typing import Any from typing import List from typing import Optional +from typing import Type from typing import Union -# third party -from typing_extensions import Self - # relative from ..abstract_node import NodeSideType from ..abstract_node import NodeType @@ -26,7 +24,7 @@ class GatewayClient(SyftClient): # TODO: add widget repr for gateway client - def proxy_to(self, peer: Any) -> Self: + def proxy_to(self, peer: Any) -> SyftClient: # relative from .domain_client import DomainClient from .enclave_client import EnclaveClient @@ -34,7 +32,7 @@ def proxy_to(self, peer: Any) -> Self: connection = self.connection.with_proxy(peer.id) metadata = connection.get_node_metadata(credentials=SyftSigningKey.generate()) if metadata.node_type == NodeType.DOMAIN.value: - client_type = DomainClient + client_type: Type[SyftClient] = DomainClient elif metadata.node_type == NodeType.ENCLAVE.value: client_type = EnclaveClient else: @@ -53,8 +51,8 @@ def proxy_client_for( name: str, email: Optional[str] = None, password: Optional[str] = None, - **kwargs, - ): + **kwargs: Any, + ) -> SyftClient: peer = None if self.api.has_service("network"): peer = self.api.services.network.get_peer_by_name(name=name) @@ -97,18 +95,17 @@ def _repr_html_(self) -> str: url = getattr(self.connection, "url", None) node_details = f"<strong>URL:</strong> {url}<br />" if url else "" - node_details += ( - f"<strong>Node Type:</strong> {self.metadata.node_type.capitalize()}<br />" - ) - node_side_type = ( - "Low Side" - if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value - else "High Side" - ) - node_details += f"<strong>Node Side Type:</strong> {node_side_type}<br />" - node_details += ( - f"<strong>Syft Version:</strong> {self.metadata.syft_version}<br />" - ) + if self.metadata: + node_details += f"<strong>Node Type:</strong> {self.metadata.node_type.capitalize()}<br />" + node_side_type = ( + "Low Side" + if self.metadata.node_side_type == NodeSideType.LOW_SIDE.value + else "High Side" + ) + node_details += f"<strong>Node Side Type:</strong> {node_side_type}<br />" + node_details += ( + f"<strong>Syft Version:</strong> {self.metadata.syft_version}<br />" + ) return f""" <style> @@ -157,7 +154,7 @@ class ProxyClient(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 routing_client: GatewayClient - node_type: Optional[NodeType] + node_type: Optional[NodeType] = None def retrieve_nodes(self) -> List[NodePeer]: if self.node_type in [NodeType.DOMAIN, NodeType.ENCLAVE]: @@ -178,7 +175,7 @@ def _repr_html_(self) -> str: def __len__(self) -> int: return len(self.retrieve_nodes()) - def __getitem__(self, key: int): + def __getitem__(self, key: Union[int, str]) -> SyftClient: if not isinstance(key, int): raise SyftException(f"Key: {key} must be an integer") diff --git a/packages/syft/src/syft/client/registry.py b/packages/syft/src/syft/client/registry.py index c23d21b7b8e..b67a3dd1c5d 100644 --- a/packages/syft/src/syft/client/registry.py +++ b/packages/syft/src/syft/client/registry.py @@ -23,7 +23,6 @@ from ..util.logger import error from ..util.logger import warning from .client import SyftClient as Client -from .enclave_client import EnclaveClient NETWORK_REGISTRY_URL = ( "https://raw.githubusercontent.com/OpenMined/NetworkRegistry/main/gateways.json" @@ -200,8 +199,10 @@ def check_network(network: Dict) -> Optional[Dict[Any, Any]]: return online_networks @property - def online_domains(self) -> List[Tuple[NodePeer, NodeMetadataJSON]]: - def check_domain(peer: NodePeer) -> Optional[Tuple[NodePeer, NodeMetadataJSON]]: + def online_domains(self) -> List[Tuple[NodePeer, Optional[NodeMetadataJSON]]]: + def check_domain( + peer: NodePeer, + ) -> Optional[tuple[NodePeer, Optional[NodeMetadataJSON]]]: try: guest_client = peer.guest_client metadata = guest_client.metadata @@ -232,14 +233,15 @@ def check_domain(peer: NodePeer) -> Optional[Tuple[NodePeer, NodeMetadataJSON]]: online_domains.append(each) return online_domains - def __make_dict__(self) -> List[Dict[str, str]]: + def __make_dict__(self) -> list[dict[str, Any]]: on = self.online_domains domains = [] - domain_dict = {} + domain_dict: dict[str, Any] = {} for domain, metadata in on: domain_dict["name"] = domain.name - domain_dict["organization"] = metadata.organization - domain_dict["version"] = metadata.syft_version + if metadata is not None: + domain_dict["organization"] = metadata.organization + domain_dict["version"] = metadata.syft_version route = None if len(domain.node_routes) > 0: route = domain.pick_highest_priority_route() @@ -371,7 +373,7 @@ def create_client(enclave: Dict[str, Any]) -> Client: error(f"Failed to login with: {enclave}. {e}") raise SyftException(f"Failed to login with: {enclave}. {e}") - def __getitem__(self, key: Union[str, int]) -> EnclaveClient: + def __getitem__(self, key: Union[str, int]) -> Client: if isinstance(key, int): return self.create_client(enclave=self.online_enclaves[key]) else: diff --git a/packages/syft/src/syft/client/search.py b/packages/syft/src/syft/client/search.py index 66d46c5b3d4..9a979cb6475 100644 --- a/packages/syft/src/syft/client/search.py +++ b/packages/syft/src/syft/client/search.py @@ -1,5 +1,6 @@ # stdlib from typing import List +from typing import Optional from typing import Tuple from typing import Union @@ -59,7 +60,7 @@ def __init__(self, domains: DomainRegistry): @staticmethod def __search_one_node( peer_tuple: Tuple[NodePeer, NodeMetadataJSON], name: str - ) -> List[Dataset]: + ) -> Tuple[Optional[SyftClient], List[Dataset]]: try: peer, _ = peer_tuple client = peer.guest_client @@ -74,7 +75,7 @@ def __search(self, name: str) -> List[Tuple[SyftClient, List[Dataset]]]: ] # filter out SyftError - filtered = ((client, result) for client, result in results if result) + filtered = [(client, result) for client, result in results if client and result] return filtered def search(self, name: str) -> SearchResults: diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py new file mode 100644 index 00000000000..cb3d8fc7e3d --- /dev/null +++ b/packages/syft/src/syft/client/syncing.py @@ -0,0 +1,195 @@ +# stdlib +from time import sleep +from typing import List +from typing import Optional +from typing import Union + +# relative +from ..service.action.action_object import ActionObject +from ..service.action.action_permissions import ActionObjectPermission +from ..service.action.action_permissions import ActionPermission +from ..service.code.user_code import UserCode +from ..service.job.job_stash import Job +from ..service.log.log import SyftLog +from ..service.sync.diff_state import NodeDiff +from ..service.sync.diff_state import ObjectDiffBatch +from ..service.sync.diff_state import ResolvedSyncState +from ..service.sync.sync_state import SyncState + + +def compare_states(low_state: SyncState, high_state: SyncState) -> NodeDiff: + return NodeDiff.from_sync_state(low_state=low_state, high_state=high_state) + + +def get_user_input_for_resolve() -> Optional[str]: + print( + "Do you want to keep the low state or the high state for these objects? choose 'low' or 'high'" + ) + + while True: + decision = input() + decision = decision.lower() + + if decision in ["low", "high"]: + return decision + else: + print("Please choose between `low` or `high`") + + +def resolve( + state: NodeDiff, decision: Optional[str] = None, share_private_objects: bool = False +) -> tuple[ResolvedSyncState, ResolvedSyncState]: + # TODO: only add permissions for objects where we manually give permission + # Maybe default read permission for some objects (high -> low) + resolved_state_low: ResolvedSyncState = ResolvedSyncState(alias="low") + resolved_state_high: ResolvedSyncState = ResolvedSyncState(alias="high") + + for batch_diff in state.hierarchies: + batch_decision = decision + if all(diff.status == "SAME" for diff in batch_diff.diffs): + # Hierarchy has no diffs + continue + + print(batch_diff.__repr__()) + + # ask question: which side do you want + # ask question: The batch has private items that you may want to share with the related user + # user with verify key: abc. The items are + # Log with id (123) + # Result with id (567) + # do you want to give read permission to items + # TODO: get decision + # get items + if batch_decision is None: + batch_decision = get_user_input_for_resolve() + + get_user_input_for_batch_permissions( + batch_diff, share_private_objects=share_private_objects + ) + + print(f"Decision: Syncing {len(batch_diff)} objects from {batch_decision} side") + + for object_diff in batch_diff.diffs: + resolved_state_low.add_cruds_from_diff(object_diff, batch_decision) + resolved_state_high.add_cruds_from_diff(object_diff, batch_decision) + + resolved_state_low.new_permissions += object_diff.new_low_permissions + + print() + print("=" * 100) + print() + + return resolved_state_low, resolved_state_high + + +def get_user_input_for_batch_permissions( + batch_diff: ObjectDiffBatch, share_private_objects: bool = False +) -> None: + private_high_objects: List[Union[SyftLog, ActionObject]] = [] + + for diff in batch_diff.diffs: + if isinstance(diff.high_obj, (SyftLog, ActionObject)): + private_high_objects.append(diff) + + user_codes_high: List[UserCode] = [ + diff.high_obj + for diff in batch_diff.diffs + if isinstance(diff.high_obj, UserCode) + ] + if not len(user_codes_high) < 2: + raise ValueError("too many user codes") + + if user_codes_high: + user_code_high = user_codes_high[0] + + # TODO: only do this under condition that its accepted to sync + high_job_diffs = [ + diff for diff in batch_diff.diffs if isinstance(diff.high_obj, Job) + ] + + for diff in high_job_diffs: + read_permission_job = ActionObjectPermission( + uid=diff.object_id, + permission=ActionPermission.READ, + credentials=user_code_high.user_verify_key, + ) + diff.new_low_permissions.append(read_permission_job) + + if share_private_objects: + for diff in private_high_objects: + read_permission_private_obj = ActionObjectPermission( + uid=diff.object_id, + permission=ActionPermission.READ, + credentials=user_code_high.user_verify_key, + ) + diff.new_low_permissions.append(read_permission_private_obj) + + else: + print( + f"""This batch of updates contains new private objects on the high side that you may want \ + to share with user {user_code_high.user_verify_key}.""" + ) + while True: + if len(private_high_objects) > 0: + if user_code_high is None: + raise ValueError("No usercode found for private objects") + objects_str = "\n".join( + [ + f"{diff.object_type} #{diff.object_id}" + for diff in private_high_objects + ] + ) + print( + f""" + You currently have the following private objects: + + {objects_str} + + Do you want to share some of these private objects? If so type the first 3 characters of the id e.g. 'abc'. + If you dont want to share any more private objects, type "no" + """, + flush=True, + ) + else: + break + + sleep(0.1) + res = input() + if res == "no": + break + elif len(res) >= 3: + matches = [ + diff + for diff in private_high_objects + if str(diff.object_id).startswith(res) + ] + if len(matches) == 0: + print("Invalid input") + continue + elif len(matches) == 1: + diff = matches[0] + print() + print("=" * 100) + print() + print( + f""" + Setting permissions for {diff.object_type} #{diff.object_id} to share with ABC, + this will become effective when you call client.apply_state(<resolved_state>)) + """ + ) + private_high_objects.remove(diff) + read_permission_private_obj = ActionObjectPermission( + uid=diff.object_id, + permission=ActionPermission.READ, + credentials=user_code_high.user_verify_key, + ) + diff.new_low_permissions.append(read_permission_private_obj) + + # questions + # Q:do we also want to give read permission if we defined that by accept_by_depositing_result? + # A:only if we pass: sync_read_permission to resolve + else: + print("Found multiple matches for provided id, exiting") + break + else: + print("invalid input") diff --git a/packages/syft/src/syft/custom_worker/builder.py b/packages/syft/src/syft/custom_worker/builder.py index 5e479cee71c..8109ac94b43 100644 --- a/packages/syft/src/syft/custom_worker/builder.py +++ b/packages/syft/src/syft/custom_worker/builder.py @@ -3,6 +3,7 @@ import os.path from pathlib import Path from typing import Any +from typing import Optional # relative from .builder_docker import DockerBuilder @@ -39,7 +40,7 @@ def builder(self) -> BuilderBase: def build_image( self, config: WorkerConfig, - tag: str = None, + tag: Optional[str] = None, **kwargs: Any, ) -> ImageBuildResult: """ @@ -82,7 +83,7 @@ def _build_dockerfile( self, config: DockerWorkerConfig, tag: str, - **kwargs, + **kwargs: Any, ) -> ImageBuildResult: return self.builder.build_image( dockerfile=config.dockerfile, diff --git a/packages/syft/src/syft/custom_worker/builder_docker.py b/packages/syft/src/syft/custom_worker/builder_docker.py index 3f7f16cf185..6b68d1e99c2 100644 --- a/packages/syft/src/syft/custom_worker/builder_docker.py +++ b/packages/syft/src/syft/custom_worker/builder_docker.py @@ -2,6 +2,7 @@ import contextlib import io from pathlib import Path +from typing import Any from typing import Iterable from typing import Optional @@ -22,11 +23,11 @@ class DockerBuilder(BuilderBase): def build_image( self, tag: str, - dockerfile: str = None, - dockerfile_path: Path = None, + dockerfile: Optional[str] = None, + dockerfile_path: Optional[Path] = None, buildargs: Optional[dict] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ImageBuildResult: if dockerfile: # convert dockerfile string to file-like object kwargs["fileobj"] = io.BytesIO(dockerfile.encode("utf-8")) @@ -53,9 +54,10 @@ def build_image( def push_image( self, tag: str, - registry_url: str, username: str, password: str, + registry_url: str, + **kwargs: Any, ) -> ImagePushResult: with contextlib.closing(docker.from_env()) as client: if registry_url and username and password: diff --git a/packages/syft/src/syft/custom_worker/builder_k8s.py b/packages/syft/src/syft/custom_worker/builder_k8s.py index 1be16d3c0ac..24e494c7756 100644 --- a/packages/syft/src/syft/custom_worker/builder_k8s.py +++ b/packages/syft/src/syft/custom_worker/builder_k8s.py @@ -1,6 +1,7 @@ # stdlib from hashlib import sha256 from pathlib import Path +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -33,16 +34,16 @@ class BuildFailed(Exception): class KubernetesBuilder(BuilderBase): COMPONENT = "builder" - def __init__(self): + def __init__(self) -> None: self.client = get_kr8s_client() def build_image( self, tag: str, - dockerfile: str = None, - dockerfile_path: Path = None, + dockerfile: Optional[str] = None, + dockerfile_path: Optional[Path] = None, buildargs: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> ImageBuildResult: image_digest = None logs = None @@ -102,7 +103,7 @@ def push_image( username: str, password: str, registry_url: str, - **kwargs, + **kwargs: Any, ) -> ImagePushResult: exit_code = 1 logs = None @@ -354,7 +355,9 @@ def _create_push_job( ) return KubeUtils.create_or_get(job) - def _create_push_secret(self, id: str, url: str, username: str, password: str): + def _create_push_secret( + self, id: str, url: str, username: str, password: str + ) -> Secret: return KubeUtils.create_dockerconfig_secret( secret_name=f"push-secret-{id}", component=KubernetesBuilder.COMPONENT, diff --git a/packages/syft/src/syft/custom_worker/builder_types.py b/packages/syft/src/syft/custom_worker/builder_types.py index 8007bf476e9..9464bafced5 100644 --- a/packages/syft/src/syft/custom_worker/builder_types.py +++ b/packages/syft/src/syft/custom_worker/builder_types.py @@ -2,6 +2,7 @@ from abc import ABC from abc import abstractmethod from pathlib import Path +from typing import Any from typing import Optional # third party @@ -33,20 +34,22 @@ class ImagePushResult(BaseModel): class BuilderBase(ABC): @abstractmethod def build_image( + self, tag: str, - dockerfile: str = None, - dockerfile_path: Path = None, + dockerfile: Optional[str] = None, + dockerfile_path: Optional[Path] = None, buildargs: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> ImageBuildResult: pass @abstractmethod def push_image( + self, tag: str, username: str, password: str, registry_url: str, - **kwargs, + **kwargs: Any, ) -> ImagePushResult: pass diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index c54d4f77c40..b35505f6994 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -12,7 +12,7 @@ # third party import docker from packaging import version -from pydantic import validator +from pydantic import field_validator from typing_extensions import Self import yaml @@ -54,10 +54,11 @@ class CustomBuildConfig(SyftBaseModel): # f"Python version must be between {PYTHON_MIN_VER} and {PYTHON_MAX_VER}" # ) - @validator("python_packages") + @field_validator("python_packages") + @classmethod def validate_python_packages(cls, pkgs: List[str]) -> List[str]: for pkg in pkgs: - ver_parts = () + ver_parts: Union[tuple, list] = () name_ver = pkg.split("==") if len(name_ver) != 2: raise ValueError(_malformed_python_package_error_msg(pkg)) @@ -72,13 +73,13 @@ def validate_python_packages(cls, pkgs: List[str]) -> List[str]: return pkgs - def merged_python_pkgs(self, sep=" ") -> str: + def merged_python_pkgs(self, sep: str = " ") -> str: return sep.join(self.python_packages) - def merged_system_pkgs(self, sep=" ") -> str: + def merged_system_pkgs(self, sep: str = " ") -> str: return sep.join(self.system_packages) - def merged_custom_cmds(self, sep=";") -> str: + def merged_custom_cmds(self, sep: str = ";") -> str: return sep.join(self.custom_cmds) @@ -114,7 +115,7 @@ def get_signature(self) -> str: class PrebuiltWorkerConfig(WorkerConfig): # tag that is already built and pushed in some registry tag: str - description: Optional[str] + description: Optional[str] = None def __str__(self) -> str: if self.description: @@ -129,10 +130,11 @@ def set_description(self, description_text: str) -> None: @serializable() class DockerWorkerConfig(WorkerConfig): dockerfile: str - file_name: Optional[str] - description: Optional[str] + file_name: Optional[str] = None + description: Optional[str] = None - @validator("dockerfile") + @field_validator("dockerfile") + @classmethod def validate_dockerfile(cls, dockerfile: str) -> str: if not dockerfile: raise ValueError("Dockerfile cannot be empty") @@ -166,7 +168,9 @@ def __str__(self) -> str: def set_description(self, description_text: str) -> None: self.description = description_text - def test_image_build(self, tag: str, **kwargs) -> Union[SyftSuccess, SyftError]: + def test_image_build( + self, tag: str, **kwargs: Any + ) -> Union[SyftSuccess, SyftError]: try: with contextlib.closing(docker.from_env()) as client: if not client.ping(): diff --git a/packages/syft/src/syft/custom_worker/k8s.py b/packages/syft/src/syft/custom_worker/k8s.py index fb777f6ec6c..067d23d1a3f 100644 --- a/packages/syft/src/syft/custom_worker/k8s.py +++ b/packages/syft/src/syft/custom_worker/k8s.py @@ -17,6 +17,7 @@ from kr8s.objects import Pod from kr8s.objects import Secret from pydantic import BaseModel +from typing_extensions import Self # Time after which Job will be deleted JOB_COMPLETION_TTL = 60 @@ -47,7 +48,7 @@ class PodCondition(BaseModel): ready: bool @classmethod - def from_conditions(cls, conditions: list): + def from_conditions(cls, conditions: list) -> Self: pod_cond = KubeUtils.list_dict_unpack(conditions, key="type", value="status") pod_cond_flags = {k: v == "True" for k, v in pod_cond.items()} return cls( @@ -62,12 +63,12 @@ class ContainerStatus(BaseModel): ready: bool running: bool waiting: bool - reason: Optional[str] # when waiting=True - message: Optional[str] # when waiting=True - startedAt: Optional[str] # when running=True + reason: Optional[str] = None # when waiting=True + message: Optional[str] = None # when waiting=True + startedAt: Optional[str] = None # when running=True @classmethod - def from_status(cls, cstatus: dict): + def from_status(cls, cstatus: dict) -> Self: cstate = cstatus.get("state", {}) return cls( @@ -86,7 +87,7 @@ class PodStatus(BaseModel): container: ContainerStatus @classmethod - def from_status_dict(cls: "PodStatus", status: dict): + def from_status_dict(cls, status: dict) -> Self: return cls( phase=PodPhase(status.get("phase", "Unknown")), condition=PodCondition.from_conditions(status.get("conditions", [])), @@ -120,8 +121,10 @@ def resolve_pod(client: kr8s.Api, pod: Union[str, Pod]) -> Optional[Pod]: for _pod in client.get("pods", pod): return _pod + return None + @staticmethod - def get_logs(pods: List[Pod]): + def get_logs(pods: List[Pod]) -> str: """Combine and return logs for all the pods as string""" logs = [] for pod in pods: @@ -142,11 +145,13 @@ def get_pod_status(pod: Pod) -> Optional[PodStatus]: def get_pod_env(pod: Pod) -> Optional[List[Dict]]: """Return the environment variables of the first container in the pod.""" if not pod: - return + return None for container in pod.spec.containers: return container.env.to_list() + return None + @staticmethod def get_container_exit_code(pods: List[Pod]) -> List[int]: """Return the exit codes of all the containers in the given pods.""" @@ -203,11 +208,11 @@ def create_secret( type: str, component: str, data: str, - encoded=True, + encoded: bool = True, ) -> Secret: if not encoded: for k, v in data.items(): - data[k] = KubeUtils.b64encode_secret(v) + data[k] = KubeUtils.b64encode_secret(v) # type: ignore secret = Secret( { diff --git a/packages/syft/src/syft/custom_worker/runner_k8s.py b/packages/syft/src/syft/custom_worker/runner_k8s.py index 3b35830c0f4..25d3dbfd2a3 100644 --- a/packages/syft/src/syft/custom_worker/runner_k8s.py +++ b/packages/syft/src/syft/custom_worker/runner_k8s.py @@ -1,4 +1,5 @@ # stdlib +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -21,7 +22,7 @@ class KubernetesRunner: - def __init__(self): + def __init__(self) -> None: self.client = get_kr8s_client() def create_pool( @@ -34,7 +35,7 @@ def create_pool( reg_username: Optional[str] = None, reg_password: Optional[str] = None, reg_url: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> StatefulSet: try: # create pull secret if registry credentials are passed @@ -134,8 +135,8 @@ def _create_image_pull_secret( reg_username: str, reg_password: str, reg_url: str, - **kwargs, - ): + **kwargs: Any, + ) -> Secret: return KubeUtils.create_dockerconfig_secret( secret_name=f"pull-secret-{pool_name}", component=pool_name, @@ -148,11 +149,11 @@ def _create_stateful_set( self, pool_name: str, tag: str, - replicas=1, + replicas: int = 1, env_vars: Optional[List[Dict]] = None, mount_secrets: Optional[Dict] = None, pull_secret: Optional[Secret] = None, - **kwargs, + **kwargs: Any, ) -> StatefulSet: """Create a stateful set for a pool""" diff --git a/packages/syft/src/syft/exceptions/exception.py b/packages/syft/src/syft/exceptions/exception.py index 6bab71a747a..16f1717686b 100644 --- a/packages/syft/src/syft/exceptions/exception.py +++ b/packages/syft/src/syft/exceptions/exception.py @@ -2,6 +2,9 @@ from typing import List from typing import Optional +# third party +from typing_extensions import Self + # relative from ..service.context import NodeServiceContext from ..service.response import SyftError @@ -16,7 +19,7 @@ def __init__(self, message: str, roles: Optional[List[ServiceRole]] = None): self.message = message self.roles = roles if roles else [ServiceRole.ADMIN] - def raise_with_context(self, context: NodeServiceContext): + def raise_with_context(self, context: NodeServiceContext) -> Self: self.context = context return self diff --git a/packages/syft/src/syft/external/__init__.py b/packages/syft/src/syft/external/__init__.py index 467294ecd5c..2de40a58f87 100644 --- a/packages/syft/src/syft/external/__init__.py +++ b/packages/syft/src/syft/external/__init__.py @@ -16,7 +16,7 @@ # if the external library is not installed, we prompt the user # to install it with the pip package name. -OBLV = str_to_bool(os.getenv("ENABLE_OBLV", "false")) +OBLV = str_to_bool(os.getenv("OBLV_ENABLED", "false")) EXTERNAL_LIBS = { "oblv": { diff --git a/packages/syft/src/syft/external/oblv/auth.py b/packages/syft/src/syft/external/oblv/auth.py index a3a382aff88..2360e7b477f 100644 --- a/packages/syft/src/syft/external/oblv/auth.py +++ b/packages/syft/src/syft/external/oblv/auth.py @@ -1,12 +1,13 @@ # stdlib from getpass import getpass +from typing import Any from typing import Optional # third party from oblv_ctl import authenticate -def login(apikey: Optional[str] = None): +def login(apikey: Optional[str] = None) -> Any: if apikey is None: apikey = getpass("Please provide your oblv API_KEY to login:") diff --git a/packages/syft/src/syft/external/oblv/constants.py b/packages/syft/src/syft/external/oblv/constants.py index b2d93ed3011..444c915f986 100644 --- a/packages/syft/src/syft/external/oblv/constants.py +++ b/packages/syft/src/syft/external/oblv/constants.py @@ -7,4 +7,4 @@ VISIBILITY = "private" REF_TYPE = "branch" LOCAL_MODE = True -DOMAIN_CONNECTION_PORT = 3030 +OBLV_LOCALHOST_PORT = 3030 diff --git a/packages/syft/src/syft/external/oblv/deployment_client.py b/packages/syft/src/syft/external/oblv/deployment_client.py index 6b4c1f6d304..deecee225a1 100644 --- a/packages/syft/src/syft/external/oblv/deployment_client.py +++ b/packages/syft/src/syft/external/oblv/deployment_client.py @@ -14,10 +14,11 @@ from typing import List from typing import Optional from typing import TYPE_CHECKING +from typing import Union # third party from oblv_ctl import OblvClient -from pydantic import validator +from pydantic import field_validator import requests # relative @@ -26,7 +27,9 @@ from ...client.client import login from ...client.client import login_as_guest from ...client.enclave_client import EnclaveMetadata +from ...node.credentials import SyftSigningKey from ...serde.serializable import serializable +from ...service.response import SyftError from ...types.uid import UID from ...util.util import bcolors from .constants import LOCAL_MODE @@ -43,10 +46,11 @@ class OblvMetadata(EnclaveMetadata): """Contains Metadata to connect to Oblivious Enclave""" - deployment_id: Optional[str] - oblv_client: Optional[OblvClient] + deployment_id: Optional[str] = None + oblv_client: Optional[OblvClient] = None - @validator("deployment_id") + @field_validator("deployment_id") + @classmethod def check_valid_deployment_id(cls, deployment_id: str) -> str: if not deployment_id and not LOCAL_MODE: raise ValueError( @@ -56,7 +60,8 @@ def check_valid_deployment_id(cls, deployment_id: str) -> str: ) return deployment_id - @validator("oblv_client") + @field_validator("oblv_client") + @classmethod def check_valid_oblv_client(cls, oblv_client: OblvClient) -> OblvClient: if not oblv_client and not LOCAL_MODE: raise ValueError( @@ -75,7 +80,7 @@ class DeploymentClient: __conn_string: str __logs: Any __process: Any - __enclave_client: SyftClient + __enclave_client: Optional[SyftClient] def __init__( self, @@ -90,14 +95,14 @@ def __init__( "domain_clients should be populated with valid domain nodes" ) self.deployment_id = deployment_id - self.key_name = key_name + self.key_name: Optional[str] = key_name self.oblv_client = oblv_client self.domain_clients = domain_clients self.__conn_string = "" self.__process = None self.__logs = None self._api = api - self.__enclave_client = None + self.__enclave_client: Optional[SyftClient] = None def make_request_to_enclave( self, @@ -245,7 +250,7 @@ def register( password: str, institution: Optional[str] = None, website: Optional[str] = None, - ): + ) -> Optional[Union[SyftError, SyftSigningKey]]: self.check_connection_string() guest_client = login_as_guest(url=self.__conn_string) return guest_client.register( diff --git a/packages/syft/src/syft/external/oblv/oblv_service.py b/packages/syft/src/syft/external/oblv/oblv_service.py index beb4d4cad92..cb4b2bf2971 100644 --- a/packages/syft/src/syft/external/oblv/oblv_service.py +++ b/packages/syft/src/syft/external/oblv/oblv_service.py @@ -8,7 +8,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Tuple from typing import cast # third party @@ -36,8 +35,8 @@ from ...store.document_store import DocumentStore from ...types.uid import UID from ...util.util import find_available_port -from .constants import DOMAIN_CONNECTION_PORT from .constants import LOCAL_MODE +from .constants import OBLV_LOCALHOST_PORT from .deployment_client import OblvMetadata from .exceptions import OblvEnclaveError from .exceptions import OblvProxyConnectPCRError @@ -191,7 +190,7 @@ def make_request_to_enclave( def create_keys_from_db( oblv_keys_stash: OblvKeysStash, verify_key: SyftVerifyKey, oblv_key_name: str -): +) -> None: oblv_key_path = os.path.expanduser(os.getenv("OBLV_KEY_PATH", "~/.oblv")) os.makedirs(oblv_key_path, exist_ok=True) @@ -211,7 +210,7 @@ def create_keys_from_db( f_public.close() -def generate_oblv_key(oblv_key_name: str) -> Tuple[bytes]: +def generate_oblv_key(oblv_key_name: str) -> tuple[bytes, bytes]: oblv_key_path = os.path.expanduser(os.getenv("OBLV_KEY_PATH", "~/.oblv")) os.makedirs(oblv_key_path, exist_ok=True) @@ -256,11 +255,12 @@ def __init__(self, store: DocumentStore) -> None: def create_key( self, context: AuthedServiceContext, + oblv_key_name: str, override_existing_key: bool = False, ) -> Result[Ok, Err]: """Domain Public/Private Key pair creation""" # TODO 🟣 Check for permission after it is fully integrated - public_key, private_key = generate_oblv_key() + public_key, private_key = generate_oblv_key(oblv_key_name) if override_existing_key: self.oblv_keys_stash.clear() @@ -323,7 +323,7 @@ def get_api_for( ) connection_string = f"http://127.0.0.1:{port}" else: - port = os.getenv("DOMAIN_CONNECTION_PORT", DOMAIN_CONNECTION_PORT) + port = os.getenv("OBLV_LOCALHOST_PORT", OBLV_LOCALHOST_PORT) connection_string = f"http://127.0.0.1:{port}" # To identify if we are in docker container diff --git a/packages/syft/src/syft/node/credentials.py b/packages/syft/src/syft/node/credentials.py index c60c57e1db2..d774f0f4c91 100644 --- a/packages/syft/src/syft/node/credentials.py +++ b/packages/syft/src/syft/node/credentials.py @@ -9,7 +9,7 @@ from nacl.encoding import HexEncoder from nacl.signing import SigningKey from nacl.signing import VerifyKey -import pydantic +from pydantic import field_validator # relative from ..serde.serializable import serializable @@ -54,8 +54,9 @@ def __hash__(self) -> int: class SyftSigningKey(SyftBaseModel): signing_key: SigningKey - @pydantic.validator("signing_key", pre=True, always=True) - def make_signing_key(cls, v: Union[str, SigningKey]) -> SigningKey: + @field_validator("signing_key", mode="before") + @classmethod + def make_signing_key(cls, v: Any) -> Any: return SigningKey(bytes.fromhex(v)) if isinstance(v, str) else v @property diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index fe5872f1aba..aa7c5da6bdf 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -18,7 +18,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Tuple from typing import Type from typing import Union import uuid @@ -39,6 +38,7 @@ from ..client.api import SyftAPICall from ..client.api import SyftAPIData from ..client.api import debox_signed_syftapicall_response +from ..client.client import SyftClient from ..exceptions.exception import PySyftException from ..external import OBLV from ..protocol.data_protocol import PROTOCOL_TYPE @@ -46,10 +46,12 @@ from ..service.action.action_object import Action from ..service.action.action_object import ActionObject from ..service.action.action_service import ActionService +from ..service.action.action_store import ActionStore from ..service.action.action_store import DictActionStore from ..service.action.action_store import MongoActionStore from ..service.action.action_store import SQLiteActionStore from ..service.blob_storage.service import BlobStorageService +from ..service.code.status_service import UserCodeStatusService from ..service.code.user_code_service import UserCodeService from ..service.code.user_code_stash import UserCodeStash from ..service.code_history.code_history_service import CodeHistoryService @@ -63,12 +65,15 @@ from ..service.enclave.enclave_service import EnclaveService from ..service.job.job_service import JobService from ..service.job.job_stash import Job +from ..service.job.job_stash import JobStash from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService from ..service.metadata.node_metadata import NodeMetadataV3 from ..service.network.network_service import NetworkService from ..service.notification.notification_service import NotificationService +from ..service.notifier.notifier_service import NotifierService from ..service.object_search.migration_state_service import MigrateStateService +from ..service.output.output_service import OutputService from ..service.policy.policy_service import PolicyService from ..service.project.project_service import ProjectService from ..service.queue.base_queue import AbstractMessageHandler @@ -91,6 +96,7 @@ from ..service.settings.settings import NodeSettingsV2 from ..service.settings.settings_service import SettingsService from ..service.settings.settings_stash import SettingsStash +from ..service.sync.sync_service import SyncService from ..service.user.user import User from ..service.user.user import UserCreate from ..service.user.user_roles import ServiceRole @@ -101,9 +107,11 @@ from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image from ..service.worker.worker_image_service import SyftWorkerImageService +from ..service.worker.worker_pool import WorkerPool from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_pool_stash import SyftWorkerPoolStash from ..service.worker.worker_service import WorkerService +from ..service.worker.worker_stash import WorkerStash from ..store.blob_storage import BlobStorageConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig @@ -156,7 +164,7 @@ def get_node_name() -> Optional[str]: return get_env(NODE_NAME, None) -def get_node_side_type() -> str: +def get_node_side_type() -> Optional[str]: return get_env(NODE_SIDE_TYPE, "high") @@ -188,15 +196,15 @@ def get_container_host() -> Optional[str]: return get_env("CONTAINER_HOST") -def get_default_worker_image() -> str: +def get_default_worker_image() -> Optional[str]: return get_env("DEFAULT_WORKER_POOL_IMAGE") -def get_default_worker_pool_name() -> str: +def get_default_worker_pool_name() -> Optional[str]: return get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME) -def get_default_worker_pool_count(node) -> int: +def get_default_worker_pool_count(node: Node) -> int: return int( get_env( "DEFAULT_WORKER_POOL_COUNT", node.queue_config.client_config.n_consumers @@ -204,7 +212,7 @@ def get_default_worker_pool_count(node) -> int: ) -def in_kubernetes() -> Optional[str]: +def in_kubernetes() -> bool: return get_container_host() == "k8s" @@ -242,7 +250,7 @@ def get_syft_worker_uid() -> Optional[str]: class AuthNodeContextRegistry: - __node_context_registry__: Dict[Tuple, Node] = OrderedDict() + __node_context_registry__: Dict[str, NodeServiceContext] = OrderedDict() @classmethod def set_node_context( @@ -250,7 +258,7 @@ def set_node_context( node_uid: Union[UID, str], context: NodeServiceContext, user_verify_key: Union[SyftVerifyKey, str], - ): + ) -> None: if isinstance(node_uid, str): node_uid = UID.from_string(node_uid) @@ -290,9 +298,9 @@ def __init__( signing_key: Optional[Union[SyftSigningKey, SigningKey]] = None, action_store_config: Optional[StoreConfig] = None, document_store_config: Optional[StoreConfig] = None, - root_email: str = default_root_email, - root_username: str = default_root_username, - root_password: str = default_root_password, + root_email: Optional[str] = default_root_email, + root_username: Optional[str] = default_root_username, + root_password: Optional[str] = default_root_password, processes: int = 0, is_subprocess: bool = False, node_type: Union[str, NodeType] = NodeType.DOMAIN, @@ -309,6 +317,11 @@ def __init__( dev_mode: bool = False, migrate: bool = False, in_memory_workers: bool = True, + smtp_username: Optional[str] = None, + smtp_password: Optional[str] = None, + email_sender: Optional[str] = None, + smtp_port: Optional[str] = None, + smtp_host: Optional[str] = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -350,6 +363,7 @@ def __init__( DataSubjectService, NetworkService, PolicyService, + NotifierService, NotificationService, DataSubjectMemberService, ProjectService, @@ -361,6 +375,9 @@ def __init__( SyftWorkerImageService, SyftWorkerPoolService, SyftImageRegistryService, + SyncService, + OutputService, + UserCodeStatusService, ] if services is None else services @@ -394,7 +411,17 @@ def __init__( node=self, ) - self.client_cache = {} + NotifierService.init_notifier( + node=self, + email_password=smtp_password, + email_username=smtp_username, + email_sender=email_sender, + smtp_port=smtp_port, + smtp_host=smtp_host, + ) + + self.client_cache: dict = {} + if isinstance(node_type, str): node_type = NodeType(node_type) self.node_type = node_type @@ -425,7 +452,7 @@ def __init__( NodeRegistry.set_node_for(self.id, self) @property - def runs_in_docker(self): + def runs_in_docker(self) -> bool: path = "/proc/self/cgroup" return ( os.path.exists("/.dockerenv") @@ -447,7 +474,7 @@ def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None: # relative from ..store.blob_storage.seaweedfs import SeaweedFSConfig - if isinstance(config, SeaweedFSConfig): + if isinstance(config, SeaweedFSConfig) and self.signing_key: blob_storage_service = self.get_service(BlobStorageService) remote_profiles = blob_storage_service.remote_profile_stash.get_all( credentials=self.signing_key.verify_key, has_permission=True @@ -457,14 +484,14 @@ def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None: remote_profile.profile_name ] = remote_profile - def stop(self): + def stop(self) -> None: for consumer_list in self.queue_manager.consumers.values(): for c in consumer_list: c.close() for p in self.queue_manager.producers.values(): p.close() - def close(self): + def close(self) -> None: self.stop() def create_queue_config( @@ -493,10 +520,10 @@ def create_queue_config( return queue_config_ - def init_queue_manager(self, queue_config: QueueConfig): + def init_queue_manager(self, queue_config: QueueConfig) -> None: MessageHandlers = [APICallMessageHandler] if self.is_subprocess: - return + return None self.queue_manager = QueueManager(config=queue_config) for message_handler in MessageHandlers: @@ -551,8 +578,8 @@ def add_consumer_for_service( service_name: str, syft_worker_id: UID, address: str, - message_handler: AbstractMessageHandler = APICallMessageHandler, - ): + message_handler: Type[AbstractMessageHandler] = APICallMessageHandler, + ) -> None: consumer: QueueConsumer = self.queue_manager.create_consumer( message_handler, address=address, @@ -664,7 +691,7 @@ def is_root(self, credentials: SyftVerifyKey) -> bool: return credentials == self.verify_key @property - def root_client(self): + def root_client(self) -> SyftClient: # relative from ..client.client import PythonConnection @@ -673,7 +700,8 @@ def root_client(self): if isinstance(client_type, SyftError): return client_type root_client = client_type(connection=connection, credentials=self.signing_key) - root_client.api.refresh_api_callback() + if root_client.api.refresh_api_callback is not None: + root_client.api.refresh_api_callback() return root_client def _find_klasses_pending_for_migration( @@ -707,7 +735,7 @@ def _find_klasses_pending_for_migration( return klasses_to_be_migrated - def find_and_migrate_data(self): + def find_and_migrate_data(self) -> None: # Track all object type that need migration for document store context = AuthedServiceContext( node=self, @@ -772,24 +800,26 @@ def find_and_migrate_data(self): print("Data Migrated to latest version !!!") @property - def guest_client(self): + def guest_client(self) -> SyftClient: return self.get_guest_client() @property - def current_protocol(self) -> List: + def current_protocol(self) -> Union[str, int]: data_protocol = get_data_protocol() return data_protocol.latest_version - def get_guest_client(self, verbose: bool = True): + def get_guest_client(self, verbose: bool = True) -> SyftClient: # relative from ..client.client import PythonConnection connection = PythonConnection(node=self) - if verbose: - print( + if verbose and self.node_side_type: + message: str = ( f"Logged into <{self.name}: {self.node_side_type.value.capitalize()} " - f"side {self.node_type.value.capitalize()} > as GUEST" ) + if self.node_type: + message += f"side {self.node_type.value.capitalize()} > as GUEST" + print(message) client_type = connection.get_client_type() if isinstance(client_type, SyftError): @@ -798,7 +828,8 @@ def get_guest_client(self, verbose: bool = True): guest_client = client_type( connection=connection, credentials=SyftSigningKey.generate() ) - guest_client.api.refresh_api_callback() + if guest_client.api.refresh_api_callback is not None: + guest_client.api.refresh_api_callback() return guest_client def __repr__(self) -> str: @@ -834,13 +865,15 @@ def post_init(self) -> None: def reload_user_code() -> None: user_code_service.load_user_code(context=context) - CODE_RELOADER[thread_ident()] = reload_user_code + ti = thread_ident() + if ti is not None: + CODE_RELOADER[ti] = reload_user_code def init_stores( self, document_store_config: Optional[StoreConfig] = None, action_store_config: Optional[StoreConfig] = None, - ): + ) -> None: if document_store_config is None: if self.local_db or (self.processes > 0 and not self.is_subprocess): client_config = SQLiteStoreClientConfig(path=self.sqlite_path) @@ -884,7 +917,7 @@ def init_stores( action_store_config.client_config.filename = f"{self.id}.sqlite" if isinstance(action_store_config, SQLiteStoreConfig): - self.action_store = SQLiteActionStore( + self.action_store: ActionStore = SQLiteActionStore( store_config=action_store_config, root_verify_key=self.verify_key, ) @@ -905,14 +938,14 @@ def init_stores( self.queue_stash = QueueStash(store=self.document_store) @property - def job_stash(self): + def job_stash(self) -> JobStash: return self.get_service("jobservice").stash @property - def worker_stash(self): + def worker_stash(self) -> WorkerStash: return self.get_service("workerservice").stash - def _construct_services(self): + def _construct_services(self) -> None: self.service_path_map = {} for service_klass in self.services: @@ -932,6 +965,7 @@ def _construct_services(self): DataSubjectService, NetworkService, PolicyService, + NotifierService, NotificationService, DataSubjectMemberService, ProjectService, @@ -943,6 +977,9 @@ def _construct_services(self): SyftWorkerImageService, SyftWorkerPoolService, SyftImageRegistryService, + SyncService, + OutputService, + UserCodeStatusService, ] if OBLV: @@ -952,7 +989,7 @@ def _construct_services(self): store_services += [OblvService] if service_klass in store_services: - kwargs["store"] = self.document_store + kwargs["store"] = self.document_store # type: ignore[assignment] self.service_path_map[service_klass.__name__.lower()] = service_klass( **kwargs ) @@ -962,7 +999,7 @@ def get_service_method(self, path_or_func: Union[str, Callable]) -> Callable: path_or_func = path_or_func.__qualname__ return self._get_service_method_from_path(path_or_func) - def get_service(self, path_or_func: Union[str, Callable]) -> Callable: + def get_service(self, path_or_func: Union[str, Callable]) -> AbstractService: if callable(path_or_func): path_or_func = path_or_func.__qualname__ return self._get_service_from_path(path_or_func) @@ -984,6 +1021,8 @@ def _get_service_method_from_path(self, path: str) -> Callable: @property def settings(self) -> NodeSettingsV2: settings_stash = SettingsStash(store=self.document_store) + if self.signing_key is None: + raise ValueError(f"{self} has no signing key") settings = settings_stash.get_all(self.signing_key.verify_key) if settings.is_ok() and len(settings.ok()) > 0: settings_data = settings.ok()[0] @@ -1000,6 +1039,8 @@ def metadata(self) -> NodeMetadataV3: organization = settings_data.organization description = settings_data.description show_warnings = settings_data.show_warnings + node_type = self.node_type.value if self.node_type else "" + node_side_type = self.node_side_type.value if self.node_side_type else "" return NodeMetadataV3( name=name, @@ -1010,8 +1051,8 @@ def metadata(self) -> NodeMetadataV3: syft_version=__version__, description=description, organization=organization, - node_type=self.node_type.value, - node_side_type=self.node_side_type.value, + node_type=node_type, + node_side_type=node_side_type, show_warnings=show_warnings, ) @@ -1021,6 +1062,8 @@ def icon(self) -> str: @property def verify_key(self) -> SyftVerifyKey: + if self.signing_key is None: + raise ValueError(f"{self} has no signing key") return self.signing_key.verify_key def __hash__(self) -> int: @@ -1135,7 +1178,7 @@ def handle_api_call( self, api_call: Union[SyftAPICall, SignedSyftAPICall], job_id: Optional[UID] = None, - check_call_location=True, + check_call_location: bool = True, ) -> Result[SignedSyftAPICall, Err]: # Get the result result = self.handle_api_call_with_unsigned_result( @@ -1150,7 +1193,7 @@ def handle_api_call_with_unsigned_result( self, api_call: Union[SyftAPICall, SignedSyftAPICall], job_id: Optional[UID] = None, - check_call_location=True, + check_call_location: bool = True, ) -> Union[Result, QueueItem, SyftObject, SyftError]: if self.required_signed_calls and isinstance(api_call, SyftAPICall): return SyftError( @@ -1188,8 +1231,8 @@ def handle_api_call_with_unsigned_result( if api_call.path not in user_config_registry: if ServiceConfigRegistry.path_exists(api_call.path): return SyftError( - message=f"As a `{role}`," - f"you have has no access to: {api_call.path}" + message=f"As a `{role}`, " + f"you have no access to: {api_call.path}" ) else: return SyftError( @@ -1212,9 +1255,9 @@ def handle_api_call_with_unsigned_result( def add_action_to_queue( self, - action, + action: Action, credentials: SyftVerifyKey, - parent_job_id=None, + parent_job_id: Optional[UID] = None, has_execute_permissions: bool = False, worker_pool_name: Optional[str] = None, ) -> Union[Job, SyftError]: @@ -1268,10 +1311,10 @@ def add_action_to_queue( def add_queueitem_to_queue( self, - queue_item: ActionQueueItem, + queue_item: QueueItem, credentials: SyftVerifyKey, - action=None, - parent_job_id=None, + action: Optional[Action] = None, + parent_job_id: Optional[UID] = None, ) -> Union[Job, SyftError]: log_id = UID() role = self.get_role_for_credentials(credentials=credentials) @@ -1333,7 +1376,9 @@ def _is_usercode_call_on_owned_kwargs( user_code_service = self.get_service("usercodeservice") return user_code_service.is_execution_on_owned_args(api_call.kwargs, context) - def add_api_call_to_queue(self, api_call, parent_job_id=None): + def add_api_call_to_queue( + self, api_call: SyftAPICall, parent_job_id: Optional[UID] = None + ) -> Union[Job, SyftError]: unsigned_call = api_call if isinstance(api_call, SignedSyftAPICall): unsigned_call = api_call.message @@ -1420,7 +1465,7 @@ def pool_stash(self) -> SyftWorkerPoolStash: def user_code_stash(self) -> UserCodeStash: return self.get_service(UserCodeService).stash - def get_default_worker_pool(self): + def get_default_worker_pool(self) -> Union[Optional[WorkerPool], SyftError]: result = self.pool_stash.get_by_name( credentials=self.verify_key, pool_name=get_default_worker_pool_name(), @@ -1457,6 +1502,9 @@ def create_initial_settings(self, admin_email: str) -> Optional[NodeSettingsV2]: self.name = random_name() try: settings_stash = SettingsStash(store=self.document_store) + if self.signing_key is None: + print("create_initial_settings failed as there is no signing key") + return None settings_exists = settings_stash.get_all(self.signing_key.verify_key).ok() if settings_exists: self.name = settings_exists[0].name @@ -1474,7 +1522,7 @@ def create_initial_settings(self, admin_email: str) -> Optional[NodeSettingsV2]: deployed_on=datetime.now().date().strftime("%m/%d/%Y"), signup_enabled=flags.CAN_REGISTER, admin_email=admin_email, - node_side_type=self.node_side_type.value, + node_side_type=self.node_side_type.value, # type: ignore show_warnings=self.enable_warnings, ) result = settings_stash.set( @@ -1484,7 +1532,8 @@ def create_initial_settings(self, admin_email: str) -> Optional[NodeSettingsV2]: return result.ok() return None except Exception as e: - print("create_worker_metadata failed", e) + print(f"create_initial_settings failed with error {e}") + return None def create_admin_new( @@ -1525,6 +1574,8 @@ def create_admin_new( except Exception as e: print("Unable to create new admin", e) + return None + def create_oblv_key_pair( worker: Node, @@ -1537,7 +1588,7 @@ def create_oblv_key_pair( oblv_keys_stash = OblvKeysStash(store=worker.document_store) - if not len(oblv_keys_stash): + if not len(oblv_keys_stash) and worker.signing_key: public_key, private_key = generate_oblv_key(oblv_key_name=worker.name) oblv_keys = OblvKeys(public_key=public_key, private_key=private_key) res = oblv_keys_stash.set(worker.signing_key.verify_key, oblv_keys) @@ -1548,6 +1599,9 @@ def create_oblv_key_pair( print(f"Using Existing Public/Private Key pair: {len(oblv_keys_stash)}") except Exception as e: print("Unable to create Oblv Keys.", e) + return None + + return None class NodeRegistry: @@ -1573,7 +1627,7 @@ def get_all_nodes(cls) -> List[Node]: return list(cls.__node_registry__.values()) -def get_default_worker_tag_by_env(dev_mode=False): +def get_default_worker_tag_by_env(dev_mode: bool = False) -> Optional[str]: if in_kubernetes(): return get_default_worker_image() elif dev_mode: @@ -1621,7 +1675,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if isinstance(result, SyftError): print("Failed to build default worker image: ", result.message) - return + return None # Create worker pool if it doesn't exists print( @@ -1654,7 +1708,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if isinstance(result, SyftError): print(f"Default worker pool error. {result.message}") - return + return None for n in range(worker_to_add_): container_status = result[n] @@ -1663,6 +1717,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: f"Failed to create container: Worker: {container_status.worker}," f"Error: {container_status.error}" ) - return + return None print("Created default worker pool.") + return None diff --git a/packages/syft/src/syft/node/run.py b/packages/syft/src/syft/node/run.py index 3b8376fb4e1..10aa942a498 100644 --- a/packages/syft/src/syft/node/run.py +++ b/packages/syft/src/syft/node/run.py @@ -2,6 +2,9 @@ import argparse from typing import Optional +# third party +from hagrid.orchestra import NodeHandle + # relative from ..client.deploy import Orchestra @@ -14,7 +17,7 @@ def str_to_bool(bool_str: Optional[str]) -> bool: return result -def run(): +def run() -> Optional[NodeHandle]: parser = argparse.ArgumentParser() parser.add_argument("command", help="command: launch", type=str, default="none") parser.add_argument( diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index fc7bbb2bc8d..28032da15fd 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -81,16 +81,16 @@ def run_uvicorn( queue_port: Optional[int], create_producer: bool, n_consumers: int, -): +) -> None: async def _run_uvicorn( name: str, - node_type: Enum, + node_type: NodeType, host: str, port: int, reset: bool, dev_mode: bool, node_side_type: Enum, - ): + ) -> None: if node_type not in worker_classes: raise NotImplementedError(f"node_type: {node_type} is not supported") worker_class = worker_classes[node_type] @@ -205,7 +205,7 @@ def serve_node( ), ) - def stop(): + def stop() -> None: print(f"Stopping {name}") server_process.terminate() server_process.join(3) @@ -214,7 +214,7 @@ def stop(): server_process.kill() print("killed") - def start(): + def start() -> None: print(f"Starting {name} server on {host}:{port}") server_process.start() diff --git a/packages/syft/src/syft/node/worker_settings.py b/packages/syft/src/syft/node/worker_settings.py index 106e2e94821..57542d89c1c 100644 --- a/packages/syft/src/syft/node/worker_settings.py +++ b/packages/syft/src/syft/node/worker_settings.py @@ -16,30 +16,11 @@ from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig from ..store.document_store import StoreConfig -from ..types.syft_migration import migrate -from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject -from ..types.transforms import drop -from ..types.transforms import make_set_default from ..types.uid import UID -@serializable() -class WorkerSettingsV1(SyftObject): - __canonical_name__ = "WorkerSettings" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - name: str - node_type: NodeType - node_side_type: NodeSideType - signing_key: SyftSigningKey - document_store_config: StoreConfig - action_store_config: StoreConfig - blob_store_config: Optional[BlobStorageConfig] - - @serializable() class WorkerSettings(SyftObject): __canonical_name__ = "WorkerSettings" @@ -52,37 +33,23 @@ class WorkerSettings(SyftObject): signing_key: SyftSigningKey document_store_config: StoreConfig action_store_config: StoreConfig - blob_store_config: Optional[BlobStorageConfig] - queue_config: Optional[QueueConfig] - - @staticmethod - def from_node(node: AbstractNode) -> Self: - return WorkerSettings( + blob_store_config: Optional[BlobStorageConfig] = None + queue_config: Optional[QueueConfig] = None + + @classmethod + def from_node(cls, node: AbstractNode) -> Self: + if node.node_side_type: + node_side_type: str = node.node_side_type.value + else: + node_side_type = NodeSideType.HIGH_SIDE + return cls( id=node.id, name=node.name, node_type=node.node_type, signing_key=node.signing_key, document_store_config=node.document_store_config, action_store_config=node.action_store_config, - node_side_type=node.node_side_type.value, + node_side_type=node_side_type, blob_store_config=node.blob_store_config, queue_config=node.queue_config, ) - - -# queue_config - - -@migrate(WorkerSettings, WorkerSettingsV1) -def downgrade_workersettings_v2_to_v1(): - return [ - drop(["queue_config"]), - ] - - -@migrate(WorkerSettingsV1, WorkerSettings) -def upgrade_workersettings_v1_to_v2(): - # relative - from ..service.queue.zmq_queue import ZMQQueueConfig - - return [make_set_default("queue_config", ZMQQueueConfig())] diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 4512b25e9be..079647dd7e7 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -9,6 +9,7 @@ import re from typing import Any from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Tuple @@ -33,7 +34,7 @@ PROTOCOL_TYPE = Union[str, int] -def natural_key(key: PROTOCOL_TYPE) -> List[int]: +def natural_key(key: PROTOCOL_TYPE) -> List[Union[int, str, Any]]: """Define key for natural ordering of strings.""" if isinstance(key, int): key = str(key) @@ -128,7 +129,7 @@ def _hash_to_sha256(obj_dict: Dict) -> str: def build_state(self, stop_key: Optional[str] = None) -> dict: sorted_dict = sort_dict_naturally(self.protocol_history) - state_dict = defaultdict(dict) + state_dict: dict = defaultdict(dict) for protocol_number in sorted_dict: object_versions = sorted_dict[protocol_number]["object_versions"] for canonical_name, versions in object_versions.items(): @@ -165,8 +166,8 @@ def build_state(self, stop_key: Optional[str] = None) -> dict: return state_dict def diff_state(self, state: Dict) -> tuple[Dict, Dict]: - compare_dict = defaultdict(dict) # what versions are in the latest code - object_diff = defaultdict(dict) # diff in latest code with saved json + compare_dict: dict = defaultdict(dict) # what versions are in the latest code + object_diff: dict = defaultdict(dict) # diff in latest code with saved json for k in TYPE_BANK: ( nonrecursive, @@ -497,7 +498,7 @@ def debox_arg_and_migrate(arg: Any, protocol_state: dict) -> Any: arg = arg.value if isinstance(arg, MutableMapping): - iterable_keys = arg.keys() + iterable_keys: Iterable = arg.keys() elif isinstance(arg, MutableSequence): iterable_keys = range(len(arg)) elif isinstance(arg, tuple): @@ -548,7 +549,7 @@ def migrate_args_and_kwargs( to_protocol = data_protocol.latest_version if to_latest_protocol else None if to_protocol is None: - raise SyftException(message="Protocol version missing.") + raise SyftException("Protocol version missing.") # If latest protocol being used is equal to the protocol to be migrate # then skip migration of the object diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 1baa64af0a8..9f26b400215 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -10,10 +10,616 @@ }, "dev": { "object_versions": { + "ActionObject": { + "1": { + "version": 1, + "hash": "632446f1415102490c93fafb56dd9eb29d79623bcc5e9f2e6e37c4f63c2c51c3", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "577aa1f010b90194958a18ec38ee21db3718bd96d9e036501c6ddeefabedf432", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "0fe8c63c7ebf317c9b3791563eede28ce301dc0a2a1a98b13e657f34ed1e9edb", + "action": "add" + } + }, + "AnyActionObject": { + "1": { + "version": 1, + "hash": "bcb31f847907edc9c95d2d120dc5427854604f40940e3f41cd0474a1820ac65e", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "002d8be821140befebbc0503e6bc1ef8779094e24e46305e5da5af6eecb56b13", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "0ac9122d40743966890247c7444c1033ba52bdbb0d2396daf8767adbe42faaad", + "action": "add" + } + }, + "BlobFileOBject": { + "1": { + "version": 1, + "hash": "8da2c80ced4f0414c671313c4b63d05846df1e397c763d99d803be86c29755bb", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "cf3789022517ea88c968672566e7e3ae1dbf35c9f8ac5f09fd1ff7ca79534444", + "action": "add" + } + }, + "JobInfo": { + "1": { + "version": 1, + "hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "5c1f7d5e6a991123a1907c1823be14a75458ba06af1fe5a1b77aaac7fa546c78", + "action": "add" + } + }, + "ExecutionOutput": { + "1": { + "version": 1, + "hash": "833addc66807a638939aac00a4be306c93bd8d80a8f4ce6fcdb16d98e87ceb8b", + "action": "add" + } + }, + "OutputPolicyExecuteCount": { + "1": { + "version": 1, + "hash": "6bb24b3b35e19564c43b838ca3f46ccdeadb6596511917f2d220681a378e439d", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "ca0ba249f4f32379f5b83279a27df4a21eb23c531a86538c821a10ddf2c799ff", + "action": "add" + } + }, + "OutputPolicyExecuteOnce": { + "1": { + "version": 1, + "hash": "32a40fc9966b277528eebc61c01041f3a5447417731954abdaffbb14dabc76bb", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "e6b0f23047037734c1cc448771bc2770f5bf6c8b8f80cf46939eb7ba66dd377e", + "action": "add" + } + }, + "UserCodeStatusCollection": { + "1": { + "version": 1, + "hash": "4afcdcebd4b0ba95a8ac65eda9fcaa88129b7c520e8e6b093c6ab5208641a617", + "action": "add" + } + }, "UserCode": { + "1": { + "version": 1, + "hash": "e14c22686cdc7d1fb2b0d01c0aebdea37e62a61b051677c1d30234214f05cd42", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "660e1abc15034f525e91ffdd820c2a2179bfddf83b7b9e3ce7823b2efc515c69", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "90fcae0f556f375ba1e91d2e345f57241660695c6e2b84c8e311df89d09e6c66", + "action": "remove" + }, "4": { "version": 4, - "hash": "09e2c6a119246d8beb14f34a365c6e016947f854db86658bbef99c8970bf7e27", + "hash": "4acb1fa6856da943966b6a93eb7874000f785b29f12ecbed9025606f8fe51aa4", + "action": "add" + } + }, + "UserCodeExecutionOutput": { + "1": { + "version": 1, + "hash": "94c18d2dec05b39993c1a7a70bca2c991c95bd168005a93e578a810e57ef3164", + "action": "add" + } + }, + "NumpyArrayObject": { + "1": { + "version": 1, + "hash": "dcc7b44fa5ad22ae0bc576948f856c172dac1e9de2bc8e2a302e428f3309a278", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "2c631121d9211006edab5620b214dea83e2398bee92244d822227ee316647e22", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "b6c27c63285f55425942296a91bb16010fd359909fb82fcd52efa9e744e5f2a4", + "action": "add" + } + }, + "NumpyScalarObject": { + "1": { + "version": 1, + "hash": "5c1b6b6e8ba88bc79e76646d621489b889fe8f9b9fd59f117d594be18a409633", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "0d5d81b9d45c140f6e07b43ed68d31e0ef060d6b4d0431c9b4795997bb35c69d", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "028e645eea21425a049a56393218c2e89343edf09e9ff70d7fed6561c6508a43", + "action": "add" + } + }, + "NumpyBoolObject": { + "1": { + "version": 1, + "hash": "a5c822a6a3ca9eefd6a2b68f7fd0bc614fba7995f6bcc30bdc9dc882296b9b16", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "24839ba1c88ed833a134124750d5f299abcdf318670315028ed87b254f4578b3", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "e36b44d1829aff0e127bb1ba7b8e8f6853d6cf94cc86ef11c521019f1eec7e96", + "action": "add" + } + }, + "PandasDataframeObject": { + "1": { + "version": 1, + "hash": "35058924b3de2e0a604a92f91f4dd2e3cc0dac80c219d34f360e7cedd52f5f4c", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "66729d4ba7a92210d45c5a5c24fbdb4c8e58138a515a7bdb71ac8f6e8b868544", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "90fb7e7e5c7b03f37573012029c6979ccaaa44e720a48a7f829d83c6a41393e5", + "action": "add" + } + }, + "PandasSeriesObject": { + "1": { + "version": 1, + "hash": "2a0d8a55f1c27bd8fccd276cbe01bf272c40cab10417d7027273983fed423caa", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "cb05a714f75b1140a943f56a3622fcc0477b3a1f504cd545a98510959ffe1528", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "50d5d68c0b4d57f8ecf594ee9761a6b4a9cd726354a4c8e3ff28e4e0a2fe58a4", + "action": "add" + } + }, + "UserCodeStatusChange": { + "1": { + "version": 1, + "hash": "4f5b405cc2b3976ed8f7018df82e873435d9187dff15fa5a23bc85a738969f3f", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "d83e0905ae882c824ba8fbbf455cd3881906bf8b2ebbfff07bcf471ef869cedc", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "999ab977d4fe5a7b74ee2d90370599ce9caa1b38fd6e6c29bd543d379c4dae31", + "action": "add" + } + }, + "SyncStateItem": { + "1": { + "version": 1, + "hash": "7e1f22d0e24bb615b077d76feae7bed96a49a998358bd842aba18e8d69a22481", + "action": "add" + } + }, + "SyncState": { + "1": { + "version": 1, + "hash": "6da39adb0ecffb4ca7873c0d95ed31c8bf037610cde144662285b921de5d8f04", + "action": "add" + } + }, + "StoreConfig": { + "1": { + "version": 1, + "hash": "17de8875cf590311ddb042140347ffc79d4a85028e504dad178ca4e1237ec861", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "517ca390f0a92e60b79ee7a70772a6b2c29f82ed9042266957f0ce0d61b636f1", + "action": "add" + } + }, + "MongoStoreConfig": { + "1": { + "version": 1, + "hash": "e52aa382e300b0b69aaa2d80aadb4e3a9a3c02b3c741b71d56f959c4d3891ce5", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "9c47910aa82d955b11c62cbab5e23e83f90cfb6b82aa0b6d4aae7dffc9f2d846", + "action": "add" + } + }, + "Action": { + "1": { + "version": 1, + "hash": "5cf71ee35097f17fbb1dd05096f875211d71cf07161205d7f6a9c11fd49d5272", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "a13b50c4d23bd6deb7896e394f2a20e6cef4c33c5e6f4ee30f19eaffab708f21", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "0588c49fe6f38fbe2a6aefa1a2fe50ed79273f218ead40b3a8c4d2fd63a22d08", + "action": "add" + } + }, + "DataSubjectCreate": { + "1": { + "version": 1, + "hash": "5a94f9fcba75c50d78d71222f0235c5fd4d8003ae0db4d74bdbc4d56a99de3aa", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "4eb3d7fb24d674ad23e3aec584e0332054768d61d62bba329488183816732f6e", + "action": "add" + } + }, + "Dataset": { + "1": { + "version": 1, + "hash": "99ca2fa3e46fd9810222d269fac6accb546f632e94d5d57529016ba5e55af5a8", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "ef072e802af563bb5bb95e928ac50fa30ff6b07da2dccf16cf134d71f8744132", + "action": "add" + } + }, + "CreateDataset": { + "1": { + "version": 1, + "hash": "3b020d9b8928cbd7e91f41c749ab4c932e19520696a183f2c7cd1312ebb640d1", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "4c3cbd2b10e43e750fea1bad5368c7de9e66e49840cd4dc84f80bbbf1e81f359", + "action": "add" + } + }, + "DictStoreConfig": { + "1": { + "version": 1, + "hash": "256e9c623ce0becd555ddd2a55a0c15514e162786b1549388cef98a92a9b18c9", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "dc42f71c620250c74f798304cb0cdfd8c3df42ddc0e38b9663f084a451e4e0f6", + "action": "add" + } + }, + "SQLiteStoreConfig": { + "1": { + "version": 1, + "hash": "b656b26c14cf4e97aba702dd62a0927aec7f860c12eed512c2c688e1b7109aa5", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "41c8ead76c6babfe8c1073ef705b1c5d4d96fba5735d9d8cb669073637f83f5f", + "action": "add" + } + }, + "Plan": { + "1": { + "version": 1, + "hash": "a0bba2b7792c9e08c453e9e256f0ac6e6185610726566bcd50b057ae83b42d9a", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "6103055aebe436855987c18aeb63d6ec90e0ec6654f960eaa8212c0a6d2964aa", + "action": "add" + } + }, + "NodeMetadata": { + "1": { + "version": 1, + "hash": "6bee018894dfdf697ea624740d0bf051750e0b0d8470ced59646f6d8812068ac", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "f856169fea72486cd436875ce4411ef935da11eb7c5af48121adfa00d4c0cdb6", + "action": "remove" + } + }, + "NodeSettings": { + "1": { + "version": 1, + "hash": "b662047bb278f4f5db77c102f94b733c3a929839271b3d6b82ea174a60e2aaf0", + "action": "remove" + } + }, + "BlobFile": { + "1": { + "version": 1, + "hash": "47ed55183d619c6c624e35412360a41de42833e2c24223c1de1ad12a84fdafc2", + "action": "remove" + } + }, + "SeaweedSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "5724a38b1a92b8a55da3d9cc34a720365a6d0c32683acda630fc44067173e201", + "action": "remove" + } + }, + "BlobStorageEntry": { + "1": { + "version": 1, + "hash": "9f1b027cce390ee6f71c7a81e7420bb71a477b29c6c62ba74e781a97bc5434e6", + "action": "remove" + } + }, + "BlobStorageMetadata": { + "1": { + "version": 1, + "hash": "6888943be3f97186190dd26d7eefbdf29b15c6f2fa459e13608065ebcdb799e2", + "action": "remove" + } + }, + "BlobRetrieval": { + "1": { + "version": 1, + "hash": "a8d7e1d6483e7a9b5a130e837fa398862aa6cbb316cc5f4470450d835755fdd9", + "action": "remove" + } + }, + "SyftObjectRetrieval": { + "2": { + "version": 2, + "hash": "d9d7a7e1b8843145c9687fd013c9223700285886073547734267e91ac53e0996", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "952958e9afae007bef3cb89aa15be95dddc4c310e3a8ce4191576f90ac6fcbc8", + "action": "remove" + }, + "4": { + "version": 4, + "hash": "939934f46b72eb2c903606bce8e7ac2e59b1707b73c65fa2b9de8eed6e35f9da", + "action": "add" + } + }, + "WorkerSettings": { + "1": { + "version": 1, + "hash": "0dcd95422ec8a7c74e45ee68a125084c08f898dc94a13d25fe5a5fd0e4fc5027", + "action": "remove" + } + }, + "SubmitUserCode": { + "2": { + "version": 2, + "hash": "9b29e060973a3de8d3564a2b7d2bb5c53745aa445bf257576994b613505d7194", + "action": "remove" + } + }, + "SeaweedFSBlobDeposit": { + "1": { + "version": 1, + "hash": "382a9ac178deed2a9591e1ebbb39f265cbe67027fb93a420d473a4c26b7fda11", + "action": "remove" + } + }, + "QueueItem": { + "1": { + "version": 1, + "hash": "5aa94681d9d0715d5b605f9625a54e114927271378cf2ea7245f85c488035e0b", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "9503b878de4b5b7a1793580301353523b7d6219ebd27d38abe598061979b7570", + "action": "remove" + } + }, + "ZMQClientConfig": { + "1": { + "version": 1, + "hash": "e6054969b495791569caaf33239039beae3d116e1fe74e9575467c48b9007c45", + "action": "remove" + } + }, + "ActionQueueItem": { + "1": { + "version": 1, + "hash": "11a43caf9164eb2a5a21f4bcb0ca361d0a5d134bf3c60173f2c502d0d80219de", + "action": "remove" + } + }, + "JobItem": { + "1": { + "version": 1, + "hash": "7b8723861837b0b7e948b2cf9244159d232185f3407dd6bef108346f941ddf6e", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "e99cf5a78c6dd3a0adc37af3472c7c21570a9e747985dff540a2b06d24de6446", + "action": "remove" + } + }, + "SyftLog": { + "1": { + "version": 1, + "hash": "bd3f62b8fe4b2718a6380c8f05a93c5c40169fc4ab174db291929298e588429e", + "action": "remove" + } + }, + "SignedSyftAPICall": { + "1": { + "version": 1, + "hash": "e66a116de2fa44ebdd0d4c2d7d5a047dedb555fd201a0f431cd8017d9d33a61d", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "ecc6891b770f1f543d02c1eb0007443b0eb3553fd0b9347522b8aa4b22c4cdba", + "action": "add" + } + }, + "UserUpdate": { + "2": { + "version": 2, + "hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "ca32926b95a88406796d2d7ea23eeeb15b7a632ec46f0cf300d3890a19ae78e3", + "action": "add" + } + }, + "UserCreate": { + "2": { + "version": 2, + "hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "8d87bd936564628f5e7c08ab1dedc9b26e9cd8a53899ce1604c91fbd281ae0ab", + "action": "add" + } + }, + "UserSearch": { + "1": { + "version": 1, + "hash": "69d1e10b81c8a4143cf70e4f911d8562732af2458ebbc455ca64542f11373dd1", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "79f95cd9b4dabca88773a54e7993a0f7c80f5fad1f1aa144d82bd13375173ea3", + "action": "add" + } + }, + "NodeSettingsUpdate": { + "1": { + "version": 1, + "hash": "b6ddc66ff270a3c2c4760e31e1a55d72ed04ccae2d0115ebe2fba6f2bf9bd119", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "859a91c8229a59e03ed4c20d38de569f7670bdea4b0a8cf2d4bd702da37aeabe", + "action": "add" + } + }, + "User": { + "2": { + "version": 2, + "hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "4550a80d1e4682de38adb71f79f89b42bb42fa85b1383ece51bb737a30bd5522", + "action": "add" + } + }, + "UserView": { + "2": { + "version": 2, + "hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "9849a2182fed2f54ecaf03bd9febf0efec6639b8e27e5b1501683aa846b5a2d3", + "action": "add" + } + }, + "Notification": { + "1": { + "version": 1, + "hash": "d13981f721fe2b3e2717640ee07dc716c596e4ecd442461665c3fdab0b85bf0e", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "9032bac0e8ede1a3d118a0e31e0f1f05699d1efc88327fceb0917d40185a7930", + "action": "add" + } + }, + "CreateNotification": { + "1": { + "version": 1, + "hash": "b1f459de374fe674f873a4a5f3fb8a8aabe0d83faad84a933f0a77dd1141159a", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "5098e1ab1cf7ffd8da4ba5bff36ebdb235d3983453185035d6796a7517f8272c", + "action": "add" + } + }, + "NotificationPreferences": { + "1": { + "version": 1, + "hash": "57e033e2ebac5414a057b80599a31f277027a4980e49d31770f96017c57e638f", + "action": "add" + } + }, + "NotifierSettings": { + "1": { + "version": 1, + "hash": "8753b4ee72d673958783879bc3726c51077bf6a1deca37bacac3f3475605e812", "action": "add" } } diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index eeb1236b749..9efd64e02c0 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -111,9 +111,14 @@ def recursive_serde_register( # if pydantic object and attrs are provided, the get attrs from __fields__ # cls.__fields__ auto inherits attrs pydantic_fields = [ - f.name - for f in cls.__fields__.values() - if f.outer_type_ not in (Callable, types.FunctionType, types.LambdaType) + field + for field, field_info in cls.model_fields.items() + if not ( + field_info.annotation is not None + and hasattr(field_info.annotation, "__origin__") + and field_info.annotation.__origin__ + in (Callable, types.FunctionType, types.LambdaType) + ) ] attribute_list.update(pydantic_fields) @@ -125,7 +130,9 @@ def recursive_serde_register( attribute_list.update(["value"]) exclude_attrs = [] if exclude_attrs is None else exclude_attrs - attribute_list = attribute_list - set(exclude_attrs) + attribute_list = ( + attribute_list - set(exclude_attrs) - {"syft_pre_hooks__", "syft_post_hooks__"} + ) if inheritable_attrs and attribute_list and not is_pydantic: # only set __syft_serializable__ for non-pydantic classes because diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index 2d70250cbe6..fddbb5ae755 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -20,6 +20,7 @@ import pyarrow as pa import pyarrow.parquet as pq import pydantic +from pydantic._internal._model_construction import ModelMetaclass from pymongo.collection import Collection from result import Err from result import Ok @@ -29,6 +30,8 @@ # relative from ..types.dicttuple import DictTuple from ..types.dicttuple import _Meta as _DictTupleMetaClass +from ..types.syft_metaclass import EmptyType +from ..types.syft_metaclass import PartialModelMetaclass from .deserialize import _deserialize as deserialize from .recursive_primitives import _serialize_kv_pairs from .recursive_primitives import deserialize_kv @@ -54,8 +57,6 @@ # result Ok and Err recursive_serde_register(Ok, serialize_attrs=["_value"]) recursive_serde_register(Err, serialize_attrs=["_value"]) - -recursive_serde_register_type(pydantic.main.ModelMetaclass) recursive_serde_register(Result) # exceptions @@ -148,6 +149,17 @@ def _serialize_dicttuple(x: DictTuple) -> bytes: ) +recursive_serde_register( + EmptyType, + serialize=serialize_type, + deserialize=deserialize_type, +) + + +recursive_serde_register_type(ModelMetaclass) +recursive_serde_register_type(PartialModelMetaclass) + + def serialize_bytes_io(io: BytesIO) -> bytes: io.seek(0) return serialize(io.read(), to_bytes=True) @@ -180,9 +192,9 @@ def serialize_bytes_io(io: BytesIO) -> bytes: recursive_serde_register(np.core._ufunc_config._unspecified()) recursive_serde_register( - pydantic.networks.EmailStr, + pydantic.EmailStr, serialize=lambda x: x.encode(), - deserialize=lambda x: pydantic.networks.EmailStr(x.decode()), + deserialize=lambda x: pydantic.EmailStr(x.decode()), ) recursive_serde_register( diff --git a/packages/syft/src/syft/service/action/action_data_empty.py b/packages/syft/src/syft/service/action/action_data_empty.py index e32f4e339bb..c8f0e143e3d 100644 --- a/packages/syft/src/syft/service/action/action_data_empty.py +++ b/packages/syft/src/syft/service/action/action_data_empty.py @@ -2,6 +2,7 @@ from __future__ import annotations # stdlib +import sys from typing import Optional from typing import Type @@ -11,7 +12,11 @@ from ...types.syft_object import SyftObject from ...types.uid import UID -NoneType = type(None) +if sys.version_info >= (3, 10): + # stdlib + from types import NoneType +else: + NoneType = type(None) @serializable() diff --git a/packages/syft/src/syft/service/action/action_graph.py b/packages/syft/src/syft/service/action/action_graph.py index 8456f6852a1..ab14fa81b09 100644 --- a/packages/syft/src/syft/service/action/action_graph.py +++ b/packages/syft/src/syft/service/action/action_graph.py @@ -15,9 +15,8 @@ # third party import matplotlib.pyplot as plt import networkx as nx -import pydantic from pydantic import Field -from pydantic import validator +from pydantic import field_validator from result import Err from result import Ok from result import Result @@ -38,6 +37,7 @@ from ...types.datetime import DateTime from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.uid import UID from .action_object import Action @@ -62,21 +62,17 @@ class NodeActionData(SyftObject): __canonical_name__ = "NodeActionData" __version__ = SYFT_OBJECT_VERSION_1 - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] type: NodeType status: ExecutionStatus = ExecutionStatus.PROCESSING retry: int = 0 - created_at: Optional[DateTime] - updated_at: Optional[DateTime] + created_at: DateTime = Field(default_factory=DateTime.now) + updated_at: DateTime = Field(default_factory=DateTime.now) user_verify_key: SyftVerifyKey is_mutated: bool = False # denotes that this node has been mutated is_mutagen: bool = False # denotes that this node is causing a mutation - next_mutagen_node: Optional[UID] # next neighboring mutagen node - last_nm_mutagen_node: Optional[UID] # last non mutated mutagen node - - @pydantic.validator("created_at", pre=True, always=True) - def make_created_at(cls, v: Optional[DateTime]) -> DateTime: - return DateTime.now() if v is None else v + next_mutagen_node: Optional[UID] = None # next neighboring mutagen node + last_nm_mutagen_node: Optional[UID] = None # last non mutated mutagen node @classmethod def from_action(cls, action: Action, credentials: SyftVerifyKey) -> Self: @@ -117,24 +113,20 @@ def __repr__(self) -> str: @serializable() class NodeActionDataUpdate(PartialSyftObject): __canonical_name__ = "NodeActionDataUpdate" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: UID type: NodeType status: ExecutionStatus retry: int created_at: DateTime - updated_at: Optional[DateTime] + updated_at: DateTime = Field(default_factory=DateTime.now) credentials: SyftVerifyKey is_mutated: bool is_mutagen: bool next_mutagen_node: UID # next neighboring mutagen node last_nm_mutagen_node: UID # last non mutated mutagen node - @pydantic.validator("updated_at", pre=True, always=True) - def set_updated_at(cls, v: Optional[DateTime]) -> DateTime: - return DateTime.now() if v is None else v - @serializable() class BaseGraphStore: @@ -197,7 +189,8 @@ class InMemoryStoreClientConfig(StoreClientConfig): # We need this in addition to Field(default_factory=...) # so users can still do InMemoryStoreClientConfig(path=None) - @validator("path", pre=True) + @field_validator("path", mode="before") + @classmethod def __default_path(cls, path: Optional[Union[str, Path]]) -> Union[str, Path]: if path is None: return tempfile.gettempdir() @@ -211,15 +204,17 @@ def file_path(self) -> Path: @serializable(without=["_lock"]) class NetworkXBackingStore(BaseGraphStore): def __init__(self, store_config: StoreConfig, reset: bool = False) -> None: - self.path_str = store_config.client_config.file_path.as_posix() - + if store_config.client_config: + self.path_str = store_config.client_config.file_path.as_posix() + else: + self.path_str = "" if not reset and os.path.exists(self.path_str): self._db = self._load_from_path(self.path_str) else: self._db = nx.DiGraph() self.locking_config = store_config.locking_config - self._lock = None + self._lock: Optional[SyftLock] = None @property def lock(self) -> SyftLock: diff --git a/packages/syft/src/syft/service/action/action_graph_service.py b/packages/syft/src/syft/service/action/action_graph_service.py index 6e06a9c84a0..886669f7deb 100644 --- a/packages/syft/src/syft/service/action/action_graph_service.py +++ b/packages/syft/src/syft/service/action/action_graph_service.py @@ -4,7 +4,7 @@ from typing import Union # third party -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError # relative from ...node.credentials import SyftVerifyKey @@ -118,10 +118,7 @@ def _extract_input_and_output_from_action( for _, kwarg in action.kwargs.items(): input_uids.add(kwarg.id) - if action.result_id is not None: - output_uid = action.result_id.id - else: - output_uid = None + output_uid = action.result_id.id if action.result_id is not None else None return input_uids, output_uid diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index bbff783a49a..e07953b1fa5 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -15,12 +15,17 @@ from typing import Dict from typing import List from typing import Optional +from typing import TYPE_CHECKING from typing import Tuple from typing import Type from typing import Union +from typing import cast # third party -import pydantic +from pydantic import ConfigDict +from pydantic import Field +from pydantic import field_validator +from pydantic import model_validator from result import Err from result import Ok from result import Result @@ -37,13 +42,10 @@ from ...service.response import SyftError from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftBaseObject from ...types.syft_object import SyftObject -from ...types.transforms import drop -from ...types.transforms import make_set_default from ...types.uid import LineageID from ...types.uid import UID from ...util.logger import debug @@ -57,6 +59,10 @@ from .action_types import action_type_for_type from .action_types import action_types +if TYPE_CHECKING: + # relative + from ..sync.diff_state import AttrDiff + NoneType = type(None) @@ -81,40 +87,6 @@ def repr_cls(c: Any) -> str: return f"{c.__module__}.{c.__name__}" -@serializable() -class ActionV1(SyftObject): - """Serializable Action object. - - Parameters: - path: str - The path of the Type of the remote object. - op: str - The method to be executed from the remote object. - remote_self: Optional[LineageID] - The extended UID of the SyftObject - args: List[LineageID] - `op` args - kwargs: Dict[str, LineageID] - `op` kwargs - result_id: Optional[LineageID] - Extended UID of the resulted SyftObject - """ - - __canonical_name__ = "Action" - __version__ = SYFT_OBJECT_VERSION_1 - - __attr_searchable__: List[str] = [] - - path: str - op: str - remote_self: Optional[LineageID] - args: List[LineageID] - kwargs: Dict[str, LineageID] - result_id: Optional[LineageID] - action_type: Optional[ActionType] - create_object: Optional[SyftObject] = None - - @serializable() class Action(SyftObject): """Serializable Action object. @@ -135,28 +107,23 @@ class Action(SyftObject): """ __canonical_name__ = "Action" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 - __attr_searchable__: List[str] = [] + __attr_searchable__: ClassVar[List[str]] = [] - path: Optional[str] - op: Optional[str] - remote_self: Optional[LineageID] + path: Optional[str] = None + op: Optional[str] = None + remote_self: Optional[LineageID] = None args: List[LineageID] kwargs: Dict[str, LineageID] - result_id: Optional[LineageID] - action_type: Optional[ActionType] + result_id: LineageID = Field(default_factory=lambda: LineageID(UID())) + action_type: Optional[ActionType] = None create_object: Optional[SyftObject] = None user_code_id: Optional[UID] = None - @pydantic.validator("id", pre=True, always=True) - def make_id(cls, v: Optional[UID]) -> UID: - """Generate or reuse an UID""" - return v if isinstance(v, UID) else UID() - - @pydantic.validator("result_id", pre=True, always=True) - def make_result_id(cls, v: Optional[Union[UID, LineageID]]) -> UID: - """Generate or reuse a LineageID""" + @field_validator("result_id", mode="before") + @classmethod + def make_result_id(cls, v: Any) -> LineageID: return v if isinstance(v, LineageID) else LineageID(v) @property @@ -166,11 +133,11 @@ def full_path(self) -> str: @property def job_display_name(self) -> str: - if self.user_code_id is not None: - api = APIRegistry.api_for( - node_uid=self.syft_node_location, - user_verify_key=self.syft_client_verify_key, - ) + api = APIRegistry.api_for( + node_uid=self.syft_node_location, + user_verify_key=self.syft_client_verify_key, + ) + if self.user_code_id is not None and api is not None: user_code = api.services.code.get_by_id(self.user_code_id) return user_code.service_func_name else: @@ -248,20 +215,6 @@ def repr_uid(_id: LineageID) -> str: ) -@migrate(Action, ActionV1) -def downgrade_action_v2_to_v1() -> list[Callable]: - return [ - drop("user_code_id"), - make_set_default("op", ""), - make_set_default("path", ""), - ] - - -@migrate(ActionV1, Action) -def upgrade_action_v1_to_v2() -> list[Callable]: - return [make_set_default("user_code_id", None)] - - class ActionObjectPointer: pass @@ -299,6 +252,8 @@ class ActionObjectPointer: "_save_to_blob_storage_", # syft "syft_action_data", # syft "syft_resolved", # syft + "syft_action_data_node_id", + "node_uid", "migrate_to", # syft "to_dict", # syft "dict", # syft @@ -307,6 +262,41 @@ class ActionObjectPointer: "__include_fields__", # pydantic "_calculate_keys", # pydantic "_get_value", # pydantic + "__pydantic_validator__", # pydantic + "__class_vars__", # pydantic + "__private_attributes__", # pydantic + "__signature__", # pydantic + "__pydantic_complete__", # pydantic + "__pydantic_core_schema__", # pydantic + "__pydantic_custom_init__", # pydantic + "__pydantic_decorators__", # pydantic + "__pydantic_generic_metadata__", # pydantic + "__pydantic_parent_namespace__", # pydantic + "__pydantic_post_init__", # pydantic + "__pydantic_root_model__", # pydantic + "__pydantic_serializer__", # pydantic + "__pydantic_validator__", # pydantic + "__pydantic_extra__", # pydantic + "__pydantic_fields_set__", # pydantic + "__pydantic_private__", # pydantic + "model_config", # pydantic + "model_computed_fields", # pydantic + "model_extra", # pydantic + "model_fields", # pydantic + "model_fields_set", # pydantic + "model_construct", # pydantic + "model_copy", # pydantic + "model_dump", # pydantic + "model_dump_json", # pydantic + "model_json_schema", # pydantic + "model_parametrized_name", # pydantic + "model_post_init", # pydantic + "model_rebuild", # pydantic + "model_validate", # pydantic + "model_validate_json", # pydantic + "copy", # pydantic + "__sha256__", # syft + "__hash_exclude_attrs__", # syft ] dont_wrap_output_attrs = [ "__repr__", @@ -320,6 +310,10 @@ class ActionObjectPointer: "__bool__", "__len__", "syft_resolved", # syft + "node_uid", + "syft_action_data_node_id", + "__sha256__", + "__hash_exclude_attrs__", ] dont_make_side_effects = [ "_repr_html_", @@ -331,6 +325,10 @@ class ActionObjectPointer: "__len__", "shape", "syft_resolved", # syft + "node_uid", + "syft_action_data_node_id", + "__sha256__", + "__hash_exclude_attrs__", ] action_data_empty_must_run = [ "__repr__", @@ -357,13 +355,13 @@ class PreHookContext(SyftBaseObject): The action generated by the current hook """ - obj: Any + obj: Any = None op_name: str - node_uid: Optional[UID] - result_id: Optional[Union[UID, LineageID]] - result_twin_type: Optional[TwinMode] - action: Optional[Action] - action_type: Optional[ActionType] + node_uid: Optional[UID] = None + result_id: Optional[Union[UID, LineageID]] = None + result_twin_type: Optional[TwinMode] = None + action: Optional[Action] = None + action_type: Optional[ActionType] = None def make_action_side_effect( @@ -401,7 +399,7 @@ def make_action_side_effect( class TraceResult: result: list = [] - _client: SyftClient = None + _client: Optional[SyftClient] = None is_tracing: bool = False @classmethod @@ -432,7 +430,10 @@ def convert_to_pointers( kwarg_dict = {} if args is not None: for arg in args: - if not isinstance(arg, (ActionObject, Asset, UID)): + if ( + not isinstance(arg, (ActionObject, Asset, UID)) + and api.signing_key is not None # type: ignore[unreachable] + ): arg = ActionObject.from_obj( # type: ignore[unreachable] syft_action_data=arg, syft_client_verify_key=api.signing_key.verify_key, @@ -447,7 +448,10 @@ def convert_to_pointers( if kwargs is not None: for k, arg in kwargs.items(): - if not isinstance(arg, (ActionObject, Asset, UID)): + if ( + not isinstance(arg, (ActionObject, Asset, UID)) + and api.signing_key is not None # type: ignore[unreachable] + ): arg = ActionObject.from_obj( # type: ignore[unreachable] syft_action_data=arg, syft_client_verify_key=api.signing_key.verify_key, @@ -546,7 +550,7 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]: return tuple(filtered_args), filtered_kwargs -BASE_PASSTHROUGH_ATTRS = [ +BASE_PASSTHROUGH_ATTRS: list[str] = [ "is_mock", "is_real", "is_twin", @@ -569,73 +573,73 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]: "syft_action_data_cache", "reload_cache", "syft_resolved", + "refresh_object", + "syft_action_data_node_id", + "node_uid", + "__sha256__", + "__hash_exclude_attrs__", + "__hash__", ] -@serializable() -class ActionObjectV1(SyftObject): +@serializable(without=["syft_pre_hooks__", "syft_post_hooks__"]) +class ActionObject(SyftObject): """Action object for remote execution.""" __canonical_name__ = "ActionObject" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_3 - __attr_searchable__: List[str] = [] + __attr_searchable__: List[str] = [] # type: ignore[misc] syft_action_data_cache: Optional[Any] = None syft_blob_storage_entry_id: Optional[UID] = None syft_pointer_type: ClassVar[Type[ActionObjectPointer]] # Help with calculating history hash for code verification - syft_parent_hashes: Optional[Union[int, List[int]]] - syft_parent_op: Optional[str] - syft_parent_args: Optional[Any] - syft_parent_kwargs: Optional[Any] - syft_history_hash: Optional[int] + syft_parent_hashes: Optional[Union[int, List[int]]] = None + syft_parent_op: Optional[str] = None + syft_parent_args: Optional[Any] = None + syft_parent_kwargs: Optional[Any] = None + syft_history_hash: Optional[int] = None syft_internal_type: ClassVar[Type[Any]] - syft_node_uid: Optional[UID] - _syft_pre_hooks__: Dict[str, List] = {} - _syft_post_hooks__: Dict[str, List] = {} + syft_node_uid: Optional[UID] = None + syft_pre_hooks__: Dict[str, List] = {} + syft_post_hooks__: Dict[str, List] = {} syft_twin_type: TwinMode = TwinMode.NONE - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_action_data_type: Optional[Type] - syft_action_data_repr_: Optional[str] - syft_action_data_str_: Optional[str] - syft_has_bool_attr: Optional[bool] - syft_resolve_data: Optional[bool] - syft_created_at: Optional[DateTime] + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS + syft_action_data_type: Optional[Type] = None + syft_action_data_repr_: Optional[str] = None + syft_action_data_str_: Optional[str] = None + syft_has_bool_attr: Optional[bool] = None + syft_resolve_data: Optional[bool] = None + syft_created_at: Optional[DateTime] = None + syft_resolved: bool = True + syft_action_data_node_id: Optional[UID] = None + # syft_dont_wrap_attrs = ["shape"] + def get_diff(self, ext_obj: Any) -> List[AttrDiff]: + # relative + from ...service.sync.diff_state import AttrDiff -@serializable() -class ActionObject(SyftObject): - """Action object for remote execution.""" + diff_attrs = [] - __canonical_name__ = "ActionObject" - __version__ = SYFT_OBJECT_VERSION_2 + # Sanity check + if ext_obj.id != self.id: + raise Exception("Not the same id for low side and high side requests") - __attr_searchable__: List[str] = [] - syft_action_data_cache: Optional[Any] = None - syft_blob_storage_entry_id: Optional[UID] = None - syft_pointer_type: ClassVar[Type[ActionObjectPointer]] + low_data = ext_obj.syft_action_data + high_data = self.syft_action_data + if low_data != high_data: + diff_attr = AttrDiff( + attr_name="syft_action_data", low_attr=low_data, high_attr=high_data + ) + diff_attrs.append(diff_attr) + return diff_attrs - # Help with calculating history hash for code verification - syft_parent_hashes: Optional[Union[int, List[int]]] - syft_parent_op: Optional[str] - syft_parent_args: Optional[Any] - syft_parent_kwargs: Optional[Any] - syft_history_hash: Optional[int] - syft_internal_type: ClassVar[Type[Any]] - syft_node_uid: Optional[UID] - _syft_pre_hooks__: Dict[str, List] = {} - _syft_post_hooks__: Dict[str, List] = {} - syft_twin_type: TwinMode = TwinMode.NONE - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_action_data_type: Optional[Type] - syft_action_data_repr_: Optional[str] - syft_action_data_str_: Optional[str] - syft_has_bool_attr: Optional[bool] - syft_resolve_data: Optional[bool] - syft_created_at: Optional[DateTime] - syft_resolved: bool = True - # syft_dont_wrap_attrs = ["shape"] + def _set_obj_location_(self, node_uid: UID, credentials: SyftVerifyKey) -> None: + self.syft_node_location = node_uid + self.syft_client_verify_key = credentials + if self.syft_action_data_node_id is None: + self.syft_action_data_node_id = node_uid @property def syft_action_data(self) -> Any: @@ -661,6 +665,12 @@ def reload_cache(self) -> Optional[SyftError]: blob_retrieval_object = blob_storage_read_method( uid=self.syft_blob_storage_entry_id ) + if isinstance(blob_retrieval_object, SyftError): + print( + "Could not fetch actionobject data\n", + type(blob_retrieval_object), + ) + return blob_retrieval_object # relative from ...store.blob_storage import BlobRetrieval @@ -697,6 +707,9 @@ def _save_to_blob_storage_(self, data: Any) -> Optional[SyftError]: data.upload_to_blobstorage_from_api(api) else: storage_entry = CreateBlobStorageEntry.from_obj(data) + if self.syft_blob_storage_entry_id is not None: + # TODO: check if it already exists + storage_entry.id = self.syft_blob_storage_entry_id allocate_method = from_api_or_context( func_or_path="blob_storage.allocate", syft_node_location=self.syft_node_location, @@ -760,15 +773,10 @@ def syft_lineage_id(self) -> LineageID: """Compute the LineageID of the ActionObject, using the `id` and the `syft_history_hash` memebers""" return LineageID(self.id, self.syft_history_hash) - @pydantic.validator("id", pre=True, always=True) - def make_id(cls, v: Optional[UID]) -> UID: - """Generate or reuse an UID""" - return Action.make_id(v) - - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) - @pydantic.root_validator() + @model_validator(mode="before") + @classmethod def __check_action_data(cls, values: dict) -> dict: v = values.get("syft_action_data_cache") if values.get("syft_action_data_type", None) is None: @@ -798,18 +806,6 @@ def is_real(self) -> bool: def is_twin(self) -> bool: return self.syft_twin_type != TwinMode.NONE - # @pydantic.validator("syft_action_data", pre=True, always=True) - # def check_action_data( - # cls, v: ActionObject.syft_pointer_type - # ) -> ActionObject.syft_pointer_type: - # if cls == AnyActionObject or isinstance( - # v, (cls.syft_internal_type, ActionDataEmpty) - # ): - # return v - # raise SyftException( - # f"Must init {cls} with {cls.syft_internal_type} not {type(v)}" - # ) - def syft_point_to(self, node_uid: UID) -> ActionObject: """Set the syft_node_uid, used in the post hooks""" self.syft_node_uid = node_uid @@ -852,7 +848,8 @@ def syft_execute_action( node_uid=self.syft_node_uid, user_verify_key=self.syft_client_verify_key, ) - + if api is None: + raise ValueError(f"api is None. You must login to {self.syft_node_uid}") kwargs = {"action": action} api_call = SyftAPICall( node_uid=self.syft_node_uid, path="action.execute", args=[], kwargs=kwargs @@ -868,7 +865,8 @@ def request(self, client: SyftClient) -> Union[Any, SyftError]: permission_change = ActionStoreChange( linked_obj=action_object_link, apply_permission_type=ActionPermission.READ ) - + if client.credentials is None: + return SyftError(f"{client} has no signing key") submit_request = SubmitRequest( changes=[permission_change], requesting_user_verify_key=client.credentials.verify_key, @@ -880,6 +878,9 @@ def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: return elif obj.syft_node_uid is not None: return + + if obj.syft_blob_storage_entry_id is not None: + return # TODO fix: the APIRegistry often gets the wrong client # if you have 2 clients in memory # therefore the following happens if you call a method @@ -898,11 +899,8 @@ def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: if TraceResult._client is not None: api = TraceResult._client.api - if api is not None: + if api is not None and api.signing_key is not None: obj._set_obj_location_(api.node_uid, api.signing_key.verify_key) - res = obj._save_to_blob_storage() - if isinstance(res, SyftError): - print(f"failed saving {obj} to blob storage, error: {res}") action = Action( path="", @@ -922,6 +920,12 @@ def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + print( + f"failed saving {obj} to blob storage, api is None. You must login to {self.syft_node_location}" + ) + + api = cast(SyftAPI, api) res = api.services.action.execute(action) if isinstance(res, SyftError): print(f"Failed to to store (arg) {obj} to store, {res}") @@ -941,12 +945,8 @@ def _syft_prepare_obj_uid(self, obj: Any) -> LineageID: return obj.syft_lineage_id # We got a raw object. We need to create the ActionObject from scratch and save it in the store. - obj_id = Action.make_id(None) - lin_obj_id = Action.make_result_id(obj_id) act_obj = ActionObject.from_obj( obj, - id=obj_id, - syft_lineage_id=lin_obj_id, syft_client_verify_key=self.syft_client_verify_key, syft_node_location=self.syft_node_location, ) @@ -1095,22 +1095,33 @@ def get_from(self, client: SyftClient) -> Any: else: return res.syft_action_data - def get(self, block: bool = False) -> Any: - """Get the object from a Syft Client""" + def refresh_object(self) -> ActionObject: # relative from ...client.api import APIRegistry - if block: - self.wait() - api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + return SyftError( + message=f"api is None. You must login to {self.syft_node_location}" + ) + res = api.services.action.get(self.id) + return res + + def get(self, block: bool = False) -> Any: + """Get the object from a Syft Client""" + # relative + + if block: + self.wait() + + res = self.refresh_object() if not isinstance(res, ActionObject): - return SyftError(message=f"{res}") + return SyftError(message=f"{res}") # type: ignore else: nested_res = res.syft_action_data if isinstance(nested_res, ActionObject): @@ -1124,7 +1135,11 @@ def as_empty(self) -> ActionObject: if isinstance(id, LineageID): id = id.id return ActionObject.empty( - self.syft_internal_type, id, self.syft_lineage_id, self.syft_resolved + self.syft_internal_type, + id, + self.syft_lineage_id, + self.syft_resolved, + syft_blob_storage_entry_id=self.syft_blob_storage_entry_id, ) @staticmethod @@ -1174,6 +1189,8 @@ def from_obj( syft_client_verify_key: Optional[SyftVerifyKey] = None, syft_node_location: Optional[UID] = None, syft_resolved: Optional[bool] = True, + data_node_id: Optional[UID] = None, + syft_blob_storage_entry_id: Optional[UID] = None, ) -> ActionObject: """Create an ActionObject from an existing object. @@ -1190,6 +1207,8 @@ def from_obj( action_type = action_type_for_object(syft_action_data) action_object = action_type(syft_action_data_cache=syft_action_data) + action_object.syft_blob_storage_entry_id = syft_blob_storage_entry_id + action_object.syft_action_data_node_id = data_node_id action_object.syft_resolved = syft_resolved if id is not None: @@ -1212,13 +1231,13 @@ def from_obj( @classmethod def add_trace_hook(cls) -> bool: return True - # if trace_action_side_effect not in self._syft_pre_hooks__[HOOK_ALWAYS]: - # self._syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) + # if trace_action_side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: + # self.syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) @classmethod def remove_trace_hook(cls) -> bool: return True - # self._syft_pre_hooks__[HOOK_ALWAYS].pop(trace_action_side_effct, None) + # self.syft_pre_hooks__[HOOK_ALWAYS].pop(trace_action_side_effct, None) def as_empty_data(self) -> ActionDataEmpty: return ActionDataEmpty(syft_internal_type=self.syft_internal_type) @@ -1236,7 +1255,7 @@ def wait(self) -> ActionObject: else: obj_id = self.id - while not api.services.action.is_resolved(obj_id): + while api and not api.services.action.is_resolved(obj_id): time.sleep(1) return self @@ -1265,14 +1284,17 @@ def obj_not_ready( ) return res - @staticmethod + @classmethod def empty( # TODO: fix the mypy issue + cls, syft_internal_type: Optional[Type[Any]] = None, id: Optional[UID] = None, syft_lineage_id: Optional[LineageID] = None, syft_resolved: Optional[bool] = True, - ) -> ActionObject: + data_node_id: Optional[UID] = None, + syft_blob_storage_entry_id: Optional[UID] = None, + ) -> Self: """Create an ActionObject from a type, using a ActionDataEmpty object Parameters: @@ -1288,44 +1310,46 @@ def empty( type(None) if syft_internal_type is None else syft_internal_type ) empty = ActionDataEmpty(syft_internal_type=syft_internal_type) - res = ActionObject.from_obj( + res = cls.from_obj( id=id, syft_lineage_id=syft_lineage_id, syft_action_data=empty, syft_resolved=syft_resolved, + data_node_id=data_node_id, + syft_blob_storage_entry_id=syft_blob_storage_entry_id, ) res.__dict__["syft_internal_type"] = syft_internal_type return res def __post_init__(self) -> None: """Add pre/post hooks.""" - if HOOK_ALWAYS not in self._syft_pre_hooks__: - self._syft_pre_hooks__[HOOK_ALWAYS] = [] + if HOOK_ALWAYS not in self.syft_pre_hooks__: + self.syft_pre_hooks__[HOOK_ALWAYS] = [] - if HOOK_ON_POINTERS not in self._syft_post_hooks__: - self._syft_pre_hooks__[HOOK_ON_POINTERS] = [] + if HOOK_ON_POINTERS not in self.syft_post_hooks__: + self.syft_pre_hooks__[HOOK_ON_POINTERS] = [] # this should be a list as orders matters for side_effect in [make_action_side_effect]: - if side_effect not in self._syft_pre_hooks__[HOOK_ALWAYS]: - self._syft_pre_hooks__[HOOK_ALWAYS].append(side_effect) + if side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: + self.syft_pre_hooks__[HOOK_ALWAYS].append(side_effect) for side_effect in [send_action_side_effect]: - if side_effect not in self._syft_pre_hooks__[HOOK_ON_POINTERS]: - self._syft_pre_hooks__[HOOK_ON_POINTERS].append(side_effect) + if side_effect not in self.syft_pre_hooks__[HOOK_ON_POINTERS]: + self.syft_pre_hooks__[HOOK_ON_POINTERS].append(side_effect) - if trace_action_side_effect not in self._syft_pre_hooks__[HOOK_ALWAYS]: - self._syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) + if trace_action_side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: + self.syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) - if HOOK_ALWAYS not in self._syft_post_hooks__: - self._syft_post_hooks__[HOOK_ALWAYS] = [] + if HOOK_ALWAYS not in self.syft_post_hooks__: + self.syft_post_hooks__[HOOK_ALWAYS] = [] - if HOOK_ON_POINTERS not in self._syft_post_hooks__: - self._syft_post_hooks__[HOOK_ON_POINTERS] = [] + if HOOK_ON_POINTERS not in self.syft_post_hooks__: + self.syft_post_hooks__[HOOK_ON_POINTERS] = [] for side_effect in [propagate_node_uid]: - if side_effect not in self._syft_post_hooks__[HOOK_ALWAYS]: - self._syft_post_hooks__[HOOK_ALWAYS].append(side_effect) + if side_effect not in self.syft_post_hooks__[HOOK_ALWAYS]: + self.syft_post_hooks__[HOOK_ALWAYS].append(side_effect) if isinstance(self.syft_action_data_type, ActionObject): raise Exception("Nested ActionObjects", self.syft_action_data_repr_) @@ -1337,16 +1361,16 @@ def _syft_run_pre_hooks__( ) -> Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]: """Hooks executed before the actual call""" result_args, result_kwargs = args, kwargs - if name in self._syft_pre_hooks__: - for hook in self._syft_pre_hooks__[name]: + if name in self.syft_pre_hooks__: + for hook in self.syft_pre_hooks__[name]: result = hook(context, *result_args, **result_kwargs) if result.is_ok(): context, result_args, result_kwargs = result.ok() else: debug(f"Pre-hook failed with {result.err()}") if name not in self._syft_dont_wrap_attrs(): - if HOOK_ALWAYS in self._syft_pre_hooks__: - for hook in self._syft_pre_hooks__[HOOK_ALWAYS]: + if HOOK_ALWAYS in self.syft_pre_hooks__: + for hook in self.syft_pre_hooks__[HOOK_ALWAYS]: result = hook(context, *result_args, **result_kwargs) if result.is_ok(): context, result_args, result_kwargs = result.ok() @@ -1356,8 +1380,8 @@ def _syft_run_pre_hooks__( if self.is_pointer: if name not in self._syft_dont_wrap_attrs(): - if HOOK_ALWAYS in self._syft_pre_hooks__: - for hook in self._syft_pre_hooks__[HOOK_ON_POINTERS]: + if HOOK_ALWAYS in self.syft_pre_hooks__: + for hook in self.syft_pre_hooks__[HOOK_ON_POINTERS]: result = hook(context, *result_args, **result_kwargs) if result.is_ok(): context, result_args, result_kwargs = result.ok() @@ -1372,8 +1396,8 @@ def _syft_run_post_hooks__( ) -> Any: """Hooks executed after the actual call""" new_result = result - if name in self._syft_post_hooks__: - for hook in self._syft_post_hooks__[name]: + if name in self.syft_post_hooks__: + for hook in self.syft_post_hooks__[name]: result = hook(context, name, new_result) if result.is_ok(): new_result = result.ok() @@ -1381,8 +1405,8 @@ def _syft_run_post_hooks__( debug(f"Post hook failed with {result.err()}") if name not in self._syft_dont_wrap_attrs(): - if HOOK_ALWAYS in self._syft_post_hooks__: - for hook in self._syft_post_hooks__[HOOK_ALWAYS]: + if HOOK_ALWAYS in self.syft_post_hooks__: + for hook in self.syft_post_hooks__[HOOK_ALWAYS]: result = hook(context, name, new_result) if result.is_ok(): new_result = result.ok() @@ -1391,8 +1415,8 @@ def _syft_run_post_hooks__( if self.is_pointer: if name not in self._syft_dont_wrap_attrs(): - if HOOK_ALWAYS in self._syft_post_hooks__: - for hook in self._syft_post_hooks__[HOOK_ON_POINTERS]: + if HOOK_ALWAYS in self.syft_post_hooks__: + for hook in self.syft_post_hooks__[HOOK_ON_POINTERS]: result = hook(context, name, new_result) if result.is_ok(): new_result = result.ok() @@ -1666,6 +1690,9 @@ def __getattribute__(self, name: str) -> Any: if name.startswith("_syft") or name.startswith("syft"): return object.__getattribute__(self, name) + if name in passthrough_attrs: + return object.__getattribute__(self, name) + # third party if name in self._syft_passthrough_attrs(): return object.__getattribute__(self, name) @@ -1709,7 +1736,7 @@ def __setattr__(self, name: str, value: Any) -> Any: # if we do not implement these boiler plate __method__'s then special infix # operations like x + y won't trigger __getattribute__ # unless there is a super special reason we should write no code in these functions - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: if self.is_mock: res = "TwinPointer(Mock)" elif self.is_real: @@ -1762,6 +1789,9 @@ def __str__(self) -> str: def __len__(self) -> int: return self.__len__() + def __hash__(self, *args: Any, **kwargs: Any) -> int: + return super().__hash__(*args, **kwargs) + def __getitem__(self, key: Any) -> Any: return self._syft_output_action_object(self.__getitem__(key)) @@ -1906,39 +1936,15 @@ def __rrshift__(self, other: Any) -> Any: return self._syft_output_action_object(self.__rrshift__(other)) -@migrate(ActionObject, ActionObjectV1) -def downgrade_actionobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_resolved"), - ] - - -@migrate(ActionObjectV1, ActionObject) -def upgrade_actionobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_resolved", True), - ] - - -@serializable() -class AnyActionObjectV1(ActionObjectV1): - __canonical_name__ = "AnyActionObject" - __version__ = SYFT_OBJECT_VERSION_1 - - syft_internal_type: ClassVar[Type[Any]] = NoneType # type: ignore - # syft_passthrough_attrs: List[str] = [] - syft_dont_wrap_attrs: List[str] = ["__str__", "__repr__", "syft_action_data_str_"] - - @serializable() class AnyActionObject(ActionObject): __canonical_name__ = "AnyActionObject" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 syft_internal_type: ClassVar[Type[Any]] = NoneType # type: ignore # syft_passthrough_attrs: List[str] = [] syft_dont_wrap_attrs: List[str] = ["__str__", "__repr__", "syft_action_data_str_"] - syft_action_data_str_ = "" + syft_action_data_str_: str = "" def __float__(self) -> float: return float(self.syft_action_data) @@ -1947,22 +1953,6 @@ def __int__(self) -> float: return int(self.syft_action_data) -@migrate(AnyActionObject, AnyActionObjectV1) -def downgrade_anyactionobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_action_data_str"), - drop("syft_resolved"), - ] - - -@migrate(AnyActionObjectV1, AnyActionObject) -def upgrade_anyactionobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_action_data_str", ""), - make_set_default("syft_resolved", True), - ] - - action_types[Any] = AnyActionObject diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 1ce4552dc12..76984451392 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -1,5 +1,7 @@ # stdlib from enum import Enum +from typing import Any +from typing import Dict from typing import Optional # relative @@ -50,6 +52,13 @@ def permission_string(self) -> str: return f"{self.credentials.verify}_{self.permission.name}" return f"{self.permission.name}" + def _coll_repr_(self) -> Dict[str, Any]: + return { + "uid": str(self.uid), + "credentials": str(self.credentials), + "permission": str(self.permission), + } + def __repr__(self) -> str: if self.credentials is not None: return f"[{self.permission.name}: {self.uid} as {self.credentials.verify}]" diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 056cb90e7af..806d2ad6a37 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -35,6 +35,7 @@ from ..service import TYPE_TO_SERVICE from ..service import UserLibConfigRegistry from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL from ..user.user_roles import ServiceRole from .action_object import Action @@ -102,7 +103,7 @@ def _set( if isinstance(action_object, ActionObject): action_object.syft_created_at = DateTime.now() else: - action_object.private_obj.syft_created_at = DateTime.now() + action_object.private_obj.syft_created_at = DateTime.now() # type: ignore[unreachable] action_object.mock_obj.syft_created_at = DateTime.now() # If either context or argument is True, has_result_read_permission is True @@ -230,11 +231,11 @@ def _get( ) # Resolve graph links if ( - not isinstance(obj, TwinObject) + not isinstance(obj, TwinObject) # type: ignore[unreachable] and resolve_nested and isinstance(obj.syft_action_data, ActionDataLink) ): - if not self.is_resolved( + if not self.is_resolved( # type: ignore[unreachable] context, obj.syft_action_data.action_object_id.id ).ok(): return SyftError(message="This object is not resolved yet.") @@ -299,7 +300,7 @@ def _user_code_execute( ) if not override_execution_permission: - input_policy = code_item.input_policy + input_policy = code_item.get_input_policy(context) if input_policy is None: if not code_item.output_policy_approved: return Err("Execution denied: Your code is waiting for approval") @@ -313,11 +314,13 @@ def _user_code_execute( else: filtered_kwargs = retrieve_from_db(code_item.id, kwargs, context).ok() # update input policy to track any input state - # code_item.input_policy = input_policy - if not override_execution_permission and code_item.input_policy is not None: + if ( + not override_execution_permission + and code_item.get_input_policy(context) is not None + ): expected_input_kwargs = set() - for _inp_kwarg in code_item.input_policy.inputs.values(): + for _inp_kwarg in code_item.get_input_policy(context).inputs.values(): # type: ignore keys = _inp_kwarg.keys() for k in keys: if k not in kwargs: @@ -397,7 +400,7 @@ def _user_code_execute( def set_result_to_store( self, - result_action_object: ActionObject, + result_action_object: Union[ActionObject, TwinObject], context: AuthedServiceContext, output_policy: Optional[OutputPolicy] = None, ) -> Union[Result[ActionObject, str], SyftError]: @@ -427,7 +430,7 @@ def set_result_to_store( if isinstance(result_action_object, TwinObject): result_blob_id = result_action_object.private.syft_blob_storage_entry_id else: - result_blob_id = result_action_object.syft_blob_storage_entry_id + result_blob_id = result_action_object.syft_blob_storage_entry_id # type: ignore[unreachable] # pass permission information to the action store as extra kwargs context.extra_kwargs = {"has_result_read_permission": True} @@ -437,7 +440,7 @@ def set_result_to_store( if set_result.is_err(): return set_result - blob_storage_service: BlobStorageService = context.node.get_service( + blob_storage_service: AbstractService = context.node.get_service( BlobStorageService ) @@ -550,7 +553,7 @@ def set_attribute( ) else: # TODO: Implement for twinobject args - args = filter_twin_args(args, twin_mode=TwinMode.NONE) + args = filter_twin_args(args, twin_mode=TwinMode.NONE) # type: ignore[unreachable] val = args[0] setattr(resolved_self.syft_action_data, name, val) return Ok( @@ -577,7 +580,7 @@ def get_attribute( ) ) else: - val = getattr(resolved_self.syft_action_data, action.op) + val = getattr(resolved_self.syft_action_data, action.op) # type: ignore[unreachable] return Ok(wrap_result(action.result_id, val)) def call_method( @@ -620,7 +623,7 @@ def call_method( ) ) else: - return execute_object(self, context, resolved_self, action) + return execute_object(self, context, resolved_self, action) # type:ignore[unreachable] @service_method(path="action.execute", name="execute", roles=GUEST_ROLE_LEVEL) def execute( @@ -732,6 +735,15 @@ def exists( else: return SyftError(message=f"Object: {obj_id} does not exist") + @service_method(path="action.delete", name="delete", roles=ADMIN_ROLE_LEVEL) + def delete( + self, context: AuthedServiceContext, uid: UID + ) -> Union[SyftSuccess, SyftError]: + res = self.store.delete(context.credentials, uid) + if res.is_err(): + return SyftError(message=res.err()) + return SyftSuccess(message="Great Success!") + def resolve_action_args( action: Action, context: AuthedServiceContext, service: ActionService diff --git a/packages/syft/src/syft/service/action/numpy.py b/packages/syft/src/syft/service/action/numpy.py index dfd43907b92..78fcc905376 100644 --- a/packages/syft/src/syft/service/action/numpy.py +++ b/packages/syft/src/syft/service/action/numpy.py @@ -1,7 +1,7 @@ # stdlib from typing import Any -from typing import Callable from typing import ClassVar +from typing import List from typing import Type from typing import Union @@ -11,14 +11,9 @@ # relative from ...serde.serializable import serializable -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.transforms import drop -from ...types.transforms import make_set_default +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from .action_object import ActionObject from .action_object import ActionObjectPointer -from .action_object import ActionObjectV1 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types @@ -46,28 +41,17 @@ def numpy_like_eq(left: Any, right: Any) -> bool: return bool(result) -@serializable() -class NumpyArrayObjectV1(ActionObjectV1, np.lib.mixins.NDArrayOperatorsMixin): - __canonical_name__ = "NumpyArrayObject" - __version__ = SYFT_OBJECT_VERSION_1 - - syft_internal_type: ClassVar[Type[Any]] = np.ndarray - syft_pointer_type = NumpyArrayObjectPointer - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_dont_wrap_attrs = ["dtype", "shape"] - - # 🔵 TODO 7: Map TPActionObjects and their 3rd Party types like numpy type to these # classes for bi-directional lookup. @serializable() class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyArrayObject" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 syft_internal_type: ClassVar[Type[Any]] = np.ndarray - syft_pointer_type = NumpyArrayObjectPointer - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_dont_wrap_attrs = ["dtype", "shape"] + syft_pointer_type: ClassVar[Type[ActionObjectPointer]] = NumpyArrayObjectPointer + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: List[str] = ["dtype", "shape"] # def __eq__(self, other: Any) -> bool: # # 🟡 TODO 8: move __eq__ to a Data / Serdeable type interface on ActionObject @@ -82,9 +66,11 @@ def __array_ufunc__( self, ufunc: Any, method: str, *inputs: Any, **kwargs: Any ) -> Union[Self, tuple[Self, ...]]: inputs = tuple( - np.array(x.syft_action_data, dtype=x.dtype) - if isinstance(x, NumpyArrayObject) - else x + ( + np.array(x.syft_action_data, dtype=x.dtype) + if isinstance(x, NumpyArrayObject) + else x + ) for x in inputs ) @@ -100,89 +86,27 @@ def __array_ufunc__( ) -@migrate(NumpyArrayObject, NumpyArrayObjectV1) -def downgrade_numpyarrayobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_resolved"), - ] - - -@migrate(NumpyArrayObjectV1, NumpyArrayObject) -def upgrade_numpyarrayobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_resolved", True), - ] - - -@serializable() -class NumpyScalarObjectV1(ActionObjectV1, np.lib.mixins.NDArrayOperatorsMixin): - __canonical_name__ = "NumpyScalarObject" - __version__ = SYFT_OBJECT_VERSION_1 - - syft_internal_type = np.number - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_dont_wrap_attrs = ["dtype", "shape"] - - @serializable() class NumpyScalarObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyScalarObject" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 - syft_internal_type = np.number - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_dont_wrap_attrs = ["dtype", "shape"] + syft_internal_type: ClassVar[Type] = np.number + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: List[str] = ["dtype", "shape"] def __float__(self) -> float: return float(self.syft_action_data) -@migrate(NumpyScalarObject, NumpyScalarObjectV1) -def downgrade_numpyscalarobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_resolved"), - ] - - -@migrate(NumpyScalarObjectV1, NumpyScalarObject) -def upgrade_numpyscalarobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_resolved", True), - ] - - -@serializable() -class NumpyBoolObjectV1(ActionObjectV1, np.lib.mixins.NDArrayOperatorsMixin): - __canonical_name__ = "NumpyBoolObject" - __version__ = SYFT_OBJECT_VERSION_1 - - syft_internal_type = np.bool_ - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_dont_wrap_attrs = ["dtype", "shape"] - - @serializable() class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): __canonical_name__ = "NumpyBoolObject" - __version__ = SYFT_OBJECT_VERSION_2 - - syft_internal_type = np.bool_ - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - syft_dont_wrap_attrs = ["dtype", "shape"] - - -@migrate(NumpyBoolObject, NumpyBoolObjectV1) -def downgrade_numpyboolobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_resolved"), - ] - + __version__ = SYFT_OBJECT_VERSION_3 -@migrate(NumpyBoolObjectV1, NumpyBoolObject) -def upgrade_numpyboolobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_resolved", True), - ] + syft_internal_type: ClassVar[Type] = np.bool_ + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS + syft_dont_wrap_attrs: List[str] = ["dtype", "shape"] np_array = np.array([1, 2, 3]) @@ -212,5 +136,5 @@ def upgrade_numpyboolobject_v1_to_v2() -> list[Callable]: np.float64, ] -for scalar_type in SUPPORTED_INT_TYPES + SUPPORTED_FLOAT_TYPES: +for scalar_type in SUPPORTED_INT_TYPES + SUPPORTED_FLOAT_TYPES: # type: ignore action_types[scalar_type] = NumpyScalarObject diff --git a/packages/syft/src/syft/service/action/pandas.py b/packages/syft/src/syft/service/action/pandas.py index 2dac63f3b46..1d9d73f34d5 100644 --- a/packages/syft/src/syft/service/action/pandas.py +++ b/packages/syft/src/syft/service/action/pandas.py @@ -1,7 +1,7 @@ # stdlib from typing import Any -from typing import Callable from typing import ClassVar +from typing import List from typing import Type # third party @@ -10,33 +10,19 @@ # relative from ...serde.serializable import serializable -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.transforms import drop -from ...types.transforms import make_set_default +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from .action_object import ActionObject -from .action_object import ActionObjectV1 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types -@serializable() -class PandasDataFrameObjectV1(ActionObjectV1): - __canonical_name__ = "PandasDataframeObject" - __version__ = SYFT_OBJECT_VERSION_1 - - syft_internal_type: ClassVar[Type[Any]] = DataFrame - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - - @serializable() class PandasDataFrameObject(ActionObject): __canonical_name__ = "PandasDataframeObject" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 - syft_internal_type: ClassVar[Type[Any]] = DataFrame - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS + syft_internal_type: ClassVar[Type] = DataFrame + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS # this is added for instance checks for dataframes # syft_dont_wrap_attrs = ["shape"] @@ -56,36 +42,13 @@ def syft_is_property(self, obj: Any, method: str) -> bool: return super().syft_is_property(obj, method) -@migrate(PandasDataFrameObject, PandasDataFrameObjectV1) -def downgrade_pandasdataframeobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_resolved"), - ] - - -@migrate(PandasDataFrameObjectV1, PandasDataFrameObject) -def upgrade_pandasdataframeobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_resolved", True), - ] - - -@serializable() -class PandasSeriesObjectV1(ActionObjectV1): - __canonical_name__ = "PandasSeriesObject" - __version__ = SYFT_OBJECT_VERSION_1 - - syft_internal_type = Series - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS - - @serializable() class PandasSeriesObject(ActionObject): __canonical_name__ = "PandasSeriesObject" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 syft_internal_type = Series - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS # name: Optional[str] = None # syft_dont_wrap_attrs = ["shape"] @@ -105,19 +68,5 @@ def syft_is_property(self, obj: Any, method: str) -> bool: return super().syft_is_property(obj, method) -@migrate(PandasSeriesObject, PandasSeriesObjectV1) -def downgrade_pandasseriesframeobject_v2_to_v1() -> list[Callable]: - return [ - drop("syft_resolved"), - ] - - -@migrate(PandasSeriesObjectV1, PandasSeriesObject) -def upgrade_pandasseriesframeobject_v1_to_v2() -> list[Callable]: - return [ - make_set_default("syft_resolved", True), - ] - - action_types[DataFrame] = PandasDataFrameObject action_types[Series] = PandasSeriesObject diff --git a/packages/syft/src/syft/service/action/plan.py b/packages/syft/src/syft/service/action/plan.py index a2a81b6f473..6572e22585c 100644 --- a/packages/syft/src/syft/service/action/plan.py +++ b/packages/syft/src/syft/service/action/plan.py @@ -12,7 +12,7 @@ from ... import Worker from ...client.client import SyftClient from ...serde.recursive import recursive_serde_register -from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from .action_object import Action from .action_object import TraceResult @@ -20,8 +20,14 @@ class Plan(SyftObject): __canonical_name__ = "Plan" - __version__ = SYFT_OBJECT_VERSION_1 - syft_passthrough_attrs = ["inputs", "outputs", "code", "actions", "client"] + __version__ = SYFT_OBJECT_VERSION_2 + syft_passthrough_attrs: List[str] = [ + "inputs", + "outputs", + "code", + "actions", + "client", + ] inputs: Dict[str, ActionObject] outputs: List[ActionObject] diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 1a6bb94c932..6781828a287 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -201,9 +201,12 @@ def read( ) -> Union[BlobRetrieval, SyftError]: result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): - obj: BlobStorageEntry = result.ok() + obj: Optional[BlobStorageEntry] = result.ok() if obj is None: - return SyftError(message=f"No blob storage entry exists for uid: {uid}") + return SyftError( + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + ) + context.node = cast(AbstractNode, context.node) with context.node.blob_storage_client.connect() as conn: res: BlobRetrieval = conn.read( @@ -262,7 +265,9 @@ def write_to_disk( obj: Optional[BlobStorageEntry] = result.ok() if obj is None: - return SyftError(message=f"No blob storage entry exists for uid: {uid}") + return SyftError( + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + ) try: Path(obj.location.path).write_bytes(data) @@ -292,7 +297,9 @@ def mark_write_complete( obj: Optional[BlobStorageEntry] = result.ok() if obj is None: - return SyftError(message=f"No blob storage entry exists for uid: {uid}") + return SyftError( + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + ) obj.no_lines = no_lines result = self.stash.update( @@ -316,7 +323,9 @@ def delete( obj = result.ok() if obj is None: - return SyftError(message=f"No blob storage entry exists for uid: {uid}") + return SyftError( + message=f"No blob storage entry exists for uid: {uid}, or you have no permissions to read it" + ) context.node = cast(AbstractNode, context.node) diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py new file mode 100644 index 00000000000..352c1715abc --- /dev/null +++ b/packages/syft/src/syft/service/code/status_service.py @@ -0,0 +1,97 @@ +# stdlib +from typing import List +from typing import Union + +# third party +from result import Result + +# relative +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...store.document_store import BaseUIDStoreStash +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionSettings +from ...store.document_store import QueryKeys +from ...store.document_store import UIDPartitionKey +from ...types.uid import UID +from ...util.telemetry import instrument +from ..context import AuthedServiceContext +from ..response import SyftError +from ..service import AbstractService +from ..service import TYPE_TO_SERVICE +from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL +from ..user.user_roles import GUEST_ROLE_LEVEL +from .user_code import UserCodeStatusCollection + + +@instrument +@serializable() +class StatusStash(BaseUIDStoreStash): + object_type = UserCodeStatusCollection + settings: PartitionSettings = PartitionSettings( + name=UserCodeStatusCollection.__canonical_name__, + object_type=UserCodeStatusCollection, + ) + + def __init__(self, store: DocumentStore) -> None: + super().__init__(store) + self.store = store + self.settings = self.settings + self._object_type = self.object_type + + def get_by_uid( + self, credentials: SyftVerifyKey, uid: UID + ) -> Result[UserCodeStatusCollection, str]: + qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) + return self.query_one(credentials=credentials, qks=qks) + + +@instrument +@serializable() +class UserCodeStatusService(AbstractService): + store: DocumentStore + stash: StatusStash + + def __init__(self, store: DocumentStore): + self.store = store + self.stash = StatusStash(store=store) + + @service_method(path="code_status.create", name="create", roles=ADMIN_ROLE_LEVEL) + def create( + self, + context: AuthedServiceContext, + status: UserCodeStatusCollection, + ) -> Union[UserCodeStatusCollection, SyftError]: + result = self.stash.set( + credentials=context.credentials, + obj=status, + ) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) + + @service_method( + path="code_status.get_by_uid", name="get_by_uid", roles=GUEST_ROLE_LEVEL + ) + def get_status( + self, context: AuthedServiceContext, uid: UID + ) -> Union[UserCodeStatusCollection, SyftError]: + """Get the status of a user code item""" + result = self.stash.get_by_uid(context.credentials, uid=uid) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) + + @service_method(path="code_status.get_all", name="get_all", roles=ADMIN_ROLE_LEVEL) + def get_all( + self, context: AuthedServiceContext + ) -> Union[List[UserCodeStatusCollection], SyftError]: + """Get all user code item statuses""" + result = self.stash.get_all(context.credentials) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) + + +TYPE_TO_SERVICE[UserCodeStatusCollection] = UserCodeStatusService diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 6954e2c0dec..b363fbb7340 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -17,10 +17,12 @@ import traceback from typing import Any from typing import Callable +from typing import ClassVar from typing import Dict from typing import Generator from typing import List from typing import Optional +from typing import TYPE_CHECKING from typing import Tuple from typing import Type from typing import Union @@ -29,6 +31,7 @@ # third party from IPython.display import display +from pydantic import field_validator from result import Err from typing_extensions import Self @@ -45,18 +48,13 @@ from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SYFT_OBJECT_VERSION_4 -from ...types.syft_object import SyftHashableObject from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import add_node_uid_for_key -from ...types.transforms import drop from ...types.transforms import generate_id -from ...types.transforms import make_set_default from ...types.transforms import transform from ...types.uid import UID from ...util import options @@ -68,6 +66,8 @@ from ..context import AuthedServiceContext from ..dataset.dataset import Asset from ..job.job_stash import Job +from ..output.output_service import ExecutionOutput +from ..output.output_service import OutputService from ..policy.policy import CustomInputPolicy from ..policy.policy import CustomOutputPolicy from ..policy.policy import EmpyInputPolicy @@ -77,6 +77,7 @@ from ..policy.policy import SingleExecutionExactOutput from ..policy.policy import SubmitUserPolicy from ..policy.policy import UserPolicy +from ..policy.policy import filter_only_uids from ..policy.policy import init_policy from ..policy.policy import load_policy_code from ..policy.policy_service import PolicyService @@ -90,6 +91,10 @@ from .unparse import unparse from .utils import submit_subjobs_code +if TYPE_CHECKING: + # relative + from ...service.sync.diff_state import AttrDiff + UserVerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey) CodeHashPartitionKey = PartitionKey(key="code_hash", type_=str) ServiceFuncNamePartitionKey = PartitionKey(key="service_func_name", type_=str) @@ -98,28 +103,6 @@ PyCodeObject = Any -def extract_uids(kwargs: Dict[str, Any]) -> Dict[str, UID]: - # relative - from ...types.twin_object import TwinObject - from ..action.action_object import ActionObject - - uid_kwargs = {} - for k, v in kwargs.items(): - uid = v - if isinstance(v, ActionObject): - uid = v.id - if isinstance(v, TwinObject): - uid = v.id - if isinstance(v, Asset): - uid = v.action_id - - if not isinstance(uid, UID): - raise Exception(f"Input {k} must have a UID not {type(v)}") - - uid_kwargs[k] = uid - return uid_kwargs - - @serializable() class UserCodeStatus(Enum): PENDING = "pending" @@ -130,15 +113,32 @@ def __hash__(self) -> int: return hash(self.value) -# User Code status context for multiple approvals -# To make nested dicts hashable for mongodb -# as status is in attr_searchable -@serializable(attrs=["status_dict"]) -class UserCodeStatusCollection(SyftHashableObject): +@serializable() +class UserCodeStatusCollection(SyftObject): + __canonical_name__ = "UserCodeStatusCollection" + __version__ = SYFT_OBJECT_VERSION_1 + + __repr_attrs__ = ["approved", "status_dict"] + status_dict: Dict[NodeIdentity, Tuple[UserCodeStatus, str]] = {} + user_code_link: LinkedObject + + def get_diffs(self, ext_obj: Any) -> List[AttrDiff]: + # relative + from ...service.sync.diff_state import AttrDiff - def __init__(self, status_dict: Dict): - self.status_dict = status_dict + diff_attrs = [] + status = list(self.status_dict.values())[0] + ext_status = list(ext_obj.status_dict.values())[0] + + if status != ext_status: + diff_attr = AttrDiff( + attr_name="status_dict", + low_attr=status, + high_attr=ext_status, + ) + diff_attrs.append(diff_attr) + return diff_attrs def __repr__(self) -> str: return str(self.status_dict) @@ -252,103 +252,8 @@ def mutate( message="Cannot Modify Status as the Domain's data is not included in the request" ) - -@serializable() -class UserCodeV1(SyftObject): - # version - __canonical_name__ = "UserCode" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - node_uid: Optional[UID] - user_verify_key: SyftVerifyKey - raw_code: str - input_policy_type: Union[Type[InputPolicy], UserPolicy] - input_policy_init_kwargs: Optional[Dict[Any, Any]] = None - input_policy_state: bytes = b"" - output_policy_type: Union[Type[OutputPolicy], UserPolicy] - output_policy_init_kwargs: Optional[Dict[Any, Any]] = None - output_policy_state: bytes = b"" - parsed_code: str - service_func_name: str - unique_func_name: str - user_unique_func_name: str - code_hash: str - signature: inspect.Signature - status: UserCodeStatusCollection - input_kwargs: List[str] - enclave_metadata: Optional[EnclaveMetadata] = None - submit_time: Optional[DateTime] - - __attr_searchable__ = [ - "user_verify_key", - "status", - "service_func_name", - "code_hash", - ] - - -@serializable() -class UserCodeV2(SyftObject): - # version - __canonical_name__ = "UserCode" - __version__ = SYFT_OBJECT_VERSION_2 - - id: UID - node_uid: Optional[UID] - user_verify_key: SyftVerifyKey - raw_code: str - input_policy_type: Union[Type[InputPolicy], UserPolicy] - input_policy_init_kwargs: Optional[Dict[Any, Any]] = None - input_policy_state: bytes = b"" - output_policy_type: Union[Type[OutputPolicy], UserPolicy] - output_policy_init_kwargs: Optional[Dict[Any, Any]] = None - output_policy_state: bytes = b"" - parsed_code: str - service_func_name: str - unique_func_name: str - user_unique_func_name: str - code_hash: str - signature: inspect.Signature - status: UserCodeStatusCollection - input_kwargs: List[str] - enclave_metadata: Optional[EnclaveMetadata] = None - submit_time: Optional[DateTime] - uses_domain = False # tracks if the code calls domain.something, variable is set during parsing - nested_requests: Dict[str, str] = {} - nested_codes: Optional[Dict[str, Tuple[LinkedObject, Dict]]] = {} - - -@serializable() -class UserCodeV3(SyftObject): - # version - __canonical_name__ = "UserCode" - __version__ = SYFT_OBJECT_VERSION_3 - - id: UID - node_uid: Optional[UID] - user_verify_key: SyftVerifyKey - raw_code: str - input_policy_type: Union[Type[InputPolicy], UserPolicy] - input_policy_init_kwargs: Optional[Dict[Any, Any]] = None - input_policy_state: bytes = b"" - output_policy_type: Union[Type[OutputPolicy], UserPolicy] - output_policy_init_kwargs: Optional[Dict[Any, Any]] = None - output_policy_state: bytes = b"" - parsed_code: str - service_func_name: str - unique_func_name: str - user_unique_func_name: str - code_hash: str - signature: inspect.Signature - status: UserCodeStatusCollection - input_kwargs: List[str] - enclave_metadata: Optional[EnclaveMetadata] = None - submit_time: Optional[DateTime] - uses_domain = False # tracks if the code calls domain.something, variable is set during parsing - nested_requests: Dict[str, str] = {} - nested_codes: Optional[Dict[str, Tuple[LinkedObject, Dict]]] = {} - worker_pool_name: Optional[str] + def get_sync_dependencies(self, api: Any = None) -> List[UID]: + return [self.user_code_link.object_uid] @serializable() @@ -358,7 +263,7 @@ class UserCode(SyftObject): __version__ = SYFT_OBJECT_VERSION_4 id: UID - node_uid: Optional[UID] + node_uid: Optional[UID] = None user_verify_key: SyftVerifyKey raw_code: str input_policy_type: Union[Type[InputPolicy], UserPolicy] @@ -373,28 +278,37 @@ class UserCode(SyftObject): user_unique_func_name: str code_hash: str signature: inspect.Signature - status: UserCodeStatusCollection + status_link: LinkedObject input_kwargs: List[str] enclave_metadata: Optional[EnclaveMetadata] = None - submit_time: Optional[DateTime] - uses_domain = False # tracks if the code calls domain.something, variable is set during parsing + submit_time: Optional[DateTime] = None + uses_domain: bool = False # tracks if the code calls domain.something, variable is set during parsing nested_codes: Optional[Dict[str, Tuple[LinkedObject, Dict]]] = {} - worker_pool_name: Optional[str] + worker_pool_name: Optional[str] = None - __attr_searchable__ = [ + __attr_searchable__: ClassVar[List[str]] = [ "user_verify_key", - "status", "service_func_name", "code_hash", ] - __attr_unique__: list = [] - __repr_attrs__ = [ + __attr_unique__: ClassVar[List[str]] = [] + __repr_attrs__: ClassVar[List[str]] = [ "service_func_name", "input_owners", "code_status", "worker_pool_name", ] + __exclude_sync_diff_attrs__: ClassVar[List[str]] = [ + "node_uid", + "input_policy_type", + "input_policy_init_kwargs", + "input_policy_state", + "output_policy_type", + "output_policy_init_kwargs", + "output_policy_state", + ] + def __setattr__(self, key: str, value: Any) -> None: # Get the attribute from the class, it might be a descriptor or None attr = getattr(type(self), key, None) @@ -428,6 +342,20 @@ def _coll_repr_(self) -> Dict[str, Any]: "Submit time": str(self.submit_time), } + @property + def status(self) -> Union[UserCodeStatusCollection, SyftError]: + # Clientside only + res = self.status_link.resolve + return res + + def get_status( + self, context: AuthedServiceContext + ) -> Union[UserCodeStatusCollection, SyftError]: + status = self.status_link.resolve_with_context(context) + if status.is_err(): + return SyftError(message=status.err()) + return status.ok() + @property def is_enclave_code(self) -> bool: return self.enclave_metadata is not None @@ -476,7 +404,15 @@ def code_status(self) -> list: def input_policy(self) -> Optional[InputPolicy]: if not self.status.approved: return None + return self._get_input_policy() + + def get_input_policy(self, context: AuthedServiceContext) -> Optional[InputPolicy]: + status = self.get_status(context) + if not status.approved: + return None + return self._get_input_policy() + def _get_input_policy(self) -> Optional[InputPolicy]: if len(self.input_policy_state) == 0: input_policy = None if ( @@ -517,8 +453,11 @@ def input_policy(self) -> Optional[InputPolicy]: print(f"Failed to deserialize custom input policy state. {e}") return None - @input_policy.setter - def input_policy(self, value: Any) -> None: + def is_output_policy_approved(self, context: AuthedServiceContext) -> bool: + return self.get_status(context).approved + + @input_policy.setter # type: ignore + def input_policy(self, value: Any) -> None: # type: ignore if isinstance(value, InputPolicy): self.input_policy_state = _serialize(value, to_bytes=True) elif (isinstance(value, bytes) and len(value) == 0) or value is None: @@ -527,14 +466,21 @@ def input_policy(self, value: Any) -> None: raise Exception(f"You can't set {type(value)} as input_policy_state") @property - def output_policy_approved(self) -> bool: - return self.status.approved - - @property - def output_policy(self) -> Optional[OutputPolicy]: + def output_policy(self) -> Optional[OutputPolicy]: # type: ignore if not self.status.approved: return None + return self._get_output_policy() + def get_output_policy( + self, context: AuthedServiceContext + ) -> Optional[OutputPolicy]: + if not self.get_status(context).approved: + return None + return self._get_output_policy() + + def _get_output_policy(self) -> Optional[OutputPolicy]: + # if not self.status.approved: + # return None if len(self.output_policy_state) == 0: output_policy = None if isinstance(self.output_policy_type, type) and issubclass( @@ -553,6 +499,8 @@ def output_policy(self) -> Optional[OutputPolicy]: ) if output_policy is not None: + output_policy.syft_node_location = self.syft_node_location + output_policy.syft_client_verify_key = self.syft_client_verify_key output_blob = _serialize(output_policy, to_bytes=True) self.output_policy_state = output_blob return output_policy @@ -565,8 +513,8 @@ def output_policy(self) -> Optional[OutputPolicy]: print(f"Failed to deserialize custom output policy state. {e}") return None - @output_policy.setter - def output_policy(self, value: Any) -> None: + @output_policy.setter # type: ignore + def output_policy(self, value: Any) -> None: # type: ignore if isinstance(value, OutputPolicy): self.output_policy_state = _serialize(value, to_bytes=True) elif (isinstance(value, bytes) and len(value) == 0) or value is None: @@ -574,6 +522,55 @@ def output_policy(self, value: Any) -> None: else: raise Exception(f"You can't set {type(value)} as output_policy_state") + @property + def output_history(self) -> Union[List[ExecutionOutput], SyftError]: + api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + if api is None: + return SyftError( + message=f"Can't access the api. You must login to {self.syft_node_location}" + ) + return api.services.output.get_by_user_code_id(self.id) + + def get_output_history( + self, context: AuthedServiceContext + ) -> Union[List[ExecutionOutput], SyftError]: + if not self.get_status(context).approved: + return SyftError( + message="Execution denied, Please wait for the code to be approved" + ) + node = cast(AbstractNode, context.node) + output_service = cast(OutputService, node.get_service("outputservice")) + return output_service.get_by_user_code_id(context, self.id) + + def apply_output( + self, + context: AuthedServiceContext, + outputs: Any, + job_id: Optional[UID] = None, + ) -> Union[ExecutionOutput, SyftError]: + output_policy = self.get_output_policy(context) + if output_policy is None: + return SyftError( + message="You must wait for the output policy to be approved" + ) + + output_ids = filter_only_uids(outputs) + context.node = cast(AbstractNode, context.node) + output_service = context.node.get_service("outputservice") + output_service = cast(OutputService, output_service) + execution_result = output_service.create( + context, + user_code_id=self.id, + output_ids=output_ids, + executing_user_verify_key=self.user_verify_key, + job_id=job_id, + output_policy_id=output_policy.id, + ) + if isinstance(execution_result, SyftError): + return execution_result + + return execution_result + @property def byte_code(self) -> Optional[PyCodeObject]: return compile_byte_code(self.parsed_code) @@ -583,6 +580,10 @@ def get_results(self) -> Any: from ...client.api import APIRegistry api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + if api is None: + return SyftError( + message=f"Can't access the api. You must login to {self.node_uid}" + ) return api.services.code.get_results(self) @property @@ -612,6 +613,15 @@ def assets(self) -> List[Asset]: all_assets += assets return all_assets + def get_sync_dependencies(self, api: Any = None) -> Union[List[UID], SyftError]: + dependencies = [] + + if self.nested_codes is not None: + nested_code_ids = [link.object_uid for link in self.nested_codes.values()] + dependencies.extend(nested_code_ids) + + return dependencies + @property def unsafe_function(self) -> Optional[Callable]: warning = SyftWarning( @@ -695,7 +705,7 @@ def _inner_repr(self, level: int = 0) -> str: return md - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return as_markdown_code(self._inner_repr()) @property @@ -713,60 +723,13 @@ def show_code_cell(self) -> None: ip.set_next_input(warning_message + self.raw_code) -@migrate(UserCodeV3, UserCodeV2) -def downgrade_usercode_v3_to_v2() -> list[Callable]: - return [ - drop("worker_pool_name"), - ] - - -@migrate(UserCodeV2, UserCodeV3) -def upgrade_usercode_v2_to_v3() -> list[Callable]: - return [ - make_set_default("worker_pool_name", None), - ] - - -@migrate(UserCode, UserCodeV3) -def downgrade_usercode_v4_to_v3() -> list[Callable]: - return [ - make_set_default("nested_requests", {}), - ] - - -@migrate(UserCodeV3, UserCode) -def upgrade_usercode_v3_to_v4() -> list[Callable]: - return [ - drop("nested_requests"), - ] - - -@serializable(without=["local_function"]) -class SubmitUserCodeV2(SyftObject): - # version - __canonical_name__ = "SubmitUserCode" - __version__ = SYFT_OBJECT_VERSION_2 - - id: Optional[UID] - code: str - func_name: str - signature: inspect.Signature - input_policy_type: Union[SubmitUserPolicy, UID, Type[InputPolicy]] - input_policy_init_kwargs: Optional[Dict[Any, Any]] = {} - output_policy_type: Union[SubmitUserPolicy, UID, Type[OutputPolicy]] - output_policy_init_kwargs: Optional[Dict[Any, Any]] = {} - local_function: Optional[Callable] - input_kwargs: List[str] - enclave_metadata: Optional[EnclaveMetadata] = None - - @serializable(without=["local_function"]) class SubmitUserCode(SyftObject): # version __canonical_name__ = "SubmitUserCode" __version__ = SYFT_OBJECT_VERSION_3 - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] code: str func_name: str signature: inspect.Signature @@ -774,13 +737,20 @@ class SubmitUserCode(SyftObject): input_policy_init_kwargs: Optional[Dict[Any, Any]] = {} output_policy_type: Union[SubmitUserPolicy, UID, Type[OutputPolicy]] output_policy_init_kwargs: Optional[Dict[Any, Any]] = {} - local_function: Optional[Callable] + local_function: Optional[Callable] = None input_kwargs: List[str] enclave_metadata: Optional[EnclaveMetadata] = None worker_pool_name: Optional[str] = None __repr_attrs__ = ["func_name", "code"] + @field_validator("output_policy_init_kwargs", mode="after") + @classmethod + def add_output_policy_ids(cls, values: Any) -> Any: + if isinstance(values, dict) and "id" not in values: + values["id"] = UID() + return values + @property def kwargs(self) -> Optional[dict[Any, Any]]: return self.input_policy_init_kwargs @@ -860,7 +830,10 @@ def _ephemeral_node_call( # node_uid=node_id.node_id, user_verify_key=node_id.verify_key # ) api = APIRegistry.get_by_recent_node_uid(node_uid=node_id.node_id) - + if api is None: + return SyftError( + f"Can't access the api. You must login to {node_id.node_id}" + ) # Creating TwinObject from the ids of the kwargs # Maybe there are some corner cases where this is not enough # And need only ActionObjects @@ -922,20 +895,6 @@ def input_owner_verify_keys(self) -> Optional[List[str]]: return None -@migrate(SubmitUserCode, SubmitUserCodeV2) -def downgrade_submitusercode_v3_to_v2() -> list[Callable]: - return [ - drop("worker_pool_name"), - ] - - -@migrate(SubmitUserCodeV2, SubmitUserCode) -def upgrade_submitusercode_v2_to_v3() -> list[Callable]: - return [ - make_set_default("worker_pool_name", None), - ] - - class ArgumentType(Enum): REAL = 1 MOCK = 2 @@ -1023,15 +982,16 @@ def decorator(f: Any) -> SubmitUserCode: def generate_unique_func_name(context: TransformContext) -> TransformContext: - code_hash = context.output["code_hash"] - service_func_name = context.output["func_name"] - context.output["service_func_name"] = service_func_name - func_name = f"user_func_{service_func_name}_{context.credentials}_{code_hash}" - user_unique_func_name = ( - f"user_func_{service_func_name}_{context.credentials}_{time.time()}" - ) - context.output["unique_func_name"] = func_name - context.output["user_unique_func_name"] = user_unique_func_name + if context.output is not None: + code_hash = context.output["code_hash"] + service_func_name = context.output["func_name"] + context.output["service_func_name"] = service_func_name + func_name = f"user_func_{service_func_name}_{context.credentials}_{code_hash}" + user_unique_func_name = ( + f"user_func_{service_func_name}_{context.credentials}_{time.time()}" + ) + context.output["unique_func_name"] = func_name + context.output["user_unique_func_name"] = user_unique_func_name return context @@ -1053,7 +1013,7 @@ def process_code( f.decorator_list = [] call_args = function_input_kwargs - if "domain" in function_input_kwargs: + if "domain" in function_input_kwargs and context.output is not None: context.output["uses_domain"] = True call_stmt_keywords = [ast.keyword(arg=i, value=[ast.Name(id=i)]) for i in call_args] call_stmt = ast.Assign( @@ -1080,7 +1040,10 @@ def process_code( def new_check_code(context: TransformContext) -> TransformContext: - # TODO remove this tech debt hack + # TODO: remove this tech debt hack + if context.output is None: + return context + input_kwargs = context.output["input_policy_init_kwargs"] node_view_workaround = False for k in input_kwargs.keys(): @@ -1108,25 +1071,28 @@ def new_check_code(context: TransformContext) -> TransformContext: def locate_launch_jobs(context: TransformContext) -> TransformContext: - nested_codes = {} - tree = ast.parse(context.output["raw_code"]) - - # look for domain arg - if "domain" in [arg.arg for arg in tree.body[0].args.args]: - v = LaunchJobVisitor() - v.visit(tree) - nested_calls = v.nested_calls - user_code_service = context.node.get_service("usercodeService") - for call in nested_calls: - user_codes = user_code_service.get_by_service_name(context, call) - if isinstance(user_codes, SyftError): - raise Exception(user_codes.message) - # TODO: Not great - user_code = user_codes[-1] - user_code_link = LinkedObject.from_obj(user_code, node_uid=context.node.id) - - nested_codes[call] = (user_code_link, user_code.nested_codes) - context.output["nested_codes"] = nested_codes + if context.node is None: + raise ValueError(f"context {context}'s node is None") + if context.output is not None: + nested_codes = {} + tree = ast.parse(context.output["raw_code"]) + # look for domain arg + if "domain" in [arg.arg for arg in tree.body[0].args.args]: + v = LaunchJobVisitor() + v.visit(tree) + nested_calls = v.nested_calls + user_code_service = context.node.get_service("usercodeService") + for call in nested_calls: + user_codes = user_code_service.get_by_service_name(context, call) + if isinstance(user_codes, SyftError): + raise Exception(user_codes.message) + # TODO: Not great + user_code = user_codes[-1] + user_code_link = LinkedObject.from_obj( + user_code, node_uid=context.node.id + ) + nested_codes[call] = (user_code_link, user_code.nested_codes) + context.output["nested_codes"] = nested_codes return context @@ -1139,9 +1105,12 @@ def compile_byte_code(parsed_code: str) -> Optional[PyCodeObject]: def compile_code(context: TransformContext) -> TransformContext: + if context.output is None: + return context + byte_code = compile_byte_code(context.output["parsed_code"]) if byte_code is None: - raise Exception( + raise ValueError( "Unable to compile byte code from parsed code. " + context.output["parsed_code"] ) @@ -1149,81 +1118,123 @@ def compile_code(context: TransformContext) -> TransformContext: def hash_code(context: TransformContext) -> TransformContext: + if context.output is None: + return context + code = context.output["code"] context.output["raw_code"] = code code_hash = hashlib.sha256(code.encode("utf8")).hexdigest() context.output["code_hash"] = code_hash + return context def add_credentials_for_key(key: str) -> Callable: def add_credentials(context: TransformContext) -> TransformContext: - context.output[key] = context.credentials + if context.output is not None: + context.output[key] = context.credentials return context return add_credentials def check_policy(policy: Any, context: TransformContext) -> TransformContext: - policy_service = context.node.get_service(PolicyService) - if isinstance(policy, SubmitUserPolicy): - policy = policy.to(UserPolicy, context=context) - elif isinstance(policy, UID): - policy = policy_service.get_policy_by_uid(context, policy) - if policy.is_ok(): - policy = policy.ok() - + if context.node is not None: + policy_service = context.node.get_service(PolicyService) + if isinstance(policy, SubmitUserPolicy): + policy = policy.to(UserPolicy, context=context) + elif isinstance(policy, UID): + policy = policy_service.get_policy_by_uid(context, policy) + if policy.is_ok(): + policy = policy.ok() return policy def check_input_policy(context: TransformContext) -> TransformContext: + if context.output is None: + return context + ip = context.output["input_policy_type"] ip = check_policy(policy=ip, context=context) context.output["input_policy_type"] = ip + return context def check_output_policy(context: TransformContext) -> TransformContext: - op = context.output["output_policy_type"] - op = check_policy(policy=op, context=context) - context.output["output_policy_type"] = op + if context.output is not None: + op = context.output["output_policy_type"] + op = check_policy(policy=op, context=context) + context.output["output_policy_type"] = op return context -def add_custom_status(context: TransformContext) -> TransformContext: +def create_code_status(context: TransformContext) -> TransformContext: + # relative + from .user_code_service import UserCodeService + + if context.node is None: + raise ValueError(f"{context}'s node is None") + + if context.output is None: + return context + input_keys = list(context.output["input_policy_init_kwargs"].keys()) + code_link = LinkedObject.from_uid( + context.output["id"], + UserCode, + service_type=UserCodeService, + node_uid=context.node.id, + ) if context.node.node_type == NodeType.DOMAIN: node_identity = NodeIdentity( node_name=context.node.name, node_id=context.node.id, verify_key=context.node.signing_key.verify_key, ) - context.output["status"] = UserCodeStatusCollection( - status_dict={node_identity: (UserCodeStatus.PENDING, "")} + status = UserCodeStatusCollection( + status_dict={node_identity: (UserCodeStatus.PENDING, "")}, + user_code_link=code_link, ) - # if node_identity in input_keys or len(input_keys) == 0: - # context.output["status"] = UserCodeStatusContext( - # base_dict={node_identity: UserCodeStatus.SUBMITTED} - # ) - # else: - # raise ValueError(f"Invalid input keys: {input_keys} for {node_identity}") + elif context.node.node_type == NodeType.ENCLAVE: status_dict = {key: (UserCodeStatus.PENDING, "") for key in input_keys} - context.output["status"] = UserCodeStatusCollection(status_dict=status_dict) + status = UserCodeStatusCollection( + status_dict=status_dict, + user_code_link=code_link, + ) else: raise NotImplementedError( f"Invalid node type:{context.node.node_type} for code submission" ) + + res = context.node.get_service("usercodestatusservice").create(context, status) + # relative + from .status_service import UserCodeStatusService + + # TODO error handling in transform functions + if not isinstance(res, SyftError): + context.output["status_link"] = LinkedObject.from_uid( + res.id, + UserCodeStatusCollection, + service_type=UserCodeStatusService, + node_uid=context.node.id, + ) return context def add_submit_time(context: TransformContext) -> TransformContext: - context.output["submit_time"] = DateTime.now() + if context.output: + context.output["submit_time"] = DateTime.now() return context def set_default_pool_if_empty(context: TransformContext) -> TransformContext: - if context.output.get("worker_pool_name", None) is None: + if ( + context.node + and context.output + and context.output.get("worker_pool_name", None) is None + ): default_pool = context.node.get_default_worker_pool() context.output["worker_pool_name"] = default_pool.name return context @@ -1240,7 +1251,7 @@ def submit_user_code_to_user_code() -> List[Callable]: new_check_code, locate_launch_jobs, add_credentials_for_key("user_verify_key"), - add_custom_status, + create_code_status, add_node_uid_for_key("node_uid"), add_submit_time, set_default_pool_if_empty, @@ -1257,7 +1268,20 @@ class UserCodeExecutionResult(SyftObject): user_code_id: UID stdout: str stderr: str - result: Any + result: Any = None + + +@serializable() +class UserCodeExecutionOutput(SyftObject): + # version + __canonical_name__ = "UserCodeExecutionOutput" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID + user_code_id: UID + stdout: str + stderr: str + result: Any = None class SecureContext: @@ -1441,19 +1465,24 @@ def to_str(arg: Any) -> str: try: result = eval(evil_string, _globals, _locals) # nosec except Exception as e: + error_msg = traceback_from_error(e, code_item) if context.job is not None: - error_msg = traceback_from_error(e, code_item) time = datetime.datetime.now().strftime("%d/%m/%y %H:%M:%S") original_print( f"{time} EXCEPTION LOG ({job_id}):\n{error_msg}", file=sys.stderr ) - if context.node is not None: - log_service = context.node.get_service("LogService") - log_service.append(context=context, uid=log_id, new_err=error_msg) - result = Err( + if context.node is not None: + log_service = context.node.get_service("LogService") + log_service.append(context=context, uid=log_id, new_err=error_msg) + + result_message = ( f"Exception encountered while running {code_item.service_func_name}" ", please contact the Node Admin for more info." ) + if context.dev_mode: + result_message += error_msg + + result = Err(result_message) # reset print print = original_print @@ -1462,7 +1491,7 @@ def to_str(arg: Any) -> str: sys.stdout = stdout_ sys.stderr = stderr_ - return UserCodeExecutionResult( + return UserCodeExecutionOutput( user_code_id=code_item.id, stdout=str(stdout.getvalue()), stderr=str(stderr.getvalue()), @@ -1511,11 +1540,18 @@ def traceback_from_error(e: Exception, code: UserCode) -> str: return error_msg -def load_approved_policy_code(user_code_items: List[UserCode]) -> Any: +def load_approved_policy_code( + user_code_items: List[UserCode], context: Optional[AuthedServiceContext] +) -> Any: """Reload the policy code in memory for user code that is approved.""" try: for user_code in user_code_items: - if user_code.status.approved: + if context is None: + status = user_code.status + else: + status = user_code.get_status(context) + + if status.approved: if isinstance(user_code.input_policy_type, UserPolicy): load_policy_code(user_code.input_policy_type) if isinstance(user_code.output_policy_type, UserPolicy): diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index f6a683cdd89..c91792c28e7 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -1,5 +1,4 @@ # stdlib -from copy import deepcopy from typing import Any from typing import Dict from typing import List @@ -28,6 +27,7 @@ from ..action.action_permissions import ActionPermission from ..context import AuthedServiceContext from ..network.routes import route_to_connection +from ..output.output_service import ExecutionOutput from ..policy.policy import OutputPolicy from ..request.request import Request from ..request.request import SubmitRequest @@ -40,6 +40,7 @@ from ..service import SERVICE_TO_TYPES from ..service import TYPE_TO_SERVICE from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL from ..user.user_roles import ServiceRole @@ -72,27 +73,22 @@ def submit( def _submit( self, context: AuthedServiceContext, code: Union[UserCode, SubmitUserCode] - ) -> Result: + ) -> Result[UserCode, str]: if not isinstance(code, UserCode): - code = code.to(UserCode, context=context) + code = code.to(UserCode, context=context) # type: ignore[unreachable] + result = self.stash.set(context.credentials, code) return result - @service_method( - path="code.sync_code_from_request", - name="sync_code_from_request", - roles=GUEST_ROLE_LEVEL, - ) - def sync_code_from_request( - self, - context: AuthedServiceContext, - request: Request, + @service_method(path="code.delete", name="delete", roles=ADMIN_ROLE_LEVEL) + def delete( + self, context: AuthedServiceContext, uid: UID ) -> Union[SyftSuccess, SyftError]: - """Re-submit request from a different node""" - - # This request is from a different node, ensure worker pool is not set - code: UserCode = deepcopy(request.code) - return self.submit(context=context, code=code) + """Delete User Code""" + result = self.stash.delete_by_uid(context.credentials, uid) + if result.is_err(): + return SyftError(message=str(result.err())) + return SyftSuccess(message="User Code Deleted") @service_method( path="code.get_by_service_func_name", @@ -181,10 +177,12 @@ def _request_code_execution_inner( ] ) - linked_obj = LinkedObject.from_obj(user_code, node_uid=context.node.id) + code_link = LinkedObject.from_obj(user_code, node_uid=context.node.id) CODE_EXECUTE = UserCodeStatusChange( - value=UserCodeStatus.APPROVED, linked_obj=linked_obj + value=UserCodeStatus.APPROVED, + linked_obj=user_code.status_link, + linked_user_code=code_link, ) changes = [CODE_EXECUTE] @@ -258,7 +256,7 @@ def load_user_code(self, context: AuthedServiceContext) -> None: result = self.stash.get_all(credentials=context.credentials) if result.is_ok(): user_code_items = result.ok() - load_approved_policy_code(user_code_items=user_code_items) + load_approved_policy_code(user_code_items=user_code_items, context=context) @service_method(path="code.get_results", name="get_results", roles=GUEST_ROLE_LEVEL) def get_results( @@ -280,6 +278,10 @@ def get_results( connection=connection, credentials=context.node.signing_key, ) + if enclave_client.code is None: + return SyftError( + message=f"{enclave_client} can't access the user code api" + ) outputs = enclave_client.code.get_results(code.id) if isinstance(outputs, list): for output in outputs: @@ -290,15 +292,19 @@ def get_results( # if the current node is the enclave else: - if not code.status.approved: + if not code.get_status(context.as_root_context()).approved: return code.status.get_status_message() - if (output_policy := code.output_policy) is None: - return SyftError(message=f"Output policy not approved {code}") + output_history = code.get_output_history( + context=context.as_root_context() + ) + if isinstance(output_history, SyftError): + return output_history - if len(output_policy.output_history) > 0: + if len(output_history) > 0: return resolve_outputs( - context=context, output_ids=output_policy.last_output_ids + context=context, + output_ids=output_history[-1].output_ids, ) else: return SyftError(message="No results available") @@ -306,17 +312,22 @@ def get_results( return SyftError(message="Endpoint only supported for enclave code") def is_execution_allowed( - self, code: UserCode, context: AuthedServiceContext, output_policy: OutputPolicy + self, + code: UserCode, + context: AuthedServiceContext, + output_policy: Optional[OutputPolicy], ) -> Union[bool, SyftSuccess, SyftError, SyftNotReady]: - if not code.status.approved: + if not code.get_status(context).approved: return code.status.get_status_message() # Check if the user has permission to execute the code. elif not (has_code_permission := self.has_code_permission(code, context)): return has_code_permission - elif not code.output_policy_approved: + elif not code.is_output_policy_approved(context): return SyftError("Output policy not approved", code) - elif not output_policy.valid: - return output_policy.valid + + policy_is_valid = output_policy is not None and output_policy._is_valid(context) + if not policy_is_valid: + return policy_is_valid else: return True @@ -394,35 +405,38 @@ def _call( override_execution_permission = ( context.has_execute_permissions or context.role == ServiceRole.ADMIN ) + # Override permissions bypasses the cache, since we do not check in/out policies skip_fill_cache = override_execution_permission # We do not read from output policy cache if there are mock arguments skip_read_cache = len(self.keep_owned_kwargs(kwargs, context)) > 0 # Check output policy - output_policy = code.output_policy - if output_policy is not None and not override_execution_permission: - can_execute: Any = self.is_execution_allowed( - code=code, context=context, output_policy=output_policy + output_policy = code.get_output_policy(context) + if not override_execution_permission: + output_history = code.get_output_history(context=context) + if isinstance(output_history, SyftError): + return Err(output_history.message) + can_execute = self.is_execution_allowed( + code=code, + context=context, + output_policy=output_policy, ) if not can_execute: - if not code.output_policy_approved: + if not code.is_output_policy_approved(context): return Err( "Execution denied: Your code is waiting for approval" ) - if not (is_valid := output_policy.valid): - if ( - len(output_policy.output_history) > 0 - and not skip_read_cache - ): + if not (is_valid := output_policy._is_valid(context)): # type: ignore + if len(output_history) > 0 and not skip_read_cache: result = resolve_outputs( context=context, - output_ids=output_policy.last_output_ids, + output_ids=output_history[-1].output_ids, ) return Ok(result.as_empty()) else: return is_valid.to_result() - return can_execute.to_result() + return can_execute.to_result() # type: ignore # Execute the code item context.node = cast(AbstractNode, context.node) @@ -441,7 +455,7 @@ def _call( result_action_object = result_action_object.ok() output_result = action_service.set_result_to_store( - result_action_object, context, code.output_policy + result_action_object, context, code.get_output_policy(context) ) if output_result.is_err(): @@ -453,14 +467,11 @@ def _call( # this currently only works for nested syft_functions # and admins executing on high side (TODO, decide if we want to increment counter) if not skip_fill_cache and output_policy is not None: - output_policy.apply_output(context=context, outputs=result) - code.output_policy = output_policy - if not ( - update_success := self.update_code_state( - context=context, code_item=code - ) - ): - return update_success.to_result() + res = code.apply_output( + context=context, outputs=result, job_id=context.job_id + ) + if isinstance(res, SyftError): + return Err(res.message) has_result_read_permission = context.extra_kwargs.get( "has_result_read_permission", False ) @@ -497,6 +508,27 @@ def has_code_permission( ) return SyftSuccess(message="you have permission") + @service_method( + path="code.apply_output", name="apply_output", roles=GUEST_ROLE_LEVEL + ) + def apply_output( + self, + context: AuthedServiceContext, + user_code_id: UID, + outputs: Any, + job_id: Optional[UID] = None, + ) -> Union[ExecutionOutput, SyftError]: + code_result = self.stash.get_by_uid(context.credentials, user_code_id) + if code_result.is_err(): + return SyftError(message=code_result.err()) + + code: UserCode = code_result.ok() + if not code.get_status(context).approved: + return SyftError(message="Code is not approved") + + res = code.apply_output(context=context, outputs=outputs, job_id=job_id) + return res + def resolve_outputs( context: AuthedServiceContext, diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index ed9b2655a75..fcd36d06d26 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -70,8 +70,14 @@ def _repr_html_(self) -> str: # rows = sorted(rows, key=lambda x: x["Version"]) return create_table_template(rows, "CodeHistory", table_icon=None) - def __getitem__(self, index: int) -> Union[UserCode, SyftError]: + def __getitem__(self, index: Union[int, str]) -> Union[UserCode, SyftError]: + if isinstance(index, str): + raise TypeError(f"index {index} must be an integer, not a string") api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + if api is None: + return SyftError( + message=f"Can't access the api. You must login to {self.node_uid}" + ) if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0: return SyftError( message="For security concerns we do not allow negative indexing. \ @@ -97,7 +103,9 @@ def _repr_html_(self) -> str: def add_func(self, versions: CodeHistoryView) -> Any: self.code_versions[versions.service_func_name] = versions - def __getitem__(self, name: str) -> Any: + def __getitem__(self, name: Union[str, int]) -> Any: + if isinstance(name, int): + raise TypeError("name argument ({name}) must be a string, not an integer.") return self.code_versions[name] def __getattr__(self, name: str) -> Any: @@ -123,8 +131,12 @@ class UsersCodeHistoriesDict(SyftObject): def available_keys(self) -> str: return json.dumps(self.user_dict, sort_keys=True, indent=4) - def __getitem__(self, key: int) -> Union[CodeHistoriesDict, SyftError]: + def __getitem__(self, key: Union[str, int]) -> Union[CodeHistoriesDict, SyftError]: api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + if api is None: + return SyftError( + message=f"Can't access the api. You must login to {self.node_uid}" + ) return api.services.code_history.get_history_for_user(key) def _repr_html_(self) -> str: diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 71be2d0da91..994c39f31de 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -55,8 +55,7 @@ def submit_version( if result.is_err(): return SyftError(message=str(result.err())) code = result.ok() - - elif isinstance(code, UserCode): + elif isinstance(code, UserCode): # type: ignore[unreachable] result = user_code_service.get_by_uid(context=context, uid=code.id) if isinstance(result, SyftError): return result diff --git a/packages/syft/src/syft/service/context.py b/packages/syft/src/syft/service/context.py index 5ac07649e37..a26bde54efa 100644 --- a/packages/syft/src/syft/service/context.py +++ b/packages/syft/src/syft/service/context.py @@ -1,4 +1,5 @@ # stdlib +from typing import Any from typing import Dict from typing import List from typing import Optional @@ -24,8 +25,9 @@ class NodeServiceContext(Context, SyftObject): __canonical_name__ = "NodeServiceContext" __version__ = SYFT_OBJECT_VERSION_1 - id: Optional[UID] - node: Optional[AbstractNode] + + id: Optional[UID] = None # type: ignore[assignment] + node: Optional[AbstractNode] = None class AuthedServiceContext(NodeServiceContext): @@ -34,10 +36,14 @@ class AuthedServiceContext(NodeServiceContext): credentials: SyftVerifyKey role: ServiceRole = ServiceRole.NONE - job_id: Optional[UID] + job_id: Optional[UID] = None extra_kwargs: Dict = {} has_execute_permissions: bool = False + @property + def dev_mode(self) -> Any: + return self.node.dev_mode # type: ignore + def capabilities(self) -> List[ServiceRoleCapability]: return ROLE_TO_CAPABILITIES.get(self.role, []) @@ -68,7 +74,7 @@ class UnauthedServiceContext(NodeServiceContext): __version__ = SYFT_OBJECT_VERSION_1 login_credentials: UserLoginCredentials - node: Optional[AbstractNode] + node: Optional[AbstractNode] = None role: ServiceRole = ServiceRole.NONE @@ -77,8 +83,8 @@ class ChangeContext(SyftBaseObject): __version__ = SYFT_OBJECT_VERSION_1 node: Optional[AbstractNode] = None - approving_user_credentials: Optional[SyftVerifyKey] - requesting_user_credentials: Optional[SyftVerifyKey] + approving_user_credentials: Optional[SyftVerifyKey] = None + requesting_user_credentials: Optional[SyftVerifyKey] = None extra_kwargs: Dict = {} @classmethod diff --git a/packages/syft/src/syft/service/data_subject/data_subject.py b/packages/syft/src/syft/service/data_subject/data_subject.py index 301d24881f3..409462d4bce 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject.py +++ b/packages/syft/src/syft/service/data_subject/data_subject.py @@ -14,6 +14,7 @@ from ...serde.serializable import serializable from ...store.document_store import PartitionKey from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import add_node_uid_for_key @@ -34,7 +35,7 @@ class DataSubject(SyftObject): node_uid: UID name: str - description: Optional[str] + description: Optional[str] = None aliases: List[str] = [] @property @@ -64,7 +65,7 @@ def __repr_syft_nested__(self) -> str: def __repr__(self) -> str: return f"<DataSubject: {self.name}>" - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _repr_str = f"DataSubject: {self.name}\n" _repr_str += f"Description: {self.description}\n" _repr_str += f"Aliases: {self.aliases}\n" @@ -76,11 +77,11 @@ def _repr_markdown_(self) -> str: class DataSubjectCreate(SyftObject): # version __canonical_name__ = "DataSubjectCreate" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 - id: Optional[UID] = None + id: Optional[UID] = None # type: ignore[assignment] name: str - description: Optional[str] + description: Optional[str] = None aliases: Optional[List[str]] = [] members: Dict[str, "DataSubjectCreate"] = {} @@ -120,7 +121,7 @@ def member_relationships(self) -> Set[Tuple[str, str]]: def __repr__(self) -> str: return f"<DataSubject: {self.name}>" - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _repr_str = f"DataSubject: {self.name}\n" _repr_str += f"Description: {self.description}\n" _repr_str += f"Aliases: {self.aliases}\n" @@ -129,7 +130,8 @@ def _repr_markdown_(self) -> str: def remove_members_list(context: TransformContext) -> TransformContext: - context.output.pop("members", []) + if context.output is not None: + context.output.pop("members", []) return context diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index dfed8d7b951..f514566d4c0 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -47,7 +47,10 @@ def get_by_name( return self.query_one(credentials, qks=qks) def update( - self, credentials: SyftVerifyKey, data_subject: DataSubject + self, + credentials: SyftVerifyKey, + data_subject: DataSubject, + has_permission: bool = False, ) -> Result[DataSubject, str]: res = self.check_type(data_subject, DataSubject) # we dont use and_then logic here as it is hard because of the order of the arguments diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index a624956172b..fe39765ad35 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -15,12 +15,13 @@ from IPython.display import display import itables import pandas as pd -from pydantic import ValidationError -from pydantic import root_validator -from pydantic import validator +from pydantic import ConfigDict +from pydantic import field_validator +from pydantic import model_validator from result import Err from result import Ok from result import Result +from typing_extensions import Self # relative from ...serde.serializable import serializable @@ -28,6 +29,7 @@ from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import generate_id @@ -62,10 +64,10 @@ class Contributor(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 name: str - role: Optional[str] + role: Optional[str] = None email: str - phone: Optional[str] - note: Optional[str] + phone: Optional[str] = None + note: Optional[str] = None __repr_attrs__ = ["name", "role", "email"] @@ -101,7 +103,7 @@ class MarkdownDescription(SyftObject): text: str - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: style = """ <style> .jp-RenderedHTMLCommon pre { @@ -131,9 +133,9 @@ class Asset(SyftObject): contributors: Set[Contributor] = set() data_subjects: List[DataSubject] = [] mock_is_real: bool = False - shape: Optional[Tuple] + shape: Optional[Tuple] = None created_at: DateTime = DateTime.now() - uploader: Optional[Contributor] + uploader: Optional[Contributor] = None __repr_attrs__ = ["name", "shape"] @@ -208,7 +210,7 @@ def _repr_html_(self) -> Any: {mock_table_line} </div>""" - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _repr_str = f"Asset: {self.name}\n" _repr_str += f"Pointer Id: {self.action_id}\n" _repr_str += f"Description: {self.description}\n" @@ -257,8 +259,6 @@ def mock(self) -> Union[SyftError, Any]: ) if api is None: return SyftError(message=f"You must login to {self.node_uid}") - if api.services is None: - return SyftError(message=f"Services for {api} is None") result = api.services.action.get_mock(self.action_id) try: if isinstance(result, SyftObject): @@ -310,7 +310,7 @@ def check_mock(data: Any, mock: Any) -> bool: if type(data) == type(mock): return True - return _is_action_data_empty(mock) + return _is_action_data_empty(mock) or _is_action_data_empty(data) @serializable() @@ -319,45 +319,34 @@ class CreateAsset(SyftObject): __canonical_name__ = "CreateAsset" __version__ = SYFT_OBJECT_VERSION_1 - id: Optional[UID] = None + id: Optional[UID] = None # type:ignore[assignment] name: str description: Optional[MarkdownDescription] = None contributors: Set[Contributor] = set() data_subjects: List[DataSubjectCreate] = [] - node_uid: Optional[UID] - action_id: Optional[UID] - data: Optional[Any] - mock: Optional[Any] - shape: Optional[Tuple] + node_uid: Optional[UID] = None + action_id: Optional[UID] = None + data: Optional[Any] = None + mock: Optional[Any] = None + shape: Optional[Tuple] = None mock_is_real: bool = False - created_at: Optional[DateTime] - uploader: Optional[Contributor] + created_at: Optional[DateTime] = None + uploader: Optional[Contributor] = None __repr_attrs__ = ["name"] - - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) def __init__(self, description: Optional[str] = "", **data: Any) -> None: super().__init__(**data, description=MarkdownDescription(text=str(description))) - @root_validator() - def __empty_mock_cannot_be_real(cls, values: dict[str, Any]) -> Dict: - """set mock_is_real to False whenever mock is None or empty""" - - if (mock := values.get("mock")) is None or _is_action_data_empty(mock): - values["mock_is_real"] = False - - return values + @model_validator(mode="after") + def __mock_is_real_for_empty_mock_must_be_false(self) -> Self: + if self.mock_is_real and ( + self.mock is None or _is_action_data_empty(self.mock) + ): + self.__dict__["mock_is_real"] = False - @validator("mock_is_real") - def __mock_is_real_for_empty_mock_must_be_false( - cls, v: bool, values: dict[str, Any], **kwargs: Any - ) -> bool: - if v and ((mock := values.get("mock")) is None or _is_action_data_empty(mock)): - raise ValueError("mock_is_real must be False if mock is not provided") - - return v + return self def add_data_subject(self, data_subject: DataSubject) -> None: self.data_subjects.append(data_subject) @@ -399,14 +388,11 @@ def set_mock(self, mock_data: Any, mock_is_real: bool) -> None: if isinstance(mock_data, SyftError): raise SyftException(mock_data) - current_mock = self.mock - self.mock = mock_data + if mock_is_real and (mock_data is None or _is_action_data_empty(mock_data)): + raise SyftException("`mock_is_real` must be False if mock is empty") - try: - self.mock_is_real = mock_is_real - except ValidationError as e: - self.mock = current_mock - raise e + self.mock = mock_data + self.mock_is_real = mock_is_real def no_mock(self) -> None: # relative @@ -457,20 +443,20 @@ def get_shape_or_len(obj: Any) -> Optional[Union[Tuple[int, ...], int]]: @serializable() class Dataset(SyftObject): # version - __canonical_name__ = "Dataset" - __version__ = SYFT_OBJECT_VERSION_1 + __canonical_name__: str = "Dataset" + __version__ = SYFT_OBJECT_VERSION_2 id: UID name: str - node_uid: Optional[UID] + node_uid: Optional[UID] = None asset_list: List[Asset] = [] contributors: Set[Contributor] = set() - citation: Optional[str] - url: Optional[str] + citation: Optional[str] = None + url: Optional[str] = None description: Optional[MarkdownDescription] = None - updated_at: Optional[str] + updated_at: Optional[str] = None requests: Optional[int] = 0 - mb_size: Optional[int] + mb_size: Optional[float] = None created_at: DateTime = DateTime.now() uploader: Contributor @@ -558,7 +544,7 @@ def _old_repr_markdown_(self) -> str: _repr_str += f"Description: {self.description.text}\n" return as_markdown_python_code(_repr_str) - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: # return self._old_repr_markdown_() return self._markdown_() @@ -631,22 +617,22 @@ class DatasetPageView(SyftObject): class CreateDataset(Dataset): # version __canonical_name__ = "CreateDataset" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 asset_list: List[CreateAsset] = [] __repr_attrs__ = ["name", "url"] - id: Optional[UID] = None - created_at: Optional[DateTime] - uploader: Optional[Contributor] # type: ignore[assignment] + id: Optional[UID] = None # type: ignore[assignment] + created_at: Optional[DateTime] = None # type: ignore[assignment] + uploader: Optional[Contributor] = None # type: ignore[assignment] - class Config: - validate_assignment = True + model_config = ConfigDict(validate_assignment=True) def _check_asset_must_contain_mock(self) -> None: _check_asset_must_contain_mock(self.asset_list) - @validator("asset_list") + @field_validator("asset_list") + @classmethod def __assets_must_contain_mock( cls, asset_list: List[CreateAsset] ) -> List[CreateAsset]: @@ -740,6 +726,9 @@ def check(self) -> Result[SyftSuccess, List[SyftError]]: def create_and_store_twin(context: TransformContext) -> TransformContext: + if context.output is None: + raise ValueError("f{context}'s output is None. No trasformation happened") + action_id = context.output["action_id"] if action_id is None: # relative @@ -748,37 +737,49 @@ def create_and_store_twin(context: TransformContext) -> TransformContext: private_obj = context.output.pop("data", None) mock_obj = context.output.pop("mock", None) if private_obj is None and mock_obj is None: - raise Exception("No data and no action_id means this asset has no data") + raise ValueError("No data and no action_id means this asset has no data") twin = TwinObject( private_obj=private_obj, mock_obj=mock_obj, ) + if context.node is None: + raise ValueError( + "f{context}'s node is None, please log in. No trasformation happened" + ) action_service = context.node.get_service("actionservice") result = action_service.set( context=context.to_node_context(), action_object=twin ) if result.is_err(): - raise Exception(f"Failed to create and store twin. {result}") + raise RuntimeError(f"Failed to create and store twin. Error: {result}") context.output["action_id"] = twin.id else: private_obj = context.output.pop("data", None) mock_obj = context.output.pop("mock", None) + return context def infer_shape(context: TransformContext) -> TransformContext: - if context.output["shape"] is None: - if not _is_action_data_empty(context.obj.mock): + if context.output is not None and context.output["shape"] is None: + if context.obj is not None and not _is_action_data_empty(context.obj.mock): context.output["shape"] = get_shape_or_len(context.obj.mock) + else: + print("f{context}'s output is None. No trasformation happened") return context -def set_data_subjects(context: TransformContext) -> TransformContext: +def set_data_subjects(context: TransformContext) -> Union[TransformContext, SyftError]: + if context.output is None: + return SyftError("f{context}'s output is None. No trasformation happened") + if context.node is None: + return SyftError( + "f{context}'s node is None, please log in. No trasformation happened" + ) data_subjects = context.output["data_subjects"] get_data_subject = context.node.get_service_method(DataSubjectService.get_by_name) - resultant_data_subjects = [] for data_subject in data_subjects: result = get_data_subject(context=context, name=data_subject.name) @@ -790,13 +791,19 @@ def set_data_subjects(context: TransformContext) -> TransformContext: def add_msg_creation_time(context: TransformContext) -> TransformContext: + if context.output is None: + return context + context.output["created_at"] = DateTime.now() return context def add_default_node_uid(context: TransformContext) -> TransformContext: - if context.output["node_uid"] is None: - context.output["node_uid"] = context.node.id + if context.output is not None: + if context.output["node_uid"] is None and context.node is not None: + context.output["node_uid"] = context.node.id + else: + print("f{context}'s output is None. No trasformation happened.") return context @@ -813,18 +820,26 @@ def createasset_to_asset() -> List[Callable]: def convert_asset(context: TransformContext) -> TransformContext: + if context.output is None: + return context + assets = context.output.pop("asset_list", []) for idx, create_asset in enumerate(assets): asset_context = TransformContext.from_context(obj=create_asset, context=context) assets[idx] = create_asset.to(Asset, context=asset_context) context.output["asset_list"] = assets + return context def add_current_date(context: TransformContext) -> TransformContext: + if context.output is None: + return context + current_date = datetime.now() formatted_date = current_date.strftime("%b %d, %Y") context.output["updated_at"] = formatted_date + return context diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index be35bc0cf3e..19abea2e3eb 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -40,7 +40,10 @@ def get_by_name( return self.query_one(credentials=credentials, qks=qks) def update( - self, credentials: SyftVerifyKey, dataset_update: DatasetUpdate + self, + credentials: SyftVerifyKey, + dataset_update: DatasetUpdate, + has_permission: bool = False, ) -> Result[Dataset, str]: res = self.check_type(dataset_update, DatasetUpdate) # we dont use and_then logic here as it is hard because of the order of the arguments diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index 938858faa28..f543f55e9b2 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -1,5 +1,6 @@ # stdlib from typing import Dict +from typing import Optional from typing import Type from typing import Union @@ -14,11 +15,12 @@ from ...types.twin_object import TwinObject from ...types.uid import UID from ..action.action_object import ActionObject -from ..code.user_code_service import UserCode -from ..code.user_code_service import UserCodeStatus +from ..code.user_code import UserCode +from ..code.user_code import UserCodeStatus from ..context import AuthedServiceContext from ..context import ChangeContext from ..network.routes import route_to_connection +from ..policy.policy import InputPolicy from ..service import AbstractService from ..service import service_method @@ -59,7 +61,7 @@ def send_user_code_inputs_to_enclave( return user_code reason: str = context.extra_kwargs.get("reason", "") - status_update = user_code.status.mutate( + status_update = user_code.get_status(root_context).mutate( value=(UserCodeStatus.APPROVED, reason), node_name=node_name, node_id=node_id, @@ -68,13 +70,9 @@ def send_user_code_inputs_to_enclave( if isinstance(status_update, SyftError): return status_update - user_code.status = status_update - - user_code_update = user_code_service.update_code_state( - context=root_context, code_item=user_code - ) - if isinstance(user_code_update, SyftError): - return user_code_update + res = user_code.status_link.update_with_context(root_context, status_update) + if isinstance(res, SyftError): + return res root_context = context.as_root_context() if not action_service.exists(context=context, obj_id=user_code_id): @@ -157,9 +155,12 @@ def propagate_inputs_to_enclave( else: return SyftSuccess(message="Current Request does not require Enclave Transfer") - if user_code.input_policy is None: + input_policy: Optional[InputPolicy] = user_code.get_input_policy( + context.to_service_ctx() + ) + if input_policy is None: return SyftError(message=f"{user_code}'s input policy is None") - inputs = user_code.input_policy._inputs_for_context(context) + inputs = input_policy._inputs_for_context(context) if isinstance(inputs, SyftError): return inputs diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index c0276b31f69..70a6d343ef8 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -1,4 +1,5 @@ # stdlib +from typing import Any from typing import List from typing import Union from typing import cast @@ -12,14 +13,18 @@ from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission +from ..code.user_code import UserCode from ..context import AuthedServiceContext +from ..log.log_service import LogService from ..queue.queue_stash import ActionQueueItem from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import DATA_OWNER_ROLE_LEVEL from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL +from ..user.user_roles import GUEST_ROLE_LEVEL from .job_stash import Job from .job_stash import JobStash from .job_stash import JobStatus @@ -38,7 +43,7 @@ def __init__(self, store: DocumentStore) -> None: @service_method( path="job.get", name="get", - roles=DATA_SCIENTIST_ROLE_LEVEL, + roles=GUEST_ROLE_LEVEL, ) def get( self, context: AuthedServiceContext, uid: UID @@ -77,6 +82,19 @@ def get_by_user_code_id( res = res.ok() return res + @service_method( + path="job.delete", + name="delete", + roles=ADMIN_ROLE_LEVEL, + ) + def delete( + self, context: AuthedServiceContext, uid: UID + ) -> Union[SyftSuccess, SyftError]: + res = self.stash.delete_by_uid(context.credentials, uid) + if res.is_err(): + return SyftError(message=res.err()) + return SyftSuccess(message="Great Success!") + @service_method( path="job.restart", name="restart", @@ -181,6 +199,36 @@ def get_active(self, context: AuthedServiceContext) -> Union[List[Job], SyftErro return SyftError(message=res.err()) return res.ok() + @service_method( + path="job.add_read_permission_job_for_code_owner", + name="add_read_permission_job_for_code_owner", + roles=DATA_OWNER_ROLE_LEVEL, + ) + def add_read_permission_job_for_code_owner( + self, context: AuthedServiceContext, job: Job, user_code: UserCode + ) -> None: + permission = ActionObjectPermission( + job.id, ActionPermission.READ, user_code.user_verify_key + ) + return self.stash.add_permission(permission=permission) + + @service_method( + path="job.add_read_permission_log_for_code_owner", + name="add_read_permission_log_for_code_owner", + roles=DATA_OWNER_ROLE_LEVEL, + ) + def add_read_permission_log_for_code_owner( + self, context: AuthedServiceContext, log_id: UID, user_code: UserCode + ) -> Any: + context.node = cast(AbstractNode, context.node) + log_service = context.node.get_service("logservice") + log_service = cast(LogService, log_service) + return log_service.stash.add_permission( + ActionObjectPermission( + log_id, ActionPermission.READ, user_code.user_verify_key + ) + ) + @service_method( path="job.create_job_for_user_code_id", name="create_job_for_user_code_id", @@ -206,21 +254,19 @@ def create_job_for_user_code_id( return user_code # The owner of the code should be able to read the job - permission = ActionObjectPermission( - job.id, ActionPermission.READ, user_code.user_verify_key - ) - self.stash.set(context.credentials, job, add_permissions=[permission]) + self.stash.set(context.credentials, job) + self.add_read_permission_job_for_code_owner(context, job, user_code) - context.node = cast(AbstractNode, context.node) log_service = context.node.get_service("logservice") res = log_service.add(context, job.log_id) if isinstance(res, SyftError): return res # The owner of the code should be able to read the job log - log_service.stash.add_permission( - ActionObjectPermission( - job.log_id, ActionPermission.READ, user_code.user_verify_key - ) - ) + self.add_read_permission_log_for_code_owner(context, job.log_id, user_code) + # log_service.stash.add_permission( + # ActionObjectPermission( + # job.log_id, ActionPermission.READ, user_code.user_verify_key + # ) + # ) return job diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 7fa85144ec7..7cbce206770 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -3,14 +3,14 @@ from datetime import timedelta from enum import Enum from typing import Any -from typing import Callable from typing import Dict from typing import List from typing import Optional from typing import Union # third party -import pydantic +from pydantic import field_validator +from pydantic import model_validator from result import Err from result import Ok from result import Result @@ -30,14 +30,10 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.syft_object import short_uid -from ...types.transforms import drop -from ...types.transforms import make_set_default from ...types.uid import UID from ...util import options from ...util.colors import SURFACE @@ -62,43 +58,6 @@ class JobStatus(str, Enum): INTERRUPTED = "interrupted" -@serializable() -class JobV1(SyftObject): - __canonical_name__ = "JobItem" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - node_uid: UID - result: Optional[Any] - resolved: bool = False - status: JobStatus = JobStatus.CREATED - log_id: Optional[UID] - parent_job_id: Optional[UID] - n_iters: Optional[int] = 0 - current_iter: Optional[int] = None - creation_time: Optional[str] = None - action: Optional[Action] = None - - -@serializable() -class JobV2(SyftObject): - __canonical_name__ = "JobItem" - __version__ = SYFT_OBJECT_VERSION_2 - - id: UID - node_uid: UID - result: Optional[Any] - resolved: bool = False - status: JobStatus = JobStatus.CREATED - log_id: Optional[UID] - parent_job_id: Optional[UID] - n_iters: Optional[int] = 0 - current_iter: Optional[int] = None - creation_time: Optional[str] = None - action: Optional[Action] = None - job_pid: Optional[int] = None - - @serializable() class Job(SyftObject): __canonical_name__ = "JobItem" @@ -106,11 +65,11 @@ class Job(SyftObject): id: UID node_uid: UID - result: Optional[Any] + result: Optional[Any] = None resolved: bool = False status: JobStatus = JobStatus.CREATED - log_id: Optional[UID] - parent_job_id: Optional[UID] + log_id: Optional[UID] = None + parent_job_id: Optional[UID] = None n_iters: Optional[int] = 0 current_iter: Optional[int] = None creation_time: Optional[str] = None @@ -122,27 +81,25 @@ class Job(SyftObject): __attr_searchable__ = ["parent_job_id", "job_worker_id", "status", "user_code_id"] __repr_attrs__ = ["id", "result", "resolved", "progress", "creation_time"] + __exclude_sync_diff_attrs__ = ["action"] - @pydantic.root_validator() - def check_time(cls, values: dict) -> dict: - if values.get("creation_time", None) is None: - values["creation_time"] = str(datetime.now()) - return values - - @pydantic.root_validator() - def check_user_code_id(cls, values: dict) -> dict: - action = values.get("action") - user_code_id = values.get("user_code_id") - - if action is not None: - if user_code_id is None: - values["user_code_id"] = action.user_code_id - elif action.user_code_id != user_code_id: - raise pydantic.ValidationError( - "user_code_id does not match the action's user_code_id", cls + @field_validator("creation_time") + @classmethod + def check_time(cls, time: Any) -> Any: + return str(datetime.now()) if time is None else time + + @model_validator(mode="after") + def check_user_code_id(self) -> Self: + if self.action is not None: + if self.user_code_id is None: + self.user_code_id = self.action.user_code_id + elif self.action.user_code_id != self.user_code_id: + raise ValueError( + "user_code_id does not match the action's user_code_id", + self.__class__, ) - return values + return self @property def action_display_name(self) -> str: @@ -177,6 +134,10 @@ def worker(self) -> Union[SyftWorker, SyftError]: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + return SyftError( + message=f"Can't access Syft API. You must login to {self.syft_node_location}" + ) return api.services.worker.get(self.job_worker_id) @property @@ -267,6 +228,10 @@ def restart(self, kill: bool = False) -> None: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + raise ValueError( + f"Can't access Syft API. You must login to {self.syft_node_location}" + ) call = SyftAPICall( node_uid=self.node_uid, path="job.restart", @@ -288,6 +253,10 @@ def kill(self) -> Optional[SyftError]: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + return SyftError( + message=f"Can't access Syft API. You must login to {self.syft_node_location}" + ) call = SyftAPICall( node_uid=self.node_uid, path="job.kill", @@ -307,6 +276,10 @@ def fetch(self) -> None: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + raise ValueError( + f"Can't access Syft API. You must login to {self.syft_node_location}" + ) call = SyftAPICall( node_uid=self.node_uid, path="job.get", @@ -329,6 +302,10 @@ def subjobs(self) -> Union[list[QueueItem], SyftError]: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + return SyftError( + message=f"Can't access Syft API. You must login to {self.syft_node_location}" + ) return api.services.job.get_subjobs(self.id) @property @@ -337,8 +314,21 @@ def owner(self) -> Union[UserView, SyftError]: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + return SyftError( + message=f"Can't access Syft API. You must login to {self.syft_node_location}" + ) return api.services.user.get_current_user(self.id) + def _get_log_objs(self) -> Union[SyftObject, SyftError]: + api = APIRegistry.api_for( + node_uid=self.node_uid, + user_verify_key=self.syft_client_verify_key, + ) + if api is None: + raise ValueError(f"api is None. You must login to {self.node_uid}") + return api.services.log.get(self.log_id) + def logs( self, stdout: bool = True, stderr: bool = True, _print: bool = True ) -> Optional[str]: @@ -346,10 +336,15 @@ def logs( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) + if api is None: + return f"Can't access Syft API. You must login to {self.syft_node_location}" results = [] if stdout: - stdout_log = api.services.log.get(self.log_id) - results.append(stdout_log) + stdout_log = api.services.log.get_stdout(self.log_id) + if isinstance(stdout_log, SyftError): + results.append(f"Log {self.log_id} not available") + else: + results.append(stdout_log) if stderr: try: @@ -403,7 +398,7 @@ def _coll_repr_(self) -> Dict[str, Any]: def has_parent(self) -> bool: return self.parent_job_id is not None - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: _ = self.resolve logs = self.logs(_print=False) if logs is not None: @@ -433,7 +428,6 @@ def wait(self, job_only: bool = False) -> Union[Any, SyftNotReady]: node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) - # todo: timeout if self.resolved: return self.resolve @@ -441,6 +435,10 @@ def wait(self, job_only: bool = False) -> Union[Any, SyftNotReady]: if not job_only and self.result is not None: self.result.wait() + if api is None: + raise ValueError( + f"Can't access Syft API. You must login to {self.syft_node_location}" + ) print_warning = True while True: self.fetch() @@ -470,11 +468,27 @@ def resolve(self) -> Union[Any, SyftNotReady]: return self.result return SyftNotReady(message=f"{self.id} not ready yet.") + def get_sync_dependencies(self, **kwargs: Dict) -> List[UID]: + dependencies = [] + if self.result is not None: + dependencies.append(self.result.id.id) + + if self.log_id: + dependencies.append(self.log_id) + + subjob_ids = [subjob.id for subjob in self.subjobs] + dependencies.extend(subjob_ids) + + if self.user_code_id is not None: + dependencies.append(self.user_code_id) + + return dependencies + @serializable() class JobInfo(SyftObject): __canonical_name__ = "JobInfo" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 __repr_attrs__ = [ "resolved", "status", @@ -501,7 +515,7 @@ class JobInfo(SyftObject): current_iter: Optional[int] = None creation_time: Optional[str] = None - result: Optional[Any] = None + result: Optional[ActionObject] = None def _repr_html_(self) -> str: metadata_str = "" @@ -550,36 +564,11 @@ def from_job( raise ValueError("Cannot sync result of unresolved job") if not isinstance(job.result, ActionObject): raise ValueError("Could not sync result of job") - info.result = job.result.get() + info.result = job.result return info -@migrate(Job, JobV2) -def downgrade_job_v3_to_v2() -> list[Callable]: - return [drop(["job_worker_id", "user_code_id"])] - - -@migrate(JobV2, Job) -def upgrade_job_v2_to_v3() -> list[Callable]: - return [ - make_set_default("job_worker_id", None), - make_set_default("user_code_id", None), - ] - - -@migrate(JobV2, JobV1) -def downgrade_job_v2_to_v1() -> list[Callable]: - return [ - drop("job_pid"), - ] - - -@migrate(JobV1, JobV2) -def upgrade_job_v1_to_v2() -> list[Callable]: - return [make_set_default("job_pid", None)] - - @instrument @serializable() class JobStash(BaseStash): diff --git a/packages/syft/src/syft/service/log/log.py b/packages/syft/src/syft/service/log/log.py index 930e06b7bc4..e2687c2f8bc 100644 --- a/packages/syft/src/syft/service/log/log.py +++ b/packages/syft/src/syft/service/log/log.py @@ -1,6 +1,8 @@ +# stdlib +from typing import List + # relative from ...serde.serializable import serializable -from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject @@ -8,21 +10,17 @@ @serializable() class SyftLog(SyftObject): __canonical_name__ = "SyftLog" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 + + __repr_attrs__ = ["stdout", "stderr"] + __exclude_sync_diff_attrs__: List[str] = [] stdout: str = "" + stderr: str = "" def append(self, new_str: str) -> None: self.stdout += new_str - -@serializable() -class SyftLogV2(SyftLog): - __canonical_name__ = "SyftLog" - __version__ = SYFT_OBJECT_VERSION_2 - - stderr: str = "" - def append_error(self, new_str: str) -> None: self.stderr += new_str diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index b920be29f80..2a47321215b 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -16,7 +16,7 @@ from ..service import service_method from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from .log import SyftLogV2 +from .log import SyftLog from .log_stash import LogStash @@ -34,7 +34,7 @@ def __init__(self, store: DocumentStore) -> None: def add( self, context: AuthedServiceContext, uid: UID ) -> Union[SyftSuccess, SyftError]: - new_log = SyftLogV2(id=uid) + new_log = SyftLog(id=uid) result = self.stash.set(context.credentials, new_log) if result.is_err(): return SyftError(message=str(result.err())) @@ -71,6 +71,18 @@ def get( if result.is_err(): return SyftError(message=str(result.err())) + return result + + @service_method( + path="log.get_stdout", name="get_stdout", roles=DATA_SCIENTIST_ROLE_LEVEL + ) + def get_stdout( + self, context: AuthedServiceContext, uid: UID + ) -> Union[SyftSuccess, SyftError]: + result = self.stash.get_by_uid(context.credentials, uid) + if result.is_err(): + return SyftError(message=str(result.err())) + return Ok(result.ok().stdout) @service_method(path="log.restart", name="restart", roles=DATA_SCIENTIST_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index 511b917d2cf..f1c37d9f6b2 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -4,15 +4,15 @@ from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument -from .log import SyftLogV2 +from .log import SyftLog @instrument @serializable() class LogStash(BaseUIDStoreStash): - object_type = SyftLogV2 + object_type = SyftLog settings: PartitionSettings = PartitionSettings( - name=SyftLogV2.__canonical_name__, object_type=SyftLogV2 + name=SyftLog.__canonical_name__, object_type=SyftLog ) def __init__(self, store: DocumentStore) -> None: diff --git a/packages/syft/src/syft/service/metadata/migrations.py b/packages/syft/src/syft/service/metadata/migrations.py index be36a07e123..91de59d220c 100644 --- a/packages/syft/src/syft/service/metadata/migrations.py +++ b/packages/syft/src/syft/service/metadata/migrations.py @@ -2,54 +2,23 @@ from typing import Callable # relative -from ...types.syft_migration import migrate from ...types.transforms import TransformContext -from ...types.transforms import drop -from ...types.transforms import rename -from .node_metadata import NodeMetadata -from .node_metadata import NodeMetadataV2 -from .node_metadata import NodeMetadataV3 - - -@migrate(NodeMetadata, NodeMetadataV2) -def upgrade_metadata_v1_to_v2() -> list[Callable]: - return [ - rename("highest_object_version", "highest_version"), - rename("lowest_object_version", "lowest_version"), - ] - - -@migrate(NodeMetadataV2, NodeMetadata) -def downgrade_metadata_v2_to_v1() -> list[Callable]: - return [ - rename("highest_version", "highest_object_version"), - rename("lowest_version", "lowest_object_version"), - ] - - -@migrate(NodeMetadataV2, NodeMetadataV3) -def upgrade_metadata_v2_to_v3() -> list[Callable]: - return [drop(["deployed_on", "on_board", "signup_enabled", "admin_email"])] def _downgrade_metadata_v3_to_v2() -> Callable: def set_defaults_from_settings(context: TransformContext) -> TransformContext: # Extract from settings if node is attached to context - if context.node is not None: - context.output["deployed_on"] = context.node.settings.deployed_on - context.output["on_board"] = context.node.settings.on_board - context.output["signup_enabled"] = context.node.settings.signup_enabled - context.output["admin_email"] = context.node.settings.admin_email - else: - # Else set default value - context.output["signup_enabled"] = False - context.output["admin_email"] = "" + if context.output is not None: + if context.node is not None: + context.output["deployed_on"] = context.node.settings.deployed_on + context.output["on_board"] = context.node.settings.on_board + context.output["signup_enabled"] = context.node.settings.signup_enabled + context.output["admin_email"] = context.node.settings.admin_email + else: + # Else set default value + context.output["signup_enabled"] = False + context.output["admin_email"] = "" return context return set_defaults_from_settings - - -@migrate(NodeMetadataV3, NodeMetadataV2) -def downgrade_metadata_v3_to_v2() -> list[Callable]: - return [_downgrade_metadata_v3_to_v2()] diff --git a/packages/syft/src/syft/service/metadata/node_metadata.py b/packages/syft/src/syft/service/metadata/node_metadata.py index 13eb097b0ac..fcb5c91d091 100644 --- a/packages/syft/src/syft/service/metadata/node_metadata.py +++ b/packages/syft/src/syft/service/metadata/node_metadata.py @@ -9,7 +9,7 @@ # third party from packaging import version from pydantic import BaseModel -from pydantic import root_validator +from pydantic import model_validator # relative from ...abstract_node import NodeType @@ -17,7 +17,6 @@ from ...protocol.data_protocol import get_data_protocol from ...serde.serializable import serializable from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import StorableObjectType from ...types.syft_object import SyftObject @@ -51,74 +50,16 @@ class NodeMetadataUpdate(SyftObject): __canonical_name__ = "NodeMetadataUpdate" __version__ = SYFT_OBJECT_VERSION_1 - name: Optional[str] - organization: Optional[str] - description: Optional[str] - on_board: Optional[bool] - id: Optional[UID] - verify_key: Optional[SyftVerifyKey] - highest_object_version: Optional[int] - lowest_object_version: Optional[int] - syft_version: Optional[str] - admin_email: Optional[str] - - -@serializable() -class NodeMetadata(SyftObject): - __canonical_name__ = "NodeMetadata" - __version__ = SYFT_OBJECT_VERSION_1 - - name: str - id: UID - verify_key: SyftVerifyKey - highest_object_version: int - lowest_object_version: int - syft_version: str - node_type: NodeType = NodeType.DOMAIN - deployed_on: str = "Date" - organization: str = "OpenMined" - on_board: bool = False - description: str = "Text" - signup_enabled: bool - admin_email: str - node_side_type: str - show_warnings: bool - - def check_version(self, client_version: str) -> bool: - return check_version( - client_version=client_version, - server_version=self.syft_version, - server_name=self.name, - ) - - -@serializable() -class NodeMetadataV2(SyftObject): - __canonical_name__ = "NodeMetadata" - __version__ = SYFT_OBJECT_VERSION_2 - - name: str - highest_version: int - lowest_version: int - id: UID - verify_key: SyftVerifyKey - syft_version: str - node_type: NodeType = NodeType.DOMAIN - deployed_on: str = "Date" - organization: str = "OpenMined" - on_board: bool = False - description: str = "Text" - signup_enabled: bool - admin_email: str - node_side_type: str - show_warnings: bool - - def check_version(self, client_version: str) -> bool: - return check_version( - client_version=client_version, - server_version=self.syft_version, - server_name=self.name, - ) + name: Optional[str] = None + organization: Optional[str] = None + description: Optional[str] = None + on_board: Optional[bool] = None + id: Optional[UID] = None # type: ignore[assignment] + verify_key: Optional[SyftVerifyKey] = None + highest_object_version: Optional[int] = None + lowest_object_version: Optional[int] = None + syft_version: Optional[str] = None + admin_email: Optional[str] = None @serializable() @@ -152,8 +93,8 @@ class NodeMetadataJSON(BaseModel, StorableObjectType): name: str id: str verify_key: str - highest_object_version: Optional[int] - lowest_object_version: Optional[int] + highest_object_version: Optional[int] = None + lowest_object_version: Optional[int] = None syft_version: str node_type: str = NodeType.DOMAIN.value organization: str = "OpenMined" @@ -164,7 +105,8 @@ class NodeMetadataJSON(BaseModel, StorableObjectType): show_warnings: bool supported_protocols: List = [] - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def add_protocol_versions(cls, values: dict) -> dict: if "supported_protocols" not in values: data_protocol = get_data_protocol() diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 16a13ec3aac..499dddb3798 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -70,7 +70,10 @@ def get_by_name( return self.query_one(credentials=credentials, qks=qks) def update( - self, credentials: SyftVerifyKey, peer: NodePeer + self, + credentials: SyftVerifyKey, + peer: NodePeer, + has_permission: bool = False, ) -> Result[NodePeer, str]: valid = self.check_type(peer, NodePeer) if valid.is_err(): @@ -229,7 +232,6 @@ def add_peer( except Exception as e: return SyftError(message=str(e)) - context.node = cast(AbstractNode, context.node) result = self.stash.update_peer(context.node.verify_key, peer) if result.is_err(): return SyftError(message=str(result.err())) @@ -243,7 +245,6 @@ def add_peer( # Q,TODO: Should the returned node peer also be signed # as the challenge is already signed - context.node = cast(AbstractNode, context.node) challenge_signature = context.node.signing_key.signing_key.sign( challenge ).signature @@ -310,7 +311,6 @@ def verify_route( ) ) peer.update_routes([route]) - context.node = cast(AbstractNode, context.node) result = self.stash.update_peer(context.node.verify_key, peer) if result.is_err(): return SyftError(message=str(result.err())) @@ -376,13 +376,15 @@ def get_peers_by_type( def from_grid_url(context: TransformContext) -> TransformContext: - url = context.obj.url.as_container_host() - context.output["host_or_ip"] = url.host_or_ip - context.output["protocol"] = url.protocol - context.output["port"] = url.port - context.output["private"] = False - context.output["proxy_target_uid"] = context.obj.proxy_target_uid - context.output["priority"] = 1 + if context.obj is not None and context.output is not None: + url = context.obj.url.as_container_host() + context.output["host_or_ip"] = url.host_or_ip + context.output["protocol"] = url.protocol + context.output["port"] = url.port + context.output["private"] = False + context.output["proxy_target_uid"] = context.obj.proxy_target_uid + context.output["priority"] = 1 + return context @@ -392,9 +394,10 @@ def http_connection_to_node_route() -> List[Callable]: def get_python_node_route(context: TransformContext) -> TransformContext: - context.output["id"] = context.obj.node.id - context.output["worker_settings"] = WorkerSettings.from_node(context.obj.node) - context.output["proxy_target_uid"] = context.obj.proxy_target_uid + if context.output is not None and context.obj is not None: + context.output["id"] = context.obj.node.id + context.output["worker_settings"] = WorkerSettings.from_node(context.obj.node) + context.output["proxy_target_uid"] = context.obj.proxy_target_uid return context diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index cc9764947e5..4f5f6ac5593 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -35,7 +35,7 @@ class NodePeer(SyftObject): __attr_unique__ = ["verify_key"] __repr_attrs__ = ["name", "node_type", "admin_email"] - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] name: str verify_key: SyftVerifyKey node_routes: List[NodeRouteType] = [] diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index eb64cca0546..15027a97ab8 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -145,4 +145,4 @@ def connection_to_route(connection: NodeConnection) -> NodeRoute: if isinstance(connection, HTTPConnection): return connection.to(HTTPNodeRoute) else: - return connection.to(PythonNodeRoute) + return connection.to(PythonNodeRoute) # type: ignore[unreachable] diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py new file mode 100644 index 00000000000..09de0297bf1 --- /dev/null +++ b/packages/syft/src/syft/service/notification/email_templates.py @@ -0,0 +1,407 @@ +# stdlib +from typing import TYPE_CHECKING +from typing import cast + +# relative +from ...abstract_node import AbstractNode +from ...store.linked_obj import LinkedObject +from ..context import AuthedServiceContext + +if TYPE_CHECKING: + # relative + from .notifications import Notification + + +class EmailTemplate: + @staticmethod + def email_title(notification: "Notification", context: AuthedServiceContext) -> str: + return "" + + @staticmethod + def email_body(notification: "Notification", context: AuthedServiceContext) -> str: + return "" + + +class OnBoardEmailTemplate(EmailTemplate): + @staticmethod + def email_title(notification: "Notification", context: AuthedServiceContext) -> str: + context.node = cast(AbstractNode, context.node) + return f"Welcome to {context.node.name} node!" + + @staticmethod + def email_body(notification: "Notification", context: AuthedServiceContext) -> str: + context.node = cast(AbstractNode, context.node) + user_service = context.node.get_service("userservice") + admin_name = user_service.get_by_verify_key( + user_service.admin_verify_key() + ).name + + head = ( + f""" + <head> + <title>Welcome to {context.node.name}</title> + """ + + """ + <style> + body { + font-family: Arial, sans-serif; + background-color: #f4f4f4; + color: #333; + line-height: 1.6; + } + .container { + max-width: 600px; + margin: 20px auto; + padding: 20px; + background: #fff; + } + h1 { + color: #0056b3; + } + .feature { + background-color: #e7f1ff; + padding: 10px; + margin: 10px 0; + border-radius: 5px; + } + .footer { + text-align: center; + font-size: 14px; + color: #aaa; + } + </style> + </head> + """ + ) + + body = f""" + <body> + <div class="container"> + <h1>Welcome to {context.node.name} node!</h1> + <p>Hello,</p> + <p>We're thrilled to have you on board and + excited to help you get started with our powerful features:</p> + + <div class="feature"> + <h3>Remote Data Science</h3> + <p>Access and analyze data from anywhere, using our comprehensive suite of data science tools.</p> + </div> + + <div class="feature"> + <h3>Remote Code Execution</h3> + <p>Execute code remotely on private data, ensuring flexibility and efficiency in your research.</p> + </div> + + <!-- Add more features here if needed --> + + <p>Explore these features and much more within your account. + If you have any questions or need assistance, don't hesitate to reach out.</p> + + <p>Cheers,</p> + <p>{admin_name}</p> + + <div class="footer"> + This is an automated message, please do not reply directly to this email. <br> + For assistance, please contact our support team. + </div> + </div> + </body> + """ + return f"""<html>{head} {body}</html>""" + + +class RequestEmailTemplate(EmailTemplate): + @staticmethod + def email_title(notification: "Notification", context: AuthedServiceContext) -> str: + context.node = cast(AbstractNode, context.node) + notification.linked_obj = cast(LinkedObject, notification.linked_obj) + request_obj = notification.linked_obj.resolve_with_context(context=context).ok() + + return f"Domain {context.node.name}: A New Request ({str(request_obj.id)[:4]}) has been received!" + + @staticmethod + def email_body(notification: "Notification", context: AuthedServiceContext) -> str: + context.node = cast(AbstractNode, context.node) + notification.linked_obj = cast(LinkedObject, notification.linked_obj) + request_obj = notification.linked_obj.resolve_with_context(context=context).ok() + + head = """ + <head> + <title>Access Request Notification</title> + <style> + body { + font-family: Arial, sans-serif; + background-color: #f4f4f4; + color: #333; + padding: 20px; + } + .container { + max-width: 600px; + margin: 0 auto; + background: #fff; + padding: 20px; + border-radius: 8px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + } + .header { + font-size: 24px; + color: #333; + text-align: center; + } + .content { + font-size: 16px; + line-height: 1.6; + } + + .request-card { + background-color: #ffffff; + border: 1px solid #ddd; + padding: 15px; + margin-top: 20px; + border-radius: 8px; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); + } + + .request-header { + font-size: 18px; + color: #333; + margin-bottom: 10px; + font-weight: bold; + } + + .request-content { + font-size: 14px; + line-height: 1.5; + color: #555; + } + .badge { + padding: 4px 10px; + border-radius: 15px; + color: white; + font-weight: bold; + margin: 10px; + box-shadow: 0px 4px 6px rgba(0,0,0,0.1); + text-transform: uppercase; + } + + .yellow { + background-color: #fdd835; + } + + .green { + background-color: #6dbf67; + } + + .red { + background-color: #f7786b; + } + + .button { + display: block; + width: max-content; + background-color: #007bff; + color: white; + padding: 10px 20px; + text-align: center; + text-color: white; + border-radius: 5px; + text-decoration: none; + font-weight: bold; + margin: 20px auto; + } + .footer { + text-align: center; + font-size: 14px; + color: #aaa; + } + </style> + </head>""" + + body = f""" + <body> + <div class="container"> + <div class="header"> + Request Notification + </div> + <div class="content"> + <p>Hello,</p> + <p>A new request has been submitted and requires your attention. + Please review the details below:</p> + + <div class="request-card"> + <div class="request-header">Request Details</div> + <div class="request-content"> + + <p><strong>ID:</strong> {request_obj.id}</p> + <p> + <strong>Submitted By:</strong> + {request_obj.requesting_user_name} {request_obj.requesting_user_email or ""} + </p> + <p><strong>Date:</strong> {request_obj.request_time}</p> + <div style="display: flex"><p><strong>Status:</strong><div class="badge yellow">{ + request_obj.status.name + }</div></div> + <p><strong>Changes:</strong>{ + ",".join([change.__class__.__name__ for change in request_obj.changes]) + }</p> + </div> + </div> + <p>If you did not expect this request or have concerns about it, + please contact our support team immediately.</p> + </div> + <div class="footer"> + This is an automated message, please do not reply directly to this email. <br> + For assistance, please contact our support team. + </div> + </div> + </body> + """ + return f"""<html>{head} {body}</html>""" + + +class RequestUpdateEmailTemplate(EmailTemplate): + @staticmethod + def email_title(notification: "Notification", context: AuthedServiceContext) -> str: + context.node = cast(AbstractNode, context.node) + return f"Domain {context.node.name}: {notification.subject}" + + @staticmethod + def email_body(notification: "Notification", context: AuthedServiceContext) -> str: + context.node = cast(AbstractNode, context.node) + notification.linked_obj = cast(LinkedObject, notification.linked_obj) + request_obj = notification.linked_obj.resolve_with_context(context=context).ok() + badge_color = "red" if request_obj.status.name == "REJECTED" else "green" + head = """ + <head> + <title>Access Request Notification</title> + <style> + body { + font-family: Arial, sans-serif; + background-color: #f4f4f4; + color: #333; + padding: 20px; + } + .container { + max-width: 600px; + margin: 0 auto; + background: #fff; + padding: 20px; + border-radius: 8px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + } + .header { + font-size: 24px; + color: #333; + text-align: center; + } + .content { + font-size: 16px; + line-height: 1.6; + } + + .request-card { + background-color: #ffffff; + border: 1px solid #ddd; + padding: 15px; + margin-top: 20px; + border-radius: 8px; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); + } + + .request-header { + font-size: 18px; + color: #333; + margin-bottom: 10px; + font-weight: bold; + } + + .badge { + padding: 4px 10px; + border-radius: 15px; + color: white; + font-weight: bold; + margin: 10px; + box-shadow: 0px 4px 6px rgba(0,0,0,0.1); + text-transform: uppercase; + } + + .yellow { + background-color: #fdd835; + } + + .green { + background-color: #6dbf67; + } + + .red { + background-color: #f7786b; + } + + .request-content { + font-size: 14px; + line-height: 1.5; + color: #555; + } + + .button { + display: block; + width: max-content; + background-color: #007bff; + color: white; + padding: 10px 20px; + text-align: center; + text-color: white; + border-radius: 5px; + text-decoration: none; + font-weight: bold; + margin: 20px auto; + } + .footer { + text-align: center; + font-size: 14px; + color: #aaa; + } + </style> + </head>""" + + body = f""" + <body> + <div class="container"> + <div class="header"> + Request Notification + </div> + <div class="content"> + <p>Hello,</p> + <p>The status of your recent request has been updated. + Below is the latest information regarding it:</p> + + <div class="request-card"> + <div class="request-header">Request Details</div> + <div class="request-content"> + + <p><strong>ID:</strong> {request_obj.id}</p> + <p> + <strong>Submitted By:</strong> + {request_obj.requesting_user_name} {request_obj.requesting_user_email or ""} + </p> + <p><strong>Date:</strong> {request_obj.request_time}</p> + <div style="display: flex"><p><strong>Status:</strong><div class="badge {badge_color}">{ + request_obj.status.name + }</div></div> + <p> + <strong>Changes:</strong> + {",".join([change.__class__.__name__ for change in request_obj.changes])} + </p> + </div> + </div> + <p>If you did not expect this request or have concerns about it, + please contact our support team immediately.</p> + </div> + <div class="footer"> + This is an automated message, please do not reply directly to this email. <br> + For assistance, please contact our support team. + </div> + </div> + </body> + """ + return f"""<html>{head} {body}</html>""" diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 7930404837a..19d089fd733 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -11,12 +11,14 @@ from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext +from ..notifier.notifier import NotifierSettings from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService from ..service import SERVICE_TO_TYPES from ..service import TYPE_TO_SERVICE from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL from .notification_stash import NotificationStash @@ -42,7 +44,6 @@ def send( self, context: AuthedServiceContext, notification: CreateNotification ) -> Union[Notification, SyftError]: """Send a new notification""" - new_notification = notification.to(Notification, context=context) # Add read permissions to person receiving this message @@ -55,6 +56,12 @@ def send( result = self.stash.set( context.credentials, new_notification, add_permissions=permissions ) + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + + res = notifier_service.dispatch_notification(context, new_notification) + if isinstance(res, SyftError): + return res if result.is_err(): return SyftError(message=str(result.err())) @@ -85,6 +92,60 @@ def reply( return result.ok() + @service_method( + path="notifications.user_settings", + name="user_settings", + ) + def user_settings( + self, + context: AuthedServiceContext, + ) -> Union[NotifierSettings, SyftError]: + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + return notifier_service.user_settings(context) + + @service_method( + path="notifications.settings", + name="settings", + roles=ADMIN_ROLE_LEVEL, + ) + def settings( + self, + context: AuthedServiceContext, + ) -> Union[NotifierSettings, SyftError]: + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + result = notifier_service.settings(context) + return result + + @service_method( + path="notifications.activate", + name="activate", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def activate( + self, + context: AuthedServiceContext, + ) -> Union[Notification, SyftError]: + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + result = notifier_service.activate(context) + return result + + @service_method( + path="notifications.deactivate", + name="deactivate", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def deactivate( + self, + context: AuthedServiceContext, + ) -> Union[Notification, SyftError]: + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + result = notifier_service.deactivate(context) + return result + @service_method( path="notifications.get_all", name="get_all", diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index fd8be6675d1..616950b71da 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -3,14 +3,18 @@ from typing import Callable from typing import List from typing import Optional +from typing import Type +from typing import cast # relative from ...client.api import APIRegistry +from ...client.api import SyftAPI from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import add_credentials_for_key @@ -20,6 +24,8 @@ from ...types.uid import UID from ...util import options from ...util.colors import SURFACE +from ..notifier.notifier_enums import NOTIFIERS +from .email_templates import EmailTemplate @serializable() @@ -45,14 +51,14 @@ class ReplyNotification(SyftObject): text: str target_msg: UID - id: Optional[UID] - from_user_verify_key: Optional[SyftVerifyKey] + id: Optional[UID] = None # type: ignore[assignment] + from_user_verify_key: Optional[SyftVerifyKey] = None @serializable() class Notification(SyftObject): __canonical_name__ = "Notification" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 subject: str node_uid: UID @@ -60,7 +66,9 @@ class Notification(SyftObject): to_user_verify_key: SyftVerifyKey created_at: DateTime status: NotificationStatus = NotificationStatus.UNREAD - linked_obj: Optional[LinkedObject] + linked_obj: Optional[LinkedObject] = None + notifier_types: List[NOTIFIERS] = [] + email_template: Optional[Type[EmailTemplate]] = None replies: Optional[List[ReplyNotification]] = [] __attr_searchable__ = [ @@ -93,27 +101,29 @@ def link(self) -> Optional[SyftObject]: return None def _coll_repr_(self) -> dict[str, str]: - linked_obj_name: str = "" - linked_obj_uid: Optional[UID] = None - if self.linked_obj is not None: - linked_obj_name = self.linked_obj.object_type.__canonical_name__ - linked_obj_uid = self.linked_obj.object_uid + self.linked_obj = cast(LinkedObject, self.linked_obj) return { "Subject": self.subject, "Status": self.determine_status().name.capitalize(), "Created At": str(self.created_at), - "Linked object": f"{linked_obj_name} ({linked_obj_uid})", + "Linked object": f"{self.linked_obj.object_type.__canonical_name__} ({self.linked_obj.object_uid})", } def mark_read(self) -> None: - api = APIRegistry.api_for( - self.node_uid, user_verify_key=self.syft_client_verify_key + api: SyftAPI = cast( + SyftAPI, + APIRegistry.api_for( + self.node_uid, user_verify_key=self.syft_client_verify_key + ), ) return api.services.notifications.mark_as_read(uid=self.id) def mark_unread(self) -> None: - api = APIRegistry.api_for( - self.node_uid, user_verify_key=self.syft_client_verify_key + api: SyftAPI = cast( + SyftAPI, + APIRegistry.api_for( + self.node_uid, user_verify_key=self.syft_client_verify_key + ), ) return api.services.notifications.mark_as_unread(uid=self.id) @@ -121,25 +131,31 @@ def determine_status(self) -> Enum: # relative from ..request.request import Request - if self.linked_obj is not None and isinstance(self.linked_obj.resolve, Request): + self.linked_obj = cast(LinkedObject, self.linked_obj) + if isinstance(self.linked_obj.resolve, Request): return self.linked_obj.resolve.status - return NotificationRequestStatus.NO_ACTION + return NotificationRequestStatus.NO_ACTION # type: ignore[unreachable] @serializable() -class CreateNotification(Notification): +class CreateNotification(SyftObject): __canonical_name__ = "CreateNotification" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 - id: Optional[UID] - node_uid: Optional[UID] - from_user_verify_key: Optional[SyftVerifyKey] - created_at: Optional[DateTime] + subject: str + from_user_verify_key: Optional[SyftVerifyKey] = None # type: ignore[assignment] + to_user_verify_key: Optional[SyftVerifyKey] = None # type: ignore[assignment] + linked_obj: Optional[LinkedObject] = None + notifier_types: List[NOTIFIERS] = [] + email_template: Optional[Type[EmailTemplate]] = None def add_msg_creation_time(context: TransformContext) -> TransformContext: - context.output["created_at"] = DateTime.now() + if context.output is not None: + context.output["created_at"] = DateTime.now() + else: + print("f{context}'s output is None. No trasformation happened.") return context diff --git a/packages/syft/src/syft/service/notifier/notifier.py b/packages/syft/src/syft/service/notifier/notifier.py new file mode 100644 index 00000000000..cc597209099 --- /dev/null +++ b/packages/syft/src/syft/service/notifier/notifier.py @@ -0,0 +1,240 @@ +# stdlib + +# stdlib +from typing import Dict +from typing import List +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union +from typing import cast + +# third party +from result import Err +from result import Ok +from result import Result + +# relative +from ...abstract_node import AbstractNode +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SyftObject +from ..context import AuthedServiceContext +from ..notification.notifications import Notification +from ..response import SyftError +from ..response import SyftSuccess +from .notifier_enums import NOTIFIERS +from .smtp_client import SMTPClient + + +class BaseNotifier: + def send( + self, target: SyftVerifyKey, notification: Notification + ) -> Union[SyftSuccess, SyftError]: + return SyftError(message="Not implemented") + + +TBaseNotifier = TypeVar("TBaseNotifier", bound=BaseNotifier) + + +class EmailNotifier(BaseNotifier): + smtp_client: SMTPClient + sender = "" + + def __init__( + self, + username: str, + password: str, + sender: str, + server: str, + port: int = 587, + ) -> None: + self.sender = sender + self.smtp_client = SMTPClient( + server=server, + port=port, + username=username, + password=password, + ) + + @classmethod + def check_credentials( + cls, + username: str, + password: str, + server: str, + port: int = 587, + ) -> Result[Ok, Err]: + return SMTPClient.check_credentials( + server=server, + port=port, + username=username, + password=password, + ) + + def send( + self, context: AuthedServiceContext, notification: Notification + ) -> Result[Ok, Err]: + try: + context.node = cast(AbstractNode, context.node) + + user_service = context.node.get_service("userservice") + + receiver = user_service.get_by_verify_key(notification.to_user_verify_key) + + if not receiver.notifications_enabled[NOTIFIERS.EMAIL]: + return Ok( + "Email notifications are disabled for this user." + ) # TODO: Should we return an error here? + + receiver_email = receiver.email + + if notification.email_template: + subject = notification.email_template.email_title( + notification, context=context + ) + body = notification.email_template.email_body( + notification, context=context + ) + else: + subject = notification.subject + body = notification._repr_html_() + + if isinstance(receiver_email, str): + receiver_email = [receiver_email] + + self.smtp_client.send( + sender=self.sender, receiver=receiver_email, subject=subject, body=body + ) + return Ok("Email sent successfully!") + except Exception: + return Err( + "Some notifications failed to be delivered. Please check the health of the mailing server." + ) + + +@serializable() +class NotificationPreferences(SyftObject): + __canonical_name__ = "NotificationPreferences" + __version__ = SYFT_OBJECT_VERSION_1 + __repr_attrs__ = [ + "email", + "sms", + "slack", + "app", + ] + + email: bool = False + sms: bool = False + slack: bool = False + app: bool = False + + +@serializable() +class NotifierSettings(SyftObject): + __canonical_name__ = "NotifierSettings" + __version__ = SYFT_OBJECT_VERSION_1 + __repr_attrs__ = [ + "active", + "email_enabled", + ] + active: bool = False + # Flag to identify which notification is enabled + # For now, consider only the email notification + # In future, Admin, must be able to have a better + # control on diff notifications. + + notifiers: Dict[NOTIFIERS, Type[TBaseNotifier]] = { + NOTIFIERS.EMAIL: EmailNotifier, + } + + notifiers_status: Dict[NOTIFIERS, bool] = { + NOTIFIERS.EMAIL: True, + NOTIFIERS.SMS: False, + NOTIFIERS.SLACK: False, + NOTIFIERS.APP: False, + } + + email_sender: Optional[str] = "" + email_server: Optional[str] = "" + email_port: Optional[int] = 587 + email_username: Optional[str] = "" + email_password: Optional[str] = "" + + @property + def email_enabled(self) -> bool: + return self.notifiers_status[NOTIFIERS.EMAIL] + + @property + def sms_enabled(self) -> bool: + return self.notifiers_status[NOTIFIERS.SMS] + + @property + def slack_enabled(self) -> bool: + return self.notifiers_status[NOTIFIERS.SLACK] + + @property + def app_enabled(self) -> bool: + return self.notifiers_status[NOTIFIERS.APP] + + def validate_email_credentials( + self, + username: str, + password: str, + server: str, + port: int, + ) -> Result[Ok, Err]: + return self.notifiers[NOTIFIERS.EMAIL].check_credentials( + server=server, + port=port, + username=username, + password=password, + ) + + def send_notifications( + self, + context: AuthedServiceContext, + notification: Notification, + ) -> Result[Ok, Err]: + notifier_objs: List = self.select_notifiers(notification) + + for notifier in notifier_objs: + result = notifier.send(context, notification) + if result.err(): + return result + + return Ok("Notification sent successfully!") + + def select_notifiers(self, notification: Notification) -> List[BaseNotifier]: + """ + Return a list of the notifiers enabled for the given notification" + + Args: + notification (Notification): The notification object + Returns: + List[BaseNotifier]: A list of enabled notifier objects + """ + notifier_objs = [] + for notifier_type in notification.notifier_types: + # Check if the notifier is enabled and if it is, create the notifier object + if ( + self.notifiers_status[notifier_type] + and self.notifiers[notifier_type] is not None + ): + # If notifier is email, we need to pass the parameters + if notifier_type == NOTIFIERS.EMAIL: + notifier_objs.append( + self.notifiers[notifier_type]( # type: ignore[misc] + username=self.email_username, + password=self.email_password, + sender=self.email_sender, + server=self.email_server, + ) + ) + # If notifier is not email, we just create the notifier object + # TODO: Add the other notifiers, and its auth methods + else: + notifier_objs.append(self.notifiers[notifier_type]()) # type: ignore[misc] + + return notifier_objs diff --git a/packages/syft/src/syft/service/notifier/notifier_enums.py b/packages/syft/src/syft/service/notifier/notifier_enums.py new file mode 100644 index 00000000000..0cb438b4fb2 --- /dev/null +++ b/packages/syft/src/syft/service/notifier/notifier_enums.py @@ -0,0 +1,14 @@ +# stdlib +from enum import Enum +from enum import auto + +# relative +from ...serde.serializable import serializable + + +@serializable() +class NOTIFIERS(Enum): + EMAIL = auto() + SMS = auto() + SLACK = auto() + APP = auto() diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py new file mode 100644 index 00000000000..bd82aa8acf0 --- /dev/null +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -0,0 +1,310 @@ +# stdlib + +# stdlib +from typing import Optional +from typing import Union +from typing import cast + +# third party +from pydantic import EmailStr +from result import Err +from result import Ok +from result import Result + +# relative +from ...abstract_node import AbstractNode +from ...serde.serializable import serializable +from ...store.document_store import DocumentStore +from ..context import AuthedServiceContext +from ..notification.notifications import Notification +from ..response import SyftError +from ..response import SyftSuccess +from ..service import AbstractService +from .notifier import NotificationPreferences +from .notifier import NotifierSettings +from .notifier_enums import NOTIFIERS +from .notifier_stash import NotifierStash + + +@serializable() +class NotifierService(AbstractService): + store: DocumentStore + stash: NotifierStash # Which stash should we use? + + def __init__(self, store: DocumentStore) -> None: + self.store = store + self.stash = NotifierStash(store=store) + + def settings( # Maybe just notifier.settings + self, + context: AuthedServiceContext, + ) -> Union[NotifierSettings, SyftError]: + """Get Notifier Settings + + Args: + context: The request context + Returns: + Union[NotifierSettings, SyftError]: Notifier Settings or SyftError + """ + result = self.stash.get(credentials=context.credentials) + if result.is_err(): + return SyftError(message="Error getting notifier settings") + + return result.ok() + + def user_settings( + self, + context: AuthedServiceContext, + ) -> NotificationPreferences: + context.node = cast(AbstractNode, context.node) + user_service = context.node.get_service("userservice") + user_view = user_service.get_current_user(context) + notifications = user_view.notifications_enabled + return NotificationPreferences( + email=notifications[NOTIFIERS.EMAIL], + sms=notifications[NOTIFIERS.SMS], + slack=notifications[NOTIFIERS.SLACK], + app=notifications[NOTIFIERS.APP], + ) + + def turn_on( + self, + context: AuthedServiceContext, + email_username: Optional[str] = None, + email_password: Optional[str] = None, + email_sender: Optional[str] = None, + email_server: Optional[str] = None, + email_port: Optional[int] = 587, + ) -> Union[SyftSuccess, SyftError]: + """Turn on email notifications. + + Args: + email_username (Optional[str]): Email server username. Defaults to None. + email_password (Optional[str]): Email email server password. Defaults to None. + sender_email (Optional[str]): Email sender email. Defaults to None. + Returns: + Union[SyftSuccess, SyftError]: A union type representing the success or error response. + + Raises: + None + + """ + + result = self.stash.get(credentials=context.credentials) + + # 1 - If something went wrong at db level, return the error + if result.is_err(): + return SyftError(message=result.err()) + + # 2 - If one of the credentials are set alone, return an error + if ( + email_username + and not email_password + or email_password + and not email_username + ): + return SyftError(message="You must provide both username and password") + + notifier = result.ok() + + # 3 - If notifier doesn't have a email server / port and the user didn't provide them, return an error + if not (email_server and email_port) and not notifier.email_server: + return SyftError( + message="You must provide both server and port to enable notifications." + ) + + print("[LOG] Got notifier from db") + # If no new credentials provided, check for existing ones + if not (email_username and email_password): + if not (notifier.email_username and notifier.email_password): + return SyftError( + message="No valid token has been added to the domain." + + "You can add a pair of SMTP credentials via " + + "<client>.settings.enable_notifications(email=<>, password=<>)" + ) + else: + print("[LOG] No new credentials provided. Using existing ones.") + email_password = notifier.email_password + email_username = notifier.email_username + print("[LOG] Validating credentials...") + + validation_result = notifier.validate_email_credentials( + username=email_username, + password=email_password, + server=email_server if email_server else notifier.email_server, + port=email_port if email_port else notifier.email_port, + ) + + if validation_result.is_err(): + return SyftError( + message="Invalid SMTP credentials. Please check your username and password." + ) + + notifier.email_password = email_password + notifier.email_username = email_username + + if email_server: + notifier.email_server = email_server + if email_port: + notifier.email_port = email_port + + # Email sender verification + if not email_sender and not notifier.email_sender: + return SyftError( + message="You must provide a sender email address to enable notifications." + ) + + if email_sender: + try: + EmailStr._validate(email_sender) + except ValueError: + return SyftError( + message="Invalid sender email address. Please check your email address." + ) + notifier.email_sender = email_sender + + notifier.active = True + print( + "[LOG] Email credentials are valid. Updating the notifier settings in the db." + ) + + result = self.stash.update(credentials=context.credentials, settings=notifier) + if result.is_err(): + return SyftError(message=result.err()) + return SyftSuccess(message="Notifications enabled successfully.") + + def turn_off( + self, + context: AuthedServiceContext, + ) -> Union[SyftSuccess, SyftError]: + """ + Turn off email notifications service. + PySyft notifications will still work. + """ + + result = self.stash.get(credentials=context.credentials) + + if result.is_err(): + return SyftError(message=result.err()) + + notifier = result.ok() + notifier.active = False + result = self.stash.update(credentials=context.credentials, settings=notifier) + if result.is_err(): + return SyftError(message=result.err()) + return SyftSuccess(message="Notifications disabled succesfullly") + + def activate( + self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL + ) -> Union[SyftSuccess, SyftError]: + """ + Activate email notifications for the authenticated user. + This will only work if the domain owner has enabled notifications. + """ + context.node = cast(AbstractNode, context.node) + user_service = context.node.get_service("userservice") + return user_service.enable_notifications(context, notifier_type=notifier_type) + + def deactivate( + self, context: AuthedServiceContext, notifier_type: NOTIFIERS = NOTIFIERS.EMAIL + ) -> Union[SyftSuccess, SyftError]: + """Deactivate email notifications for the authenticated user + This will only work if the domain owner has enabled notifications. + """ + context.node = cast(AbstractNode, context.node) + user_service = context.node.get_service("userservice") + return user_service.disable_notifications(context, notifier_type=notifier_type) + + @staticmethod + def init_notifier( + node: AbstractNode, + email_username: Optional[str] = None, + email_password: Optional[str] = None, + email_sender: Optional[str] = None, + smtp_port: Optional[str] = None, + smtp_host: Optional[str] = None, + ) -> Result[Ok, Err]: + """Initialize Notifier settings for a Node. + If settings already exist, it will use the existing one. + If not, it will create a new one. + + Args: + node: Node to initialize the notifier + active: If notifier should be active + email_username: Email username to send notifications + email_password: Email password to send notifications + Raises: + Exception: If something went wrong + Returns: + Union: SyftSuccess or SyftError + """ + try: + # Create a new NotifierStash since its a static method. + notifier_stash = NotifierStash(store=node.document_store) + result = notifier_stash.get(node.signing_key.verify_key) + if result.is_err(): + raise Exception(f"Could not create notifier: {result}") + + # Get the notifier + notifier = result.ok() + # If notifier doesn't exist, create a new one + + if not notifier: + notifier = NotifierSettings() + notifier.active = False # Default to False + + # TODO: this should be a method in NotifierSettings + if email_username and email_password: + validation_result = notifier.validate_email_credentials( + username=email_username, + password=email_password, + server=smtp_host, + port=smtp_port, + ) + + sender_not_set = not email_sender and not notifier.email_sender + if validation_result.is_err() or sender_not_set: + print( + "Ops something went wrong while trying to setup your notification system.", + "Please check your credentials and configuration.", + ) + notifier.active = False + else: + notifier.email_password = email_password + notifier.email_username = email_username + notifier.email_sender = email_sender + notifier.email_server = smtp_host + notifier.email_port = int(smtp_port) + notifier.active = True + + notifier_stash.set(node.signing_key.verify_key, notifier) + return Ok("Notifier initialized successfully") + + except Exception as e: + raise Exception(f"Error initializing notifier. \n {e}") + + # This is not a public API. + # This method is used by other services to dispatch notifications internally + def dispatch_notification( + self, context: AuthedServiceContext, notification: Notification + ) -> Union[SyftError]: + context.node = cast(AbstractNode, context.node) + admin_key = context.node.get_service("userservice").admin_verify_key() + notifier = self.stash.get(admin_key) + if notifier.is_err(): + return SyftError( + message="The mail service ran out of quota or some notifications failed to be delivered.\n" + + "Please check the health of the mailing server." + ) + + notifier = notifier.ok() + # If notifier is active + if notifier.active: + resp = notifier.send_notifications( + context=context, notification=notification + ) + if resp.is_err(): + return SyftError(message=resp.err()) + + # If notifier isn't active, return None + return SyftSuccess(message="Notifications dispatched successfully") diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py new file mode 100644 index 00000000000..e382900f226 --- /dev/null +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -0,0 +1,84 @@ +# stdlib +from typing import List +from typing import Optional + +# third party +from result import Err +from result import Ok +from result import Result + +# relative +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...store.document_store import BaseStash +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionKey +from ...store.document_store import PartitionSettings +from ...types.uid import UID +from ...util.telemetry import instrument +from ..action.action_permissions import ActionObjectPermission +from .notifier import NotifierSettings + +NamePartitionKey = PartitionKey(key="name", type_=str) +ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=List[UID]) + + +@instrument +@serializable() +class NotifierStash(BaseStash): + object_type = NotifierSettings + settings: PartitionSettings = PartitionSettings( + name=NotifierSettings.__canonical_name__, object_type=NotifierSettings + ) + + def __init__(self, store: DocumentStore) -> None: + super().__init__(store=store) + + def admin_verify_key(self) -> SyftVerifyKey: + return self.partition.root_verify_key + + # TODO: should this method behave like a singleton? + def get(self, credentials: SyftVerifyKey) -> Result[NotifierSettings, Err]: + """Get Settings""" + result = self.get_all(credentials) + if result.is_ok(): + settings = result.ok() + if len(settings) == 0: + return Ok( + None + ) # TODO: Stash shouldn't be empty after init. Return Err instead? + result = settings[ + 0 + ] # TODO: Should we check if theres more than one? => Report corruption + return Ok(result) + else: + return Err(message=result.err()) + + def set( + self, + credentials: SyftVerifyKey, + settings: NotifierSettings, + add_permissions: Optional[List[ActionObjectPermission]] = None, + ignore_duplicates: bool = False, + ) -> Result[NotifierSettings, Err]: + result = self.check_type(settings, self.object_type) + # we dont use and_then logic here as it is hard because of the order of the arguments + if result.is_err(): + return Err(message=result.err()) + return super().set( + credentials=credentials, obj=result.ok() + ) # TODO check if result isInstance(Ok) + + def update( + self, + credentials: SyftVerifyKey, + settings: NotifierSettings, + has_permission: bool = False, + ) -> Result[NotifierSettings, Err]: + result = self.check_type(settings, self.object_type) + # we dont use and_then logic here as it is hard because of the order of the arguments + if result.is_err(): + return Err(message=result.err()) + return super().update( + credentials=credentials, obj=result.ok() + ) # TODO check if result isInstance(Ok) diff --git a/packages/syft/src/syft/service/notifier/smtp_client.py b/packages/syft/src/syft/service/notifier/smtp_client.py new file mode 100644 index 00000000000..1f4df6531e5 --- /dev/null +++ b/packages/syft/src/syft/service/notifier/smtp_client.py @@ -0,0 +1,70 @@ +# stdlib +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +import smtplib + +# third party +from result import Err +from result import Ok +from result import Result + + +class SMTPClient: + SOCKET_TIMEOUT = 5 # seconds + + def __init__( + self, + server: str, + port: int, + username: str, + password: str, + ) -> None: + if not (username and password): + raise ValueError("Both username and password must be provided") + + self.username = username + self.password = password + self.server = server + self.port = port + + def send(self, sender: str, receiver: list[str], subject: str, body: str) -> None: + if not (subject and body and receiver): + raise ValueError("Subject, body, and recipient email(s) are required") + + msg = MIMEMultipart("alternative") + msg["From"] = sender + msg["To"] = ", ".join(receiver) + msg["Subject"] = subject + msg.attach(MIMEText(body, "html")) + + with smtplib.SMTP( + self.server, self.port, timeout=self.SOCKET_TIMEOUT + ) as server: + server.ehlo() + if server.has_extn("STARTTLS"): + server.starttls() + server.ehlo() + server.login(self.username, self.password) + text = msg.as_string() + server.sendmail(sender, ", ".join(receiver), text) + # TODO: Add error handling + + @classmethod + def check_credentials( + cls, server: str, port: int, username: str, password: str + ) -> Result[Ok, Err]: + """Check if the credentials are valid. + + Returns: + bool: True if the credentials are valid, False otherwise. + """ + try: + with smtplib.SMTP(server, port, timeout=cls.SOCKET_TIMEOUT) as smtp_server: + smtp_server.ehlo() + if smtp_server.has_extn("STARTTLS"): + smtp_server.starttls() + smtp_server.ehlo() + smtp_server.login(username, password) + return Ok("Credentials are valid.") + except Exception as e: + return Err(e) diff --git a/packages/syft/src/syft/service/object_search/migration_state_service.py b/packages/syft/src/syft/service/object_search/migration_state_service.py index fefb28dc60c..c16360a4354 100644 --- a/packages/syft/src/syft/service/object_search/migration_state_service.py +++ b/packages/syft/src/syft/service/object_search/migration_state_service.py @@ -15,7 +15,7 @@ @serializable() class MigrateStateService(AbstractService): store: DocumentStore - stash: SyftObjectMigrationState + stash: SyftMigrationStateStash def __init__(self, store: DocumentStore) -> None: self.store = store diff --git a/packages/syft/src/syft/service/object_search/object_migration_state.py b/packages/syft/src/syft/service/object_search/object_migration_state.py index 686c1ccb8fd..e6bab0fb8b3 100644 --- a/packages/syft/src/syft/service/object_search/object_migration_state.py +++ b/packages/syft/src/syft/service/object_search/object_migration_state.py @@ -33,7 +33,7 @@ def latest_version(self) -> Optional[int]: available_versions = SyftMigrationRegistry.get_versions( canonical_name=self.canonical_name, ) - if available_versions is None: + if not available_versions: return None return sorted(available_versions, reverse=True)[0] @@ -62,6 +62,7 @@ def set( credentials: SyftVerifyKey, migration_state: SyftObjectMigrationState, add_permissions: Optional[List[ActionObjectPermission]] = None, + ignore_duplicates: bool = False, ) -> Result[SyftObjectMigrationState, str]: res = self.check_type(migration_state, self.object_type) # we dont use and_then logic here as it is hard because of the order of the arguments diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py new file mode 100644 index 00000000000..7e0a190b366 --- /dev/null +++ b/packages/syft/src/syft/service/output/output_service.py @@ -0,0 +1,275 @@ +# stdlib +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import List +from typing import Optional +from typing import Type +from typing import Union + +# third party +from pydantic import model_validator +from result import Result + +# relative +from ...client.api import APIRegistry +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...store.document_store import BaseUIDStoreStash +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionKey +from ...store.document_store import PartitionSettings +from ...store.document_store import QueryKeys +from ...store.linked_obj import LinkedObject +from ...types.datetime import DateTime +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ...util.telemetry import instrument +from ..action.action_object import ActionObject +from ..context import AuthedServiceContext +from ..response import SyftError +from ..service import AbstractService +from ..service import service_method +from ..user.user_roles import GUEST_ROLE_LEVEL + +CreatedAtPartitionKey = PartitionKey(key="created_at", type_=DateTime) +UserCodeIdPartitionKey = PartitionKey(key="user_code_id", type_=UID) +OutputPolicyIdPartitionKey = PartitionKey(key="output_policy_id", type_=UID) + + +@serializable() +class ExecutionOutput(SyftObject): + __canonical_name__ = "ExecutionOutput" + __version__ = SYFT_OBJECT_VERSION_1 + + executing_user_verify_key: SyftVerifyKey + user_code_link: LinkedObject + output_ids: Optional[Union[List[UID], Dict[str, UID]]] = None + job_link: Optional[LinkedObject] = None + created_at: DateTime = DateTime.now() + + # Required for __attr_searchable__, set by model_validator + user_code_id: UID + + # Output policy is not a linked object because its saved on the usercode + output_policy_id: Optional[UID] = None + + __attr_searchable__: ClassVar[List[str]] = [ + "user_code_id", + "created_at", + "output_policy_id", + ] + __repr_attrs__: ClassVar[List[str]] = [ + "created_at", + "user_code_id", + "job_id", + "output_ids", + ] + + @model_validator(mode="before") + @classmethod + def add_user_code_id(cls, values: dict) -> dict: + if "user_code_link" in values: + values["user_code_id"] = values["user_code_link"].object_uid + return values + + @classmethod + def from_ids( + cls: Type["ExecutionOutput"], + output_ids: Union[UID, List[UID], Dict[str, UID]], + user_code_id: UID, + executing_user_verify_key: SyftVerifyKey, + node_uid: UID, + job_id: Optional[UID] = None, + output_policy_id: Optional[UID] = None, + ) -> "ExecutionOutput": + # relative + from ..code.user_code_service import UserCode + from ..code.user_code_service import UserCodeService + from ..job.job_service import Job + from ..job.job_service import JobService + + if isinstance(output_ids, UID): + output_ids = [output_ids] + + user_code_link = LinkedObject.from_uid( + object_uid=user_code_id, + object_type=UserCode, + service_type=UserCodeService, + node_uid=node_uid, + ) + + if job_id: + job_link = LinkedObject.from_uid( + object_uid=job_id, + object_type=Job, + service_type=JobService, + node_uid=node_uid, + ) + else: + job_link = None + return cls( + output_ids=output_ids, + user_code_link=user_code_link, + job_link=job_link, + executing_user_verify_key=executing_user_verify_key, + output_policy_id=output_policy_id, + ) + + @property + def outputs(self) -> Optional[Union[List[ActionObject], Dict[str, ActionObject]]]: + api = APIRegistry.api_for( + node_uid=self.syft_node_location, + user_verify_key=self.syft_client_verify_key, + ) + if api is None: + raise ValueError( + f"Can't access the api. Please log in to {self.syft_node_location}" + ) + action_service = api.services.action + + # TODO: error handling for action_service.get + if isinstance(self.output_ids, dict): + return {k: action_service.get(v) for k, v in self.output_ids.items()} + elif isinstance(self.output_ids, list): + return [action_service.get(v) for v in self.output_ids] + else: + return None + + @property + def output_id_list(self) -> List[UID]: + ids = self.output_ids + if isinstance(ids, dict): + return list(ids.values()) + elif isinstance(ids, list): + return ids + return [] + + @property + def job_id(self) -> Optional[UID]: + return self.job_link.object_uid if self.job_link else None + + def get_sync_dependencies(self, api: Any = None) -> List[UID]: + # Output ids, user code id, job id + res = [] + + res.extend(self.output_id_list) + res.append(self.user_code_id) + if self.job_id: + res.append(self.job_id) + + return res + + +@instrument +@serializable() +class OutputStash(BaseUIDStoreStash): + object_type = ExecutionOutput + settings: PartitionSettings = PartitionSettings( + name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput + ) + + def __init__(self, store: DocumentStore) -> None: + super().__init__(store) + self.store = store + self.settings = self.settings + self._object_type = self.object_type + + def get_by_user_code_id( + self, credentials: SyftVerifyKey, user_code_id: UID + ) -> Result[List[ExecutionOutput], str]: + qks = QueryKeys( + qks=[UserCodeIdPartitionKey.with_obj(user_code_id)], + ) + return self.query_all( + credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + ) + + def get_by_output_policy_id( + self, credentials: SyftVerifyKey, output_policy_id: UID + ) -> Result[List[ExecutionOutput], str]: + qks = QueryKeys( + qks=[OutputPolicyIdPartitionKey.with_obj(output_policy_id)], + ) + return self.query_all( + credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + ) + + +@instrument +@serializable() +class OutputService(AbstractService): + store: DocumentStore + stash: OutputStash + + def __init__(self, store: DocumentStore): + self.store = store + self.stash = OutputStash(store=store) + + @service_method( + path="output.create", + name="create", + roles=GUEST_ROLE_LEVEL, + ) + def create( + self, + context: AuthedServiceContext, + user_code_id: UID, + output_ids: Union[UID, List[UID], Dict[str, UID]], + executing_user_verify_key: SyftVerifyKey, + job_id: Optional[UID] = None, + output_policy_id: Optional[UID] = None, + ) -> Union[ExecutionOutput, SyftError]: + output = ExecutionOutput.from_ids( + output_ids=output_ids, + user_code_id=user_code_id, + executing_user_verify_key=executing_user_verify_key, + node_uid=context.node.id, # type: ignore + job_id=job_id, + output_policy_id=output_policy_id, + ) + + res = self.stash.set(context.credentials, output) + return res + + @service_method( + path="output.get_by_user_code_id", + name="get_by_user_code_id", + roles=GUEST_ROLE_LEVEL, + ) + def get_by_user_code_id( + self, context: AuthedServiceContext, user_code_id: UID + ) -> Union[List[ExecutionOutput], SyftError]: + result = self.stash.get_by_user_code_id( + credentials=context.node.verify_key, # type: ignore + user_code_id=user_code_id, + ) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) + + @service_method( + path="output.get_by_output_policy_id", + name="get_by_output_policy_id", + roles=GUEST_ROLE_LEVEL, + ) + def get_by_output_policy_id( + self, context: AuthedServiceContext, output_policy_id: UID + ) -> Union[List[ExecutionOutput], SyftError]: + result = self.stash.get_by_output_policy_id( + credentials=context.node.verify_key, # type: ignore + output_policy_id=output_policy_id, # type: ignore + ) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) + + @service_method(path="output.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) + def get_all( + self, context: AuthedServiceContext + ) -> Union[List[ExecutionOutput], SyftError]: + result = self.stash.get_all(context.credentials) + if result.is_ok(): + return result.ok() + return SyftError(message=result.err()) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index e0bcc0d0356..745abf8daef 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -28,6 +28,7 @@ # relative from ...abstract_node import AbstractNode from ...abstract_node import NodeType +from ...client.api import APIRegistry from ...client.api import NodeIdentity from ...node.credentials import SyftVerifyKey from ...serde.recursive_primitives import recursive_serde_register_type @@ -35,6 +36,7 @@ from ...store.document_store import PartitionKey from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import generate_id @@ -70,7 +72,7 @@ def extract_uid(v: Any) -> UID: return value -def filter_only_uids(results: Any) -> Union[list, dict]: +def filter_only_uids(results: Any) -> Union[list[UID], dict[str, UID], UID]: if not hasattr(results, "__len__"): results = [results] @@ -89,7 +91,7 @@ def filter_only_uids(results: Any) -> Union[list, dict]: class Policy(SyftObject): # version - __canonical_name__ = "Policy" + __canonical_name__: str = "Policy" __version__ = SYFT_OBJECT_VERSION_1 id: UID @@ -115,13 +117,12 @@ def policy_code(self) -> str: op_code += "\n" return op_code + def is_valid(self, *args: List, **kwargs: Dict) -> Union[SyftSuccess, SyftError]: # type: ignore + return SyftSuccess(message="Policy is valid.") + def public_state(self) -> Any: raise NotImplementedError - @property - def valid(self) -> Union[SyftSuccess, SyftError]: - return SyftSuccess(message="Policy is valid.") - @serializable() class UserPolicyStatus(Enum): @@ -130,7 +131,7 @@ class UserPolicyStatus(Enum): APPROVED = "approved" -def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]: +def partition_by_node(kwargs: Dict[str, Any]) -> dict[NodeIdentity, dict[str, UID]]: # relative from ...client.api import APIRegistry from ...client.api import NodeIdentity @@ -262,7 +263,7 @@ def retrieve_from_db( def allowed_ids_only( - allowed_inputs: Dict[str, UID], + allowed_inputs: dict[NodeIdentity, Any], kwargs: Dict[str, Any], context: AuthedServiceContext, ) -> Dict[str, UID]: @@ -324,7 +325,7 @@ class OutputHistory(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 output_time: DateTime - outputs: Optional[Union[List[UID], Dict[str, UID]]] + outputs: Optional[Union[List[UID], Dict[str, UID]]] = None executing_user_verify_key: SyftVerifyKey @@ -333,9 +334,8 @@ class OutputPolicy(Policy): __canonical_name__ = "OutputPolicy" __version__ = SYFT_OBJECT_VERSION_1 - output_history: List[OutputHistory] = [] output_kwargs: List[str] = [] - node_uid: Optional[UID] + node_uid: Optional[UID] = None output_readers: List[SyftVerifyKey] = [] def apply_output( @@ -343,54 +343,69 @@ def apply_output( context: NodeServiceContext, outputs: Any, ) -> Any: - output_uids: Union[Dict[str, Any], list] = filter_only_uids(outputs) - if isinstance(output_uids, UID): - output_uids = [output_uids] - history = OutputHistory( - output_time=DateTime.now(), - outputs=output_uids, - executing_user_verify_key=context.credentials, - ) - self.output_history.append(history) + # output_uids: Union[Dict[str, Any], list] = filter_only_uids(outputs) + # if isinstance(output_uids, UID): + # output_uids = [output_uids] + # history = OutputHistory( + # output_time=DateTime.now(), + # outputs=output_uids, + # executing_user_verify_key=context.credentials, + # ) + # self.output_history.append(history) + return outputs - @property - def outputs(self) -> List[str]: - return self.output_kwargs - - @property - def last_output_ids(self) -> Optional[Union[List[UID], Dict[str, UID]]]: - return self.output_history[-1].outputs + def is_valid(self, context: AuthedServiceContext) -> Union[SyftSuccess, SyftError]: # type: ignore + raise NotImplementedError() @serializable() class OutputPolicyExecuteCount(OutputPolicy): __canonical_name__ = "OutputPolicyExecuteCount" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 - count: int = 0 limit: int - def apply_output( - self, - context: NodeServiceContext, - outputs: Any, - ) -> Optional[Any]: - if self.count < self.limit: - super().apply_output(context, outputs) - self.count += 1 - return outputs - return None + @property + def count(self) -> Union[SyftError, int]: + api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + if api is None: + raise ValueError( + f"api is None. You must login to {self.syft_node_location}" + ) + output_history = api.services.output.get_by_output_policy_id(self.id) + + if isinstance(output_history, SyftError): + return output_history + return len(output_history) @property - def valid(self) -> Union[SyftSuccess, SyftError]: - is_valid = self.count < self.limit + def is_valid(self) -> Union[SyftSuccess, SyftError]: # type: ignore + execution_count = self.count + is_valid = execution_count < self.limit if is_valid: return SyftSuccess( - message=f"Policy is still valid. count: {self.count} < limit: {self.limit}" + message=f"Policy is still valid. count: {execution_count} < limit: {self.limit}" ) return SyftError( - message=f"Policy is no longer valid. count: {self.count} >= limit: {self.limit}" + message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}" + ) + + def _is_valid(self, context: AuthedServiceContext) -> Union[SyftSuccess, SyftError]: + context.node = cast(AbstractNode, context.node) + output_service = context.node.get_service("outputservice") + output_history = output_service.get_by_output_policy_id(context, self.id) + if isinstance(output_history, SyftError): + return output_history + execution_count = len(output_history) + + is_valid = execution_count < self.limit + if is_valid: + return SyftSuccess( + message=f"Policy is still valid. count: {execution_count} < limit: {self.limit}" + ) + return SyftError( + message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}" ) def public_state(self) -> dict[str, int]: @@ -400,7 +415,7 @@ def public_state(self) -> dict[str, int]: @serializable() class OutputPolicyExecuteOnce(OutputPolicyExecuteCount): __canonical_name__ = "OutputPolicyExecuteOnce" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 limit: int = 1 @@ -451,11 +466,11 @@ class CustomInputPolicy(metaclass=CustomPolicy): @serializable() class UserPolicy(Policy): - __canonical_name__ = "UserPolicy" + __canonical_name__: str = "UserPolicy" __version__ = SYFT_OBJECT_VERSION_1 id: UID - node_uid: Optional[UID] + node_uid: Optional[UID] = None user_verify_key: SyftVerifyKey raw_code: str parsed_code: str @@ -463,7 +478,6 @@ class UserPolicy(Policy): class_name: str unique_name: str code_hash: str - byte_code: PyCodeObject status: UserPolicyStatus = UserPolicyStatus.SUBMITTED # TODO: fix the mypy issue @@ -525,7 +539,7 @@ class SubmitUserPolicy(Policy): __canonical_name__ = "SubmitUserPolicy" __version__ = SYFT_OBJECT_VERSION_1 - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] code: str class_name: str input_kwargs: List[str] @@ -545,20 +559,27 @@ def from_obj(policy_obj: CustomPolicy) -> SubmitUserPolicy: def hash_code(context: TransformContext) -> TransformContext: + if context.output is None: + return context code = context.output["code"] del context.output["code"] context.output["raw_code"] = code code_hash = hashlib.sha256(code.encode("utf8")).hexdigest() context.output["code_hash"] = code_hash + return context def generate_unique_class_name(context: TransformContext) -> TransformContext: # TODO: Do we need to check if the initial name contains underscores? - code_hash = context.output["code_hash"] - service_class_name = context.output["class_name"] - unique_name = f"{service_class_name}_{context.credentials}_{code_hash}" - context.output["unique_name"] = unique_name + if context.output is not None: + code_hash = context.output["code_hash"] + service_class_name = context.output["class_name"] + unique_name = f"{service_class_name}_{context.credentials}_{code_hash}" + context.output["unique_name"] = unique_name + else: + print("f{context}'s output is None. No trasformation happened.") + return context @@ -657,6 +678,9 @@ def check_class_code(context: TransformContext) -> TransformContext: # check for Policy template -> __init__, apply_output, public_state # parse init signature # check dangerous libraries, maybe compile_restricted already does that + if context.output is None: + return context + try: processed_code = process_class_code( raw_code=context.output["raw_code"], @@ -665,34 +689,45 @@ def check_class_code(context: TransformContext) -> TransformContext: context.output["parsed_code"] = processed_code except Exception as e: raise e + return context def compile_code(context: TransformContext) -> TransformContext: - byte_code = compile_byte_code(context.output["parsed_code"]) - if byte_code is None: - raise Exception( - "Unable to compile byte code from parsed code. " - + context.output["parsed_code"] - ) + if context.output is not None: + byte_code = compile_byte_code(context.output["parsed_code"]) + if byte_code is None: + raise Exception( + "Unable to compile byte code from parsed code. " + + context.output["parsed_code"] + ) + else: + print("f{context}'s output is None. No trasformation happened.") + return context def add_credentials_for_key(key: str) -> Callable: def add_credentials(context: TransformContext) -> TransformContext: - context.output[key] = context.credentials + if context.output is not None: + context.output[key] = context.credentials + return context return add_credentials def generate_signature(context: TransformContext) -> TransformContext: + if context.output is None: + return context + params = [ Parameter(name=k, kind=Parameter.POSITIONAL_OR_KEYWORD) for k in context.output["input_kwargs"] ] sig = Signature(parameters=params) context.output["signature"] = sig + return context @@ -768,5 +803,6 @@ def load_policy_code(user_policy: UserPolicy) -> Any: def init_policy(user_policy: UserPolicy, init_args: Dict[str, Any]) -> Any: policy_class = load_policy_code(user_policy) policy_object = policy_class() + init_args = {k: v for k, v in init_args.items() if k != "id"} policy_object.__user_init__(**init_args) return policy_object diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 4f55bc444f5..41388d27080 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -18,8 +18,8 @@ from typing import Union # third party -import pydantic -from pydantic import validator +from pydantic import Field +from pydantic import field_validator from rich.progress import Progress from typing_extensions import Self @@ -58,6 +58,7 @@ from ..response import SyftInfo from ..response import SyftNotReady from ..response import SyftSuccess +from ..user.user import UserView @serializable() @@ -78,17 +79,17 @@ class ProjectEvent(SyftObject): # 1. Creation attrs id: UID - timestamp: DateTime + timestamp: DateTime = Field(default_factory=DateTime.now) allowed_sub_types: Optional[List] = [] # 2. Rebase attrs - project_id: Optional[UID] - seq_no: Optional[int] - prev_event_uid: Optional[UID] - prev_event_hash: Optional[str] - event_hash: Optional[str] + project_id: Optional[UID] = None + seq_no: Optional[int] = None + prev_event_uid: Optional[UID] = None + prev_event_hash: Optional[str] = None + event_hash: Optional[str] = None # 3. Signature attrs - creator_verify_key: Optional[SyftVerifyKey] - signature: Optional[bytes] # dont use in signing + creator_verify_key: Optional[SyftVerifyKey] = None + signature: Optional[bytes] = None # dont use in signing def __repr_syft_nested__(self) -> tuple[str, str]: return ( @@ -96,12 +97,6 @@ def __repr_syft_nested__(self) -> tuple[str, str]: f"{str(self.id)[:4]}...{str(self.id)[-3:]}", ) - @pydantic.root_validator(pre=True) - def make_timestamp(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if "timestamp" not in values or values["timestamp"] is None: - values["timestamp"] = DateTime.now() - return values - def _pre_add_update(self, project: Project) -> None: pass @@ -146,7 +141,7 @@ def valid_descendant( return valid if prev_event: - prev_event_id = prev_event.id + prev_event_id: Optional[UID] = prev_event.id prev_event_hash = prev_event.event_hash prev_seq_no = prev_event.seq_no else: @@ -277,8 +272,9 @@ class ProjectRequest(ProjectEventAddObject): linked_request: LinkedObject allowed_sub_types: List[Type] = [ProjectRequestResponse] - @validator("linked_request", pre=True) - def _validate_linked_request(cls, v: Any) -> Union[Request, LinkedObject]: + @field_validator("linked_request", mode="before") + @classmethod + def _validate_linked_request(cls, v: Any) -> LinkedObject: if isinstance(v, Request): linked_request = LinkedObject.from_obj(v, node_uid=v.node_uid) return linked_request @@ -295,16 +291,16 @@ def request(self) -> Request: __repr_attrs__ = [ "request.status", - "request.changes[-1].link.service_func_name", + "request.changes[-1].code.service_func_name", ] - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: func_name = None if len(self.request.changes) > 0: - func_name = self.request.changes[-1].link.service_func_name + func_name = self.request.changes[-1].code.service_func_name repr_dict = { "request.status": self.request.status, - "request.changes[-1].link.service_func_name": func_name, + "request.changes[-1].code.service_func_name": func_name, } return markdown_as_class_with_fields(self, repr_dict) @@ -332,13 +328,12 @@ def status(self, project: Project) -> Optional[Union[SyftInfo, SyftError]]: During Request status calculation, we do not allow multiple responses """ - responses = project.get_children(self) + responses: list[ProjectEvent] = project.get_children(self) if len(responses) == 0: return SyftInfo( "No one has responded to the request yet. Kindly recheck later 🙂" ) - - if len(responses) > 1: + elif len(responses) > 1: return SyftError( message="The Request Contains more than one Response" "which is currently not possible" @@ -348,7 +343,7 @@ def status(self, project: Project) -> Optional[Union[SyftInfo, SyftError]]: ) response = responses[0] if not isinstance(response, ProjectRequestResponse): - return SyftError( + return SyftError( # type: ignore[unreachable] message=f"Response : {type(response)} is not of type ProjectRequestResponse" ) @@ -557,8 +552,9 @@ class ProjectMultipleChoicePoll(ProjectEventAddObject): choices: List[str] allowed_sub_types: List[Type] = [AnswerProjectPoll] - @validator("choices") - def choices_min_length(cls, v: str) -> str: + @field_validator("choices") + @classmethod + def choices_min_length(cls, v: list[str]) -> list[str]: if len(v) < 1: raise ValueError("choices must have at least one item") return v @@ -587,7 +583,7 @@ def status( respondents = {} for poll_answer in poll_answers[::-1]: if not isinstance(poll_answer, AnswerProjectPoll): - return SyftError( + return SyftError( # type: ignore[unreachable] message=f"Poll answer: {type(poll_answer)} is not of type AnswerProjectPoll" ) creator_verify_key = poll_answer.creator_verify_key @@ -627,7 +623,7 @@ def __hash__(self) -> int: def add_code_request_to_project( project: Union[ProjectSubmit, Project], code: SubmitUserCode, - client: SyftClient, + client: Union[SyftClient, Any], reason: Optional[str] = None, ) -> Union[SyftError, SyftSuccess]: # TODO: fix the mypy issue @@ -681,14 +677,14 @@ class Project(SyftObject): "event_id_hashmap", ] - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] name: str - description: Optional[str] + description: Optional[str] = None members: List[NodeIdentity] users: List[UserIdentity] = [] - username: Optional[str] + username: Optional[str] = None created_by: str - start_hash: Optional[str] + start_hash: Optional[str] = None # WARNING: Do not add it to hash keys or print directly user_signing_key: Optional[SyftSigningKey] = None @@ -698,7 +694,7 @@ class Project(SyftObject): # Project sync state_sync_leader: NodeIdentity - leader_node_peer: Optional[NodePeer] + leader_node_peer: Optional[NodePeer] = None # Unused consensus_model: ConsensusModel @@ -998,7 +994,7 @@ def reply_message( reply_event: Union[ProjectMessage, ProjectThreadMessage] if isinstance(message, ProjectMessage): reply_event = message.reply(reply) - elif isinstance(message, ProjectThreadMessage): + elif isinstance(message, ProjectThreadMessage): # type: ignore[unreachable] reply_event = ProjectThreadMessage( message=reply, parent_event_id=message.parent_event_id ) @@ -1043,7 +1039,7 @@ def answer_poll( poll = self.event_id_hashmap[poll] if not isinstance(poll, ProjectMultipleChoicePoll): - return SyftError( + return SyftError( # type: ignore[unreachable] message=f"You can only reply to a poll: {type(poll)}" "Kindly re-check the poll" ) @@ -1087,7 +1083,7 @@ def approve_request( if isinstance(request_event, SyftError): return request_event else: - return SyftError( + return SyftError( # type: ignore[unreachable] message=f"You can only approve a request: {type(request)}" "Kindly re-check the request" ) @@ -1172,19 +1168,19 @@ class ProjectSubmit(SyftObject): # Init args name: str - description: Optional[str] + description: Optional[str] = None members: Union[List[SyftClient], List[NodeIdentity]] # These will be automatically populated users: List[UserIdentity] = [] created_by: Optional[str] = None - username: Optional[str] + username: Optional[str] = None clients: List[SyftClient] = [] # List of member clients start_hash: str = "" # Project sync args - leader_node_route: Optional[NodeRoute] - state_sync_leader: Optional[NodeIdentity] + leader_node_route: Optional[NodeRoute] = None + state_sync_leader: Optional[NodeIdentity] = None bootstrap_events: Optional[List[ProjectEvent]] = [] # Unused at the moment @@ -1209,7 +1205,10 @@ def __init__(self, *args: Any, **kwargs: Any): self.users = [UserIdentity.from_client(client) for client in self.clients] # Assign logged in user name as project creator - self.username = self.clients[0].me.name or "" + if isinstance(self.clients[0].me, UserView): + self.username = self.clients[0].me.name + else: + self.username = "" # Convert SyftClients to NodeIdentities self.members = list(map(self.to_node_identity, self.members)) @@ -1228,7 +1227,8 @@ def _repr_html_(self) -> Any: + "</div>" ) - @validator("members", pre=True) + @field_validator("members", mode="before") + @classmethod def verify_members( cls, val: Union[List[SyftClient], List[NodeIdentity]] ) -> Union[List[SyftClient], List[NodeIdentity]]: @@ -1252,7 +1252,7 @@ def get_syft_clients( def to_node_identity(val: Union[SyftClient, NodeIdentity]) -> NodeIdentity: if isinstance(val, NodeIdentity): return val - elif isinstance(val, SyftClient): + elif isinstance(val, SyftClient) and val.metadata is not None: metadata = val.metadata.to(NodeMetadataV3) return metadata.to(NodeIdentity) else: @@ -1353,19 +1353,21 @@ def add_members_as_owners(members: List[SyftVerifyKey]) -> Set[str]: def elect_leader(context: TransformContext) -> TransformContext: - if len(context.output["members"]) == 0: - raise Exception("Project's require at least one member") - - context.output["state_sync_leader"] = context.output["members"][0] + if context.output is not None: + if len(context.output["members"]) == 0: + raise ValueError("Project's require at least one member") + context.output["state_sync_leader"] = context.output["members"][0] return context def check_permissions(context: TransformContext) -> TransformContext: + if context.output is None: + return context + if len(context.output["members"]) > 1: # more than 1 node pass - # check at least one owner if len(context.output["project_permissions"]) == 0: project_permissions = context.output["project_permissions"] @@ -1378,7 +1380,8 @@ def check_permissions(context: TransformContext) -> TransformContext: def add_creator_name(context: TransformContext) -> TransformContext: - context.output["username"] = context.obj.username + if context.output is not None and context.obj is not None: + context.output["username"] = context.obj.username return context diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index afed1eaefc9..6de6c644259 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -105,10 +105,11 @@ def create_project( verify_key=leader_node.verify_key, ) if peer.is_err(): + this_node_id = context.node.id.short() if context.node.id else "" return SyftError( message=( f"Leader Node(id={leader_node.id.short()}) is not a " - f"peer of this Node(id={context.node.id.short()})" + f"peer of this Node(id={this_node_id})" ) ) leader_node_peer = peer.ok() diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index b72fd11b2ac..fb4eb83cf17 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -5,6 +5,7 @@ from typing import Optional from typing import Type from typing import Union +from typing import cast # third party import psutil @@ -146,6 +147,8 @@ def handle_message_multiprocessing( # this is a temp hack to prevent some multithreading issues time.sleep(0.5) queue_config = worker_settings.queue_config + if queue_config is None: + raise ValueError(f"{worker_settings} has no queue configurations!") queue_config.client_config.create_producer = False queue_config.client_config.n_consumers = 0 @@ -305,7 +308,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: queue_item.node_uid = worker.id job_item.status = JobStatus.PROCESSING - job_item.node_uid = worker.id + job_item.node_uid = cast(UID, worker.id) job_item.updated_at = DateTime.now() # try: diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index e98641dc83a..669507fb463 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -1,7 +1,6 @@ # stdlib from enum import Enum from typing import Any -from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -22,13 +21,9 @@ from ...store.document_store import QueryKeys from ...store.document_store import UIDPartitionKey from ...store.linked_obj import LinkedObject -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject -from ...types.transforms import drop -from ...types.transforms import make_set_default from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission @@ -48,38 +43,6 @@ class Status(str, Enum): StatusPartitionKey = PartitionKey(key="status", type_=Status) -@serializable() -class QueueItemV1(SyftObject): - __canonical_name__ = "QueueItem" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - node_uid: UID - result: Optional[Any] - resolved: bool = False - status: Status = Status.CREATED - - -@serializable() -class QueueItemV2(SyftObject): - __canonical_name__ = "QueueItem" - __version__ = SYFT_OBJECT_VERSION_2 - - id: UID - node_uid: UID - result: Optional[Any] - resolved: bool = False - status: Status = Status.CREATED - - method: str - service: str - args: List - kwargs: Dict[str, Any] - job_id: Optional[UID] - worker_settings: Optional[WorkerSettings] - has_execute_permissions: bool = False - - @serializable() class QueueItem(SyftObject): __canonical_name__ = "QueueItem" @@ -89,7 +52,7 @@ class QueueItem(SyftObject): id: UID node_uid: UID - result: Optional[Any] + result: Optional[Any] = None resolved: bool = False status: Status = Status.CREATED @@ -97,15 +60,15 @@ class QueueItem(SyftObject): service: str args: List kwargs: Dict[str, Any] - job_id: Optional[UID] - worker_settings: Optional[WorkerSettings] + job_id: Optional[UID] = None + worker_settings: Optional[WorkerSettings] = None has_execute_permissions: bool = False worker_pool: LinkedObject def __repr__(self) -> str: return f"<QueueItem: {self.id}>: {self.status}" - def _repr_markdown_(self) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return f"<QueueItem: {self.id}>: {self.status}" @property @@ -119,45 +82,6 @@ def action(self) -> Union[Any, SyftError]: return SyftError(message="QueueItem not an Action") -@migrate(QueueItem, QueueItemV1) -def downgrade_queueitem_v2_to_v1() -> list[Callable]: - return [ - drop( - [ - "method", - "service", - "args", - "kwargs", - "job_id", - "worker_settings", - "has_execute_permissions", - ] - ), - ] - - -@migrate(QueueItemV1, QueueItem) -def upgrade_queueitem_v1_to_v2() -> list[Callable]: - return [ - make_set_default("method", ""), - make_set_default("service", ""), - make_set_default("args", []), - make_set_default("kwargs", {}), - make_set_default("job_id", None), - make_set_default("worker_settings", None), - make_set_default("has_execute_permissions", False), - ] - - -@serializable() -class ActionQueueItemV1(QueueItemV2): - __canonical_name__ = "ActionQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 - - method: str = "execute" - service: str = "actionservice" - - @serializable() class ActionQueueItem(QueueItem): __canonical_name__ = "ActionQueueItem" diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 215dcdfe5c2..42e62da9421 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -7,7 +7,6 @@ import time from time import sleep from typing import Any -from typing import Callable from typing import DefaultDict from typing import Dict from typing import List @@ -17,7 +16,7 @@ # third party from loguru import logger -from pydantic import validator +from pydantic import field_validator from zmq import Frame from zmq import LINGER from zmq.error import ContextTerminated @@ -30,17 +29,13 @@ from ...service.action.action_object import ActionObject from ...service.context import AuthedServiceContext from ...types.base import SyftBaseModel -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject -from ...types.transforms import drop -from ...types.transforms import make_set_default from ...types.uid import UID from ...util.util import get_queue_address from ..response import SyftError from ..response import SyftSuccess +from ..service import AbstractService from ..worker.worker_pool import ConsumerState from ..worker.worker_stash import WorkerStash from .base_queue import AbstractMessageHandler @@ -120,8 +115,11 @@ class Worker(SyftBaseModel): syft_worker_id: Optional[UID] = None expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) - @validator("syft_worker_id", pre=True, always=True) - def set_syft_worker_id(cls, v: Any, values: Any) -> Union[UID, Any]: + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @field_validator("syft_worker_id", mode="before") + @classmethod + def set_syft_worker_id(cls, v: Any) -> Any: if isinstance(v, str): return UID(v) return v @@ -200,7 +198,7 @@ def close(self) -> None: self._stop.clear() @property - def action_service(self) -> Callable: + def action_service(self) -> AbstractService: if self.auth_context.node is not None: return self.auth_context.node.get_service("ActionService") else: @@ -788,57 +786,19 @@ def alive(self) -> bool: return not self.socket.closed and self.is_producer_alive() -@serializable() -class ZMQClientConfigV1(SyftObject, QueueClientConfig): - __canonical_name__ = "ZMQClientConfig" - __version__ = SYFT_OBJECT_VERSION_1 - - id: Optional[UID] - hostname: str = "127.0.0.1" - - -class ZMQClientConfigV2(SyftObject, QueueClientConfig): - __canonical_name__ = "ZMQClientConfig" - __version__ = SYFT_OBJECT_VERSION_2 - - id: Optional[UID] - hostname: str = "127.0.0.1" - queue_port: Optional[int] = None - # TODO: setting this to false until we can fix the ZMQ - # port issue causing tests to randomly fail - create_producer: bool = False - n_consumers: int = 0 - - @serializable() class ZMQClientConfig(SyftObject, QueueClientConfig): __canonical_name__ = "ZMQClientConfig" __version__ = SYFT_OBJECT_VERSION_3 - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] hostname: str = "127.0.0.1" queue_port: Optional[int] = None # TODO: setting this to false until we can fix the ZMQ # port issue causing tests to randomly fail create_producer: bool = False n_consumers: int = 0 - consumer_service: Optional[str] - - -@migrate(ZMQClientConfig, ZMQClientConfigV1) -def downgrade_zmqclientconfig_v2_to_v1() -> list[Callable]: - return [ - drop(["queue_port", "create_producer", "n_consumers"]), - ] - - -@migrate(ZMQClientConfigV1, ZMQClientConfig) -def upgrade_zmqclientconfig_v1_to_v2() -> list[Callable]: - return [ - make_set_default("queue_port", None), - make_set_default("create_producer", False), - make_set_default("n_consumers", 0), - ] + consumer_service: Optional[str] = None @serializable(attrs=["host"]) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 878f451e8e3..3621d6dc13e 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -29,15 +29,12 @@ from ...serde.serialize import _serialize from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime -from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import add_node_uid_for_key -from ...types.transforms import drop from ...types.transforms import generate_id -from ...types.transforms import make_set_default from ...types.transforms import transform from ...types.twin_object import TwinObject from ...types.uid import LineageID @@ -54,6 +51,7 @@ from ..blob_storage.service import BlobStorageService from ..code.user_code import UserCode from ..code.user_code import UserCodeStatus +from ..code.user_code import UserCodeStatusCollection from ..context import AuthedServiceContext from ..context import ChangeContext from ..job.job_stash import Job @@ -77,10 +75,10 @@ class Change(SyftObject): __canonical_name__ = "Change" __version__ = SYFT_OBJECT_VERSION_1 - linked_obj: Optional[LinkedObject] + linked_obj: Optional[LinkedObject] = None - def is_type(self, type_: type) -> bool: - return (self.linked_obj is not None) and (type_ == self.linked_obj.object_type) + def change_object_is_type(self, type_: type) -> bool: + return self.linked_obj is not None and type_ == self.linked_obj.object_type @serializable() @@ -88,7 +86,7 @@ class ChangeStatus(SyftObject): __canonical_name__ = "ChangeStatus" __version__ = SYFT_OBJECT_VERSION_1 - id: Optional[UID] + id: Optional[UID] = None # type: ignore[assignment] change_id: UID applied: bool = False @@ -113,7 +111,7 @@ def _run( try: if context.node is None: return Err(SyftError(message=f"context {context}'s node is None")) - action_service: ActionService = context.node.get_service(ActionService) + action_service: ActionService = context.node.get_service(ActionService) # type: ignore[assignment] blob_storage_service = context.node.get_service(BlobStorageService) action_store = action_service.store @@ -157,6 +155,9 @@ def _run( permission=self.apply_permission_type, ) if apply: + print( + "ADDING PERMISSION", requesting_permission_action_obj, id_action + ) action_store.add_permission(requesting_permission_action_obj) blob_storage_service.stash.add_permission( requesting_permission_blob_obj @@ -199,7 +200,7 @@ class CreateCustomImageChange(Change): config: WorkerConfig tag: str - registry_uid: Optional[UID] + registry_uid: Optional[UID] = None __repr_attrs__ = ["config", "tag"] @@ -278,8 +279,8 @@ class CreateCustomWorkerPoolChange(Change): pool_name: str num_workers: int - image_uid: Optional[UID] - config: Optional[WorkerConfig] + image_uid: Optional[UID] = None + config: Optional[WorkerConfig] = None __repr_attrs__ = ["pool_name", "num_workers", "image_uid"] @@ -348,9 +349,9 @@ class Request(SyftObject): requesting_user_name: str = "" requesting_user_email: Optional[str] = "" requesting_user_institution: Optional[str] = "" - approving_user_verify_key: Optional[SyftVerifyKey] + approving_user_verify_key: Optional[SyftVerifyKey] = None request_time: DateTime - updated_at: Optional[DateTime] + updated_at: Optional[DateTime] = None node_uid: UID request_hash: str changes: List[Change] @@ -401,8 +402,10 @@ def _repr_html_(self) -> Any: f"outputs are <strong>shared</strong> with the owners of {owners_string} once computed" ) - metadata = api.services.metadata.get_metadata() - node_name = api.node_name.capitalize() if api.node_name is not None else "" + if api is not None: + metadata = api.services.metadata.get_metadata() + node_name = api.node_name.capitalize() if api.node_name is not None else "" + node_type = metadata.node_type.value.capitalize() email_str = ( f"({self.requesting_user_email})" if self.requesting_user_email else "" @@ -425,7 +428,7 @@ def _repr_html_(self) -> Any: {shared_with_line} <p><strong>Status: </strong>{self.status}</p> <p><strong>Requested on: </strong> {node_name} of type <strong> \ - {metadata.node_type.value.capitalize()}</strong></p> + {node_type}</strong></p> <p><strong>Requested by:</strong> {self.requesting_user_name} {email_str} {institution_str}</p> <p><strong>Changes: </strong> {str_changes}</p> </div> @@ -454,6 +457,15 @@ def _coll_repr_(self) -> Dict[str, Union[str, Dict[str, str]]]: "Status": status_badge, } + @property + def code_id(self) -> UID: + for change in self.changes: + if isinstance(change, UserCodeStatusChange): + return change.linked_user_code.object_uid + return SyftError( + message="This type of request does not have code associated with it." + ) + @property def codes(self) -> Any: for change in self.changes: @@ -513,8 +525,13 @@ def approve( self.node_uid, self.syft_client_verify_key, ) + if api is None: + return SyftError(message=f"api is None. You must login to {self.node_uid}") # TODO: Refactor so that object can also be passed to generate warnings - metadata = api.connection.get_node_metadata(api.signing_key) + if api.connection: + metadata = api.connection.get_node_metadata(api.signing_key) + else: + metadata = None message, is_enclave = None, False is_code_request = not isinstance(self.codes, SyftError) @@ -529,17 +546,20 @@ def approve( if is_enclave: message = "On approval, the result will be released to the enclave." - elif metadata.node_side_type == NodeSideType.HIGH_SIDE.value: + elif metadata and metadata.node_side_type == NodeSideType.HIGH_SIDE.value: message = ( "You're approving a request on " f"{metadata.node_side_type} side {metadata.node_type} " "which may host datasets with private information." ) - if message and metadata.show_warnings and not disable_warnings: + if message and metadata and metadata.show_warnings and not disable_warnings: prompt_warning_message(message=message, confirm=True) print(f"Approving request for domain {api.node_name}") - return api.services.request.apply(self.id, **kwargs) + res = api.services.request.apply(self.id, **kwargs) + # if isinstance(res, SyftSuccess): + + return res def deny(self, reason: str) -> Union[SyftSuccess, SyftError]: """Denies the particular request. @@ -551,6 +571,8 @@ def deny(self, reason: str) -> Union[SyftSuccess, SyftError]: self.node_uid, self.syft_client_verify_key, ) + if api is None: + return SyftError(message=f"api is None. You must login to {self.node_uid}") return api.services.request.undo(uid=self.id, reason=reason) def approve_with_client(self, client: SyftClient) -> Result[SyftSuccess, SyftError]: @@ -624,6 +646,8 @@ def save(self, context: AuthedServiceContext) -> Result[SyftSuccess, SyftError]: def _get_latest_or_create_job(self) -> Union[Job, SyftError]: """Get the latest job for this requests user_code, or creates one if no jobs exist""" api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + if api is None: + return SyftError(message=f"api is None. You must login to {self.node_uid}") job_service = api.services.job existing_jobs = job_service.get_by_user_code_id(self.code.id) @@ -631,12 +655,31 @@ def _get_latest_or_create_job(self) -> Union[Job, SyftError]: return existing_jobs if len(existing_jobs) == 0: + print("Creating job for existing user code") job = job_service.create_job_for_user_code_id(self.code.id) else: + print("returning existing job") + print("setting permission") job = existing_jobs[-1] + res = job_service.add_read_permission_job_for_code_owner(job, self.code) + print(res) + res = job_service.add_read_permission_log_for_code_owner( + job.log_id, self.code + ) + print(res) return job + def _is_action_object_from_job(self, action_object: ActionObject) -> Optional[Job]: # type: ignore + api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + if api is None: + raise ValueError(f"Can't access the api. You must login to {self.node_uid}") + job_service = api.services.job + existing_jobs = job_service.get_by_user_code_id(self.code.id) + for job in existing_jobs: + if job.result and job.result.id == action_object.id: + return job + def accept_by_depositing_result( self, result: Any, force: bool = False ) -> Union[SyftError, SyftSuccess]: @@ -651,6 +694,16 @@ def accept_by_depositing_result( message="JobInfo should not include result. Use sync_job instead." ) result = job_info.result + elif isinstance(result, ActionObject): + # Do not allow accepting a result produced by a Job, + # This can cause an inconsistent Job state + if self._is_action_object_from_job(result): + action_object_job = self._is_action_object_from_job(result) + if action_object_job is not None: + return SyftError( + message=f"This ActionObject is the result of Job {action_object_job.id}, " + f"please use the `Job.info` instead." + ) else: # NOTE result is added at the end of function (once ActionObject is created) job_info = JobInfo( @@ -660,35 +713,41 @@ def accept_by_depositing_result( resolved=True, ) - change = self.changes[0] - if not change.is_type(UserCode): - if change.linked_obj is not None: - raise TypeError( - f"accept_by_depositing_result can only be run on {UserCode} not " - f"{change.linked_obj.object_type}" - ) - else: - raise TypeError( - f"accept_by_depositing_result can only be run on {UserCode}" - ) - if not type(change) == UserCodeStatusChange: + user_code_status_change: UserCodeStatusChange = self.changes[0] + if not user_code_status_change.change_object_is_type(UserCodeStatusCollection): + raise TypeError( + f"accept_by_depositing_result can only be run on {UserCodeStatusCollection} not " + f"{user_code_status_change.linked_obj.object_type}" + ) + if not type(user_code_status_change) == UserCodeStatusChange: raise TypeError( f"accept_by_depositing_result can only be run on {UserCodeStatusChange} not " - f"{type(change)}" + f"{type(user_code_status_change)}" ) api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) if not api: - raise Exception(f"Login to {self.node_uid} first.") - - is_approved = change.approved + raise Exception( + f"No access to Syft API. Please login to {self.node_uid} first." + ) + if api.signing_key is None: + raise ValueError(f"{api}'s signing key is None") + is_approved = user_code_status_change.approved permission_request = self.approve(approve_nested=True) if isinstance(permission_request, SyftError): return permission_request - code = change.linked_obj.resolve - state = code.output_policy + code = user_code_status_change.code + output_history = code.output_history + if isinstance(output_history, SyftError): + return output_history + output_policy = code.output_policy + if isinstance(output_policy, SyftError): + return output_policy + job = self._get_latest_or_create_job() + if isinstance(job, SyftError): + return job # This weird order is due to the fact that state is None before calling approve # we could fix it in a future release @@ -697,41 +756,70 @@ def accept_by_depositing_result( return SyftError( message="Already approved, if you want to force updating the result use force=True" ) - action_obj_id = state.output_history[0].outputs[0] - action_object = ActionObject.from_obj( - result, - id=action_obj_id, - syft_client_verify_key=api.signing_key.verify_key, - syft_node_location=api.node_uid, + # TODO: this should overwrite the output history instead + action_obj_id = output_history[0].output_ids[0] # type: ignore + + if not isinstance(result, ActionObject): + action_object = ActionObject.from_obj( + result, + id=action_obj_id, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) + else: + action_object = result + action_object_is_from_this_node = ( + self.syft_node_location == action_object.syft_node_location ) - blob_store_result = action_object._save_to_blob_storage() - if isinstance(blob_store_result, SyftError): - return blob_store_result - result = api.services.action.set(action_object) - if isinstance(result, SyftError): - return result + if ( + action_object.syft_blob_storage_entry_id is None + or not action_object_is_from_this_node + ): + action_object.reload_cache() + action_object.syft_node_location = self.syft_node_location + action_object.syft_client_verify_key = self.syft_client_verify_key + blob_store_result = action_object._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + result = api.services.action.set(action_object) + if isinstance(result, SyftError): + return result else: - action_object = ActionObject.from_obj( - result, - syft_client_verify_key=api.signing_key.verify_key, - syft_node_location=api.node_uid, - ) - blob_store_result = action_object._save_to_blob_storage() - if isinstance(blob_store_result, SyftError): - return blob_store_result - result = api.services.action.set(action_object) - if isinstance(result, SyftError): - return result - - ctx = AuthedServiceContext(credentials=api.signing_key.verify_key) + if not isinstance(result, ActionObject): + action_object = ActionObject.from_obj( + result, + syft_client_verify_key=api.signing_key.verify_key, + syft_node_location=api.node_uid, + ) + else: + action_object = result - state.apply_output(context=ctx, outputs=result) - policy_state_mutation = ObjectMutation( - linked_obj=change.linked_obj, - attr_name="output_policy", - match_type=True, - value=state, + # TODO: proper check for if actionobject is already uploaded + # we also need this for manualy syncing + action_object_is_from_this_node = ( + self.syft_node_location == action_object.syft_node_location ) + if ( + action_object.syft_blob_storage_entry_id is None + or not action_object_is_from_this_node + ): + action_object.reload_cache() + action_object.syft_node_location = self.syft_node_location + action_object.syft_client_verify_key = self.syft_client_verify_key + blob_store_result = action_object._save_to_blob_storage() + if isinstance(blob_store_result, SyftError): + return blob_store_result + result = api.services.action.set(action_object) + if isinstance(result, SyftError): + return result + + # Do we still need this? + # policy_state_mutation = ObjectMutation( + # linked_obj=user_code_status_change.linked_obj, + # attr_name="output_policy", + # match_type=True, + # value=output_policy, + # ) action_object_link = LinkedObject.from_obj(result, node_uid=self.node_uid) permission_change = ActionStoreChange( @@ -739,7 +827,7 @@ def accept_by_depositing_result( apply_permission_type=ActionPermission.READ, ) - new_changes = [policy_state_mutation, permission_change] + new_changes = [permission_change] result_request = api.services.request.add_changes( uid=self.id, changes=new_changes ) @@ -751,11 +839,20 @@ def accept_by_depositing_result( if isinstance(approved, SyftError): return approved + res = api.services.code.apply_output( + user_code_id=code.id, outputs=result, job_id=job.id + ) + if isinstance(res, SyftError): + return res + job_info.result = action_object - job = self._get_latest_or_create_job() + + existing_result = job.result.id if job.result is not None else None + print( + f"Job({job.id}) Setting new result {existing_result} -> {job_info.result.id}" + ) job.apply_info(job_info) - api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) job_service = api.services.job res = job_service.update(job) if isinstance(res, SyftError): @@ -771,13 +868,28 @@ def sync_job( message="This JobInfo includes a Result. Please use Request.accept_by_depositing_result instead." ) - api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + api = APIRegistry.api_for( + node_uid=self.node_uid, user_verify_key=self.syft_client_verify_key + ) + if api is None: + return SyftError(message=f"api is None. You must login to {self.node_uid}") job_service = api.services.job job = self._get_latest_or_create_job() job.apply_info(job_info) return job_service.update(job) + def get_sync_dependencies(self, api: Any = None) -> Union[List[UID], SyftError]: + dependencies = [] + + code_id = self.code_id + if isinstance(code_id, SyftError): + return code_id + + dependencies.append(code_id) + + return dependencies + @serializable() class RequestInfo(SyftObject): @@ -796,7 +908,7 @@ class RequestInfoFilter(SyftObject): __canonical_name__ = "RequestInfoFilter" __version__ = SYFT_OBJECT_VERSION_1 - name: Optional[str] + name: Optional[str] = None @serializable() @@ -805,10 +917,13 @@ class SubmitRequest(SyftObject): __version__ = SYFT_OBJECT_VERSION_1 changes: List[Change] - requesting_user_verify_key: Optional[SyftVerifyKey] + requesting_user_verify_key: Optional[SyftVerifyKey] = None def hash_changes(context: TransformContext) -> TransformContext: + if context.output is None: + return context + request_time = context.output["request_time"] key = context.output["requesting_user_verify_key"] changes = context.output["changes"] @@ -821,38 +936,44 @@ def hash_changes(context: TransformContext) -> TransformContext: final_hash = hashlib.sha256(time_hash + key_hash + changes_hash).hexdigest() context.output["request_hash"] = final_hash + return context def add_request_time(context: TransformContext) -> TransformContext: - context.output["request_time"] = DateTime.now() + if context.output is not None: + context.output["request_time"] = DateTime.now() return context def check_requesting_user_verify_key(context: TransformContext) -> TransformContext: - if context.obj.requesting_user_verify_key and context.node.is_root( - context.credentials - ): - context.output[ - "requesting_user_verify_key" - ] = context.obj.requesting_user_verify_key - else: - context.output["requesting_user_verify_key"] = context.credentials + if context.output and context.node and context.obj: + if context.obj.requesting_user_verify_key and context.node.is_root( + context.credentials + ): + context.output[ + "requesting_user_verify_key" + ] = context.obj.requesting_user_verify_key + else: + context.output["requesting_user_verify_key"] = context.credentials + return context def add_requesting_user_info(context: TransformContext) -> TransformContext: - try: - user_key = context.output["requesting_user_verify_key"] - user_service = context.node.get_service("UserService") - user = user_service.get_by_verify_key(user_key) - context.output["requesting_user_name"] = user.name - context.output["requesting_user_email"] = user.email - context.output["requesting_user_institution"] = ( - user.institution if user.institution else "" - ) - except Exception: - context.output["requesting_user_name"] = "guest_user" + if context.output is not None and context.node is not None: + try: + user_key = context.output["requesting_user_verify_key"] + user_service = context.node.get_service("UserService") + user = user_service.get_by_verify_key(user_key) + context.output["requesting_user_name"] = user.name + context.output["requesting_user_email"] = user.email + context.output["requesting_user_institution"] = ( + user.institution if user.institution else "" + ) + except Exception: + context.output["requesting_user_name"] = "guest_user" + return context @@ -873,11 +994,11 @@ class ObjectMutation(Change): __canonical_name__ = "ObjectMutation" __version__ = SYFT_OBJECT_VERSION_1 - linked_obj: Optional[LinkedObject] + linked_obj: Optional[LinkedObject] = None attr_name: str - value: Optional[Any] + value: Optional[Any] = None match_type: bool - previous_value: Optional[Any] + previous_value: Optional[Any] = None __repr_attrs__ = ["linked_obj", "attr_name"] @@ -945,7 +1066,7 @@ class EnumMutation(ObjectMutation): __version__ = SYFT_OBJECT_VERSION_1 enum_type: Type[Enum] - value: Optional[Enum] + value: Optional[Enum] = None match_type: bool = True __repr_attrs__ = ["linked_obj", "attr_name", "value"] @@ -1011,45 +1132,30 @@ def link(self) -> Optional[SyftObject]: return None -@serializable() -class UserCodeStatusChangeV1(Change): - __canonical_name__ = "UserCodeStatusChange" - __version__ = SYFT_OBJECT_VERSION_1 - - value: UserCodeStatus - linked_obj: LinkedObject - match_type: bool = True - __repr_attrs__ = [ - "link.service_func_name", - "link.input_policy_type.__canonical_name__", - "link.output_policy_type.__canonical_name__", - "link.status.approved", - ] - - @serializable() class UserCodeStatusChange(Change): __canonical_name__ = "UserCodeStatusChange" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 value: UserCodeStatus linked_obj: LinkedObject + linked_user_code: LinkedObject nested_solved: bool = False match_type: bool = True __repr_attrs__ = [ - "link.service_func_name", - "link.input_policy_type.__canonical_name__", - "link.output_policy_type.__canonical_name__", - "link.worker_pool_name", - "link.status.approved", + "code.service_func_name", + "code.input_policy_type.__canonical_name__", + "code.output_policy_type.__canonical_name__", + "code.worker_pool_name", + "code.status.approved", ] @property - def code(self) -> Optional[SyftObject]: - return self.link + def code(self) -> UserCode: + return self.linked_user_code.resolve @property - def codes(self) -> List: + def codes(self) -> List[UserCode]: def recursive_code(node: Any) -> List: codes = [] for _, (obj, new_node) in node.items(): @@ -1057,68 +1163,59 @@ def recursive_code(node: Any) -> List: codes.extend(recursive_code(new_node)) return codes - codes = [self.link] - if self.link is not None: - codes.extend(recursive_code(self.link.nested_codes)) - + codes = [self.code] + codes.extend(recursive_code(self.code.nested_codes)) return codes def nested_repr(self, node: Optional[Any] = None, level: int = 0) -> str: msg = "" - if node is None and self.link is not None: - node = self.link.nested_codes if node is None: - return msg - for service_func_name, (_, new_node) in node.items(): + node = self.code.nested_codes + + for service_func_name, (_, new_node) in node.items(): # type: ignore msg = "├──" + "──" * level + f"{service_func_name}<br>" msg += self.nested_repr(node=new_node, level=level + 1) return msg def __repr_syft_nested__(self) -> str: - if self.link is not None: - msg = ( - f"Request to change <b>{self.link.service_func_name}</b> " - f"(Pool Id: <b>{self.link.worker_pool_name}</b>) " - ) - msg += "to permission <b>RequestStatus.APPROVED</b>" - if self.nested_solved: - if self.link.nested_codes == {}: - msg += ". No nested requests" - else: - msg += ".<br><br>This change requests the following nested functions calls:<br>" - msg += self.nested_repr() + msg = ( + f"Request to change <b>{self.code.service_func_name}</b> " + f"(Pool Id: <b>{self.code.worker_pool_name}</b>) " + ) + msg += "to permission <b>RequestStatus.APPROVED</b>" + if self.nested_solved: + if self.link.nested_codes == {}: # type: ignore + msg += ". No nested requests" else: - msg += ". Nested Requests not resolved" + msg += ".<br><br>This change requests the following nested functions calls:<br>" + msg += self.nested_repr() else: - msg = f"LinkedObject of {self} is None." + msg += ". Nested Requests not resolved" return msg - def _repr_markdown_(self) -> str: - link = self.link - if link is None: - return f"{self}'s linked object is None" - + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: + code = self.code input_policy_type = ( - link.input_policy_type.__canonical_name__ - if link.input_policy_type is not None + code.input_policy_type.__canonical_name__ + if code.input_policy_type is not None else None ) output_policy_type = ( - link.output_policy_type.__canonical_name__ - if link.output_policy_type is not None + code.output_policy_type.__canonical_name__ + if code.output_policy_type is not None else None ) repr_dict = { - "function": link.service_func_name, + "function": code.service_func_name, "input_policy_type": f"{input_policy_type}", "output_policy_type": f"{output_policy_type}", - "approved": f"{link.status.approved}", + "approved": f"{code.status.approved}", } return markdown_as_class_with_fields(self, repr_dict) @property def approved(self) -> bool: - return self.linked_obj.resolve.status.approved + return self.linked_obj.resolve.approved @property def valid(self) -> Union[SyftSuccess, SyftError]: @@ -1143,12 +1240,18 @@ def valid(self) -> Union[SyftSuccess, SyftError]: # return approved_nested_codes - def mutate(self, obj: UserCode, context: ChangeContext, undo: bool) -> Any: + def mutate( + self, + status: UserCodeStatusCollection, + context: ChangeContext, + undo: bool, + ) -> Union[UserCodeStatusCollection, SyftError]: if context.node is None: return SyftError(message=f"context {context}'s node is None") reason: str = context.extra_kwargs.get("reason", "") + if not undo: - res = obj.status.mutate( + res = status.mutate( value=(self.value, reason), node_name=context.node.name, node_id=context.node.id, @@ -1157,15 +1260,12 @@ def mutate(self, obj: UserCode, context: ChangeContext, undo: bool) -> Any: if isinstance(res, SyftError): return res else: - res = obj.status.mutate( + res = status.mutate( value=(UserCodeStatus.DENIED, reason), node_name=context.node.name, node_id=context.node.id, verify_key=context.node.signing_key.verify_key, ) - if not isinstance(res, SyftError): - obj.status = res - return obj return res def is_enclave_request(self, user_code: UserCode) -> bool: @@ -1181,35 +1281,39 @@ def _run( valid = self.valid if not valid: return Err(valid) - obj = self.linked_obj.resolve_with_context(context) - if obj.is_err(): - return Err(SyftError(message=obj.err())) - obj = obj.ok() + user_code = self.linked_user_code.resolve_with_context(context) + if user_code.is_err(): + return Err(SyftError(message=user_code.err())) + user_code = user_code.ok() + user_code_status = self.linked_obj.resolve_with_context(context) + if user_code_status.is_err(): + return Err(SyftError(message=user_code_status.err())) + user_code_status = user_code_status.ok() + if apply: - res = self.mutate(obj, context, undo=False) + # Only mutate, does not write to stash + updated_status = self.mutate(user_code_status, context, undo=False) - if isinstance(res, SyftError): - return Err(res.message) + if isinstance(updated_status, SyftError): + return Err(updated_status.message) # relative from ..enclave.enclave_service import propagate_inputs_to_enclave - user_code = res - + self.linked_obj.update_with_context(context, updated_status) if self.is_enclave_request(user_code): enclave_res = propagate_inputs_to_enclave( - user_code=res, context=context + user_code=user_code, context=context ) if isinstance(enclave_res, SyftError): return enclave_res - self.linked_obj.update_with_context(context, user_code) else: - res = self.mutate(obj, context, undo=True) - if isinstance(res, SyftError): - return Err(res.message) + updated_status = self.mutate(user_code_status, context, undo=True) + if isinstance(updated_status, SyftError): + return Err(updated_status.message) # TODO: Handle Enclave approval. - self.linked_obj.update_with_context(context, res) + self.linked_obj.update_with_context(context, updated_status) return Ok(SyftSuccess(message=f"{type(self)} Success")) except Exception as e: print(f"failed to apply {type(self)}. {e}") @@ -1226,17 +1330,3 @@ def link(self) -> Optional[SyftObject]: if self.linked_obj: return self.linked_obj.resolve return None - - -@migrate(UserCodeStatusChange, UserCodeStatusChangeV1) -def downgrade_usercodestatuschange_v2_to_v1() -> List[Callable]: - return [ - drop("nested_solved"), - ] - - -@migrate(UserCodeStatusChangeV1, UserCodeStatusChange) -def upgrade_usercodestatuschange_v1_to_v2() -> List[Callable]: - return [ - make_set_default("nested_solved", True), - ] diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 99a9e19ff67..4d8bcede74d 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -18,9 +18,12 @@ from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import ActionPermission from ..context import AuthedServiceContext +from ..notification.email_templates import RequestEmailTemplate +from ..notification.email_templates import RequestUpdateEmailTemplate from ..notification.notification_service import CreateNotification from ..notification.notification_service import NotificationService from ..notification.notifications import Notification +from ..notifier.notifier_enums import NOTIFIERS from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService @@ -86,6 +89,8 @@ def submit( from_user_verify_key=context.credentials, to_user_verify_key=root_verify_key, linked_obj=link, + notifier_types=[NOTIFIERS.EMAIL], + email_template=RequestEmailTemplate, ) method = context.node.get_service_method(NotificationService.send) result = method(context=context, notification=message) @@ -210,20 +215,26 @@ def apply( link = LinkedObject.with_context(request, context=context) if not request.status == RequestStatus.PENDING: - mark_as_read = context.node.get_service_method( - NotificationService.mark_as_read - ) - mark_as_read(context=context, uid=request_notification.id) + if request_notification is not None and not isinstance( + request_notification, SyftError + ): + mark_as_read = context.node.get_service_method( + NotificationService.mark_as_read + ) + mark_as_read(context=context, uid=request_notification.id) - notification = CreateNotification( - subject=f"{request.changes} for Request id: {uid} has status updated to {request.status}", - to_user_verify_key=request.requesting_user_verify_key, - linked_obj=link, - ) - send_notification = context.node.get_service_method( - NotificationService.send - ) - send_notification(context=context, notification=notification) + notification = CreateNotification( + subject=f"Your request ({str(uid)[:4]}) has been approved!", + from_user_verify_key=context.credentials, + to_user_verify_key=request.requesting_user_verify_key, + linked_obj=link, + notifier_types=[NOTIFIERS.EMAIL], + email_template=RequestUpdateEmailTemplate, + ) + send_notification = context.node.get_service_method( + NotificationService.send + ) + send_notification(context=context, notification=notification) # TODO: check whereever we're return SyftError encapsulate it in Result. if hasattr(result, "value"): @@ -254,15 +265,15 @@ def undo( ) link = LinkedObject.with_context(request, context=context) - message_subject = ( - f"Your request for uid: {uid} has been denied. " - f"Reason specified by Data Owner: {reason}." - ) + message_subject = f"Your request ({str(uid)[:4]}) has been denied. " notification = CreateNotification( subject=message_subject, + from_user_verify_key=context.credentials, to_user_verify_key=request.requesting_user_verify_key, linked_obj=link, + notifier_types=[NOTIFIERS.EMAIL], + email_template=RequestUpdateEmailTemplate, ) context.node = cast(AbstractNode, context.node) send_notification = context.node.get_service_method(NotificationService.send) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index fa9a233e9fe..0ee5517d00e 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -10,6 +10,7 @@ from typing import List from typing import Optional from typing import Set +from typing import TYPE_CHECKING from typing import Tuple from typing import Type from typing import Union @@ -46,6 +47,10 @@ from .user.user_roles import ServiceRole from .warnings import APIEndpointWarning +if TYPE_CHECKING: + # relative + from ..client.api import APIModule + TYPE_TO_SERVICE: dict = {} SERVICE_TO_TYPES: defaultdict = defaultdict(set) @@ -77,6 +82,9 @@ def resolve_link( obj = Ok(obj) return obj + def get_all(*arg: Any, **kwargs: Any) -> Any: + pass + @serializable() class BaseConfig(SyftBaseObject): @@ -87,10 +95,10 @@ class BaseConfig(SyftBaseObject): private_path: str public_name: str method_name: str - doc_string: Optional[str] - signature: Optional[Signature] + doc_string: Optional[str] = None + signature: Optional[Signature] = None is_from_lib: bool = False - warning: Optional[APIEndpointWarning] + warning: Optional[APIEndpointWarning] = None @serializable() @@ -439,7 +447,7 @@ def from_api_or_context( func_or_path: str, syft_node_location: Optional[UID] = None, syft_client_verify_key: Optional[SyftVerifyKey] = None, -) -> Optional[Union[Callable, SyftError, partial]]: +) -> Optional[Union["APIModule", SyftError, partial]]: # relative from ..client.api import APIRegistry from ..node.node import AuthNodeContextRegistry @@ -464,7 +472,7 @@ def from_api_or_context( node_uid=syft_node_location, user_verify_key=syft_client_verify_key, ) - if node_context is not None: + if node_context is not None and node_context.node is not None: user_config_registry = UserServiceConfigRegistry.from_role( node_context.role, ) diff --git a/packages/syft/src/syft/service/settings/migrations.py b/packages/syft/src/syft/service/settings/migrations.py index 42ac9935bb7..739dbbbcac5 100644 --- a/packages/syft/src/syft/service/settings/migrations.py +++ b/packages/syft/src/syft/service/settings/migrations.py @@ -2,31 +2,13 @@ from typing import Callable # relative -from ...types.syft_migration import migrate from ...types.transforms import TransformContext -from ...types.transforms import drop -from .settings import NodeSettings -from .settings import NodeSettingsV2 def set_from_node_to_key(node_attr: str, key: str) -> Callable: def extract_from_node(context: TransformContext) -> TransformContext: - context.output[key] = getattr(context.node, node_attr) + if context.output is not None: + context.output[key] = getattr(context.node, node_attr) return context return extract_from_node - - -@migrate(NodeSettings, NodeSettingsV2) -def upgrade_metadata_v1_to_v2() -> list[Callable]: - return [ - set_from_node_to_key("verify_key", "verify_key"), - set_from_node_to_key("node_type", "node_type"), - ] - - -@migrate(NodeSettingsV2, NodeSettings) -def downgrade_metadata_v2_to_v1() -> list[Callable]: - return [ - drop(["verify_key", "node_type"]), - ] diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index fc16195c0aa..7f22fff0a77 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -6,7 +6,6 @@ from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.uid import UID @@ -15,7 +14,7 @@ @serializable() class NodeSettingsUpdate(PartialSyftObject): __canonical_name__ = "NodeSettingsUpdate" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: UID name: str @@ -26,29 +25,6 @@ class NodeSettingsUpdate(PartialSyftObject): admin_email: str -@serializable() -class NodeSettings(SyftObject): - __canonical_name__ = "NodeSettings" - __version__ = SYFT_OBJECT_VERSION_1 - __repr_attrs__ = [ - "name", - "organization", - "deployed_on", - "signup_enabled", - "admin_email", - ] - - name: str = "Node" - deployed_on: str - organization: str = "OpenMined" - on_board: bool = True - description: str = "Text" - signup_enabled: bool - admin_email: str - node_side_type: NodeSideType = NodeSideType.HIGH_SIDE - show_warnings: bool - - @serializable() class NodeSettingsV2(SyftObject): __canonical_name__ = "NodeSettings" diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 3dd5dfe8729..21342f68bbe 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -1,6 +1,7 @@ # stdlib # stdlib +from typing import Optional from typing import Union from typing import cast @@ -20,6 +21,7 @@ from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL from ..warnings import HighSideCRUDWarning from .settings import NodeSettingsUpdate from .settings import NodeSettingsV2 @@ -83,6 +85,44 @@ def update( else: return SyftError(message=result.err()) + @service_method( + path="settings.enable_notifications", + name="enable_notifications", + roles=ADMIN_ROLE_LEVEL, + ) + def enable_notifications( + self, + context: AuthedServiceContext, + email_username: Optional[str] = None, + email_password: Optional[str] = None, + email_sender: Optional[str] = None, + email_server: Optional[str] = None, + email_port: Optional[int] = None, + ) -> Union[SyftSuccess, SyftError]: + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + return notifier_service.turn_on( + context=context, + email_username=email_username, + email_password=email_password, + email_sender=email_sender, + email_server=email_server, + email_port=email_port, + ) + + @service_method( + path="settings.disable_notifications", + name="disable_notifications", + roles=ADMIN_ROLE_LEVEL, + ) + def disable_notifications( + self, + context: AuthedServiceContext, + ) -> Union[SyftSuccess, SyftError]: + context.node = cast(AbstractNode, context.node) + notifier_service = context.node.get_service("notifierservice") + return notifier_service.turn_off(context=context) + @service_method( path="settings.allow_guest_signup", name="allow_guest_signup", diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index 35aa58486e8..fb2f2bb9582 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -1,5 +1,6 @@ # stdlib from typing import List +from typing import Optional # third party from result import Result @@ -13,6 +14,7 @@ from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument +from ..action.action_permissions import ActionObjectPermission from .settings import NodeSettingsV2 NamePartitionKey = PartitionKey(key="name", type_=str) @@ -31,7 +33,11 @@ def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) def set( - self, credentials: SyftVerifyKey, settings: NodeSettingsV2 + self, + credentials: SyftVerifyKey, + settings: NodeSettingsV2, + add_permissions: Optional[List[ActionObjectPermission]] = None, + ignore_duplicates: bool = False, ) -> Result[NodeSettingsV2, str]: res = self.check_type(settings, self.object_type) # we dont use and_then logic here as it is hard because of the order of the arguments @@ -40,7 +46,10 @@ def set( return super().set(credentials=credentials, obj=res.ok()) def update( - self, credentials: SyftVerifyKey, settings: NodeSettingsV2 + self, + credentials: SyftVerifyKey, + settings: NodeSettingsV2, + has_permission: bool = False, ) -> Result[NodeSettingsV2, str]: res = self.check_type(settings, self.object_type) # we dont use and_then logic here as it is hard because of the order of the arguments diff --git a/packages/syft/src/syft/service/sync/__init__.py b/packages/syft/src/syft/service/sync/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py new file mode 100644 index 00000000000..0def3b1fa94 --- /dev/null +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -0,0 +1,771 @@ +""" +How to check differences between two objects: + * by default merge every attr + * check if there is a custom implementation of the check function + * check if there are exceptions we do not want to merge + * check if there are some restrictions on the attr set +""" + +# stdlib +import html +import textwrap +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import Union + +# third party +from pydantic import model_validator +from rich import box +from rich.console import Console +from rich.console import Group +from rich.markdown import Markdown +from rich.padding import Padding +from rich.panel import Panel +from typing_extensions import Self + +# relative +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SyftObject +from ...types.uid import LineageID +from ...types.uid import UID +from ...util import options +from ...util.colors import SURFACE +from ...util.fonts import ITABLES_CSS +from ...util.fonts import fonts_css +from ..action.action_object import ActionObject +from ..action.action_permissions import ActionObjectPermission +from ..code.user_code import UserCode +from ..code.user_code import UserCodeStatusCollection +from ..job.job_stash import Job +from ..log.log import SyftLog +from ..output.output_service import ExecutionOutput +from ..request.request import Request +from ..response import SyftError +from .sync_state import SyncState + +sketchy_tab = " " * 4 + + +class AttrDiff(SyftObject): + # version + __canonical_name__ = "AttrDiff" + __version__ = SYFT_OBJECT_VERSION_1 + attr_name: str + low_attr: Any = None + high_attr: Any = None + + def _repr_html_(self) -> str: + return f"""{self.attr_name}: + Low Side value: {self.low_attr} + High Side value: {self.high_attr} + """ + + def __repr_side__(self, side: str) -> str: + if side == "low": + return recursive_attr_repr(self.low_attr) + else: + return recursive_attr_repr(self.high_attr) + + def _coll_repr_(self) -> Dict[str, Any]: + return { + "attr name": self.attr_name, + "low attr": html.escape(f"{self.low_attr}"), + "high attr": html.escape(str(self.high_attr)), + } + + +class ListDiff(AttrDiff): + # version + __canonical_name__ = "ListDiff" + __version__ = SYFT_OBJECT_VERSION_1 + diff_ids: List[int] = [] + new_low_ids: List[int] = [] + new_high_ids: List[int] = [] + + @property + def is_empty(self) -> bool: + return ( + len(self.diff_ids) == 0 + and len(self.new_low_ids) == 0 + and len(self.new_high_ids) == 0 + ) + + @classmethod + def from_lists(cls, attr_name: str, low_list: List, high_list: List) -> "ListDiff": + diff_ids = [] + new_low_ids = [] + new_high_ids = [] + if len(low_list) != len(high_list): + if len(low_list) > len(high_list): + common_length = len(high_list) + new_low_ids = list(range(common_length, len(low_list))) + else: + common_length = len(low_list) + new_high_ids = list(range(common_length, len(high_list))) + else: + common_length = len(low_list) + + for i in range(common_length): + # if hasattr(low_list[i], 'syft_eq'): + # if not low_list[i].syft_eq(high_list[i]): + # diff_ids.append(i) + if low_list[i] != high_list[i]: + diff_ids.append(i) + + change_diff = ListDiff( + attr_name=attr_name, + low_attr=low_list, + high_attr=high_list, + diff_ids=diff_ids, + new_low_ids=new_low_ids, + new_high_ids=new_high_ids, + ) + return change_diff + + +def recursive_attr_repr(value_attr: Union[List, Dict, bytes], num_tabs: int = 0) -> str: + new_num_tabs = num_tabs + 1 + + if isinstance(value_attr, list): + list_repr = "[\n" + for elem in value_attr: + list_repr += recursive_attr_repr(elem, num_tabs=num_tabs + 1) + "\n" + list_repr += "]" + return list_repr + + elif isinstance(value_attr, dict): + dict_repr = "{\n" + for key, elem in value_attr.items(): + dict_repr += f"{sketchy_tab * new_num_tabs}{key}: {str(elem)}\n" + dict_repr += "}" + return dict_repr + + elif isinstance(value_attr, bytes): + value_attr = repr(value_attr) # type: ignore + if len(value_attr) > 50: + value_attr = value_attr[:50] + "..." # type: ignore + return f"{sketchy_tab*num_tabs}{str(value_attr)}" + + +class ObjectDiff(SyftObject): # StateTuple (compare 2 objects) + # version + __canonical_name__ = "ObjectDiff" + __version__ = SYFT_OBJECT_VERSION_1 + low_obj: Optional[SyftObject] = None + high_obj: Optional[SyftObject] = None + low_permissions: List[ActionObjectPermission] = [] + high_permissions: List[ActionObjectPermission] = [] + + new_low_permissions: List[ActionObjectPermission] = [] + new_high_permissions: List[ActionObjectPermission] = [] + obj_type: Type + diff_list: List[AttrDiff] = [] + + __repr_attrs__ = [ + "low_state", + "high_state", + ] + + @classmethod + def from_objects( + cls, + low_obj: Optional[SyftObject], + high_obj: Optional[SyftObject], + low_permissions: List[ActionObjectPermission], + high_permissions: List[ActionObjectPermission], + ) -> "ObjectDiff": + if low_obj is None and high_obj is None: + raise ValueError("Both low and high objects are None") + obj_type = type(low_obj if low_obj is not None else high_obj) + + if low_obj is None or high_obj is None: + diff_list = [] + else: + diff_list = low_obj.get_diffs(high_obj) + + return cls( + low_obj=low_obj, + high_obj=high_obj, + obj_type=obj_type, + low_permissions=low_permissions, + high_permissions=high_permissions, + diff_list=diff_list, + ) + + def __hash__(self) -> int: + return hash(self.id) + hash(self.low_obj) + hash(self.high_obj) + + @property + def status(self) -> str: + if self.low_obj is None or self.high_obj is None: + return "NEW" + if len(self.diff_list) == 0: + return "SAME" + return "DIFF" + + @property + def object_id(self) -> UID: + uid: Union[UID, LineageID] = ( + self.low_obj.id if self.low_obj is not None else self.high_obj.id # type: ignore + ) + if isinstance(uid, LineageID): + return uid.id + return uid + + @property + def non_empty_object(self) -> Optional[SyftObject]: + return self.low_obj or self.high_obj + + @property + def object_type(self) -> str: + return self.obj_type.__name__ + + @property + def high_state(self) -> str: + return self.state_str("high") + + @property + def low_state(self) -> str: + return self.state_str("low") + + @property + def object_uid(self) -> UID: + return self.low_obj.id if self.low_obj is not None else self.high_obj.id # type: ignore + + def diff_attributes_str(self, side: str) -> str: + obj = self.low_obj if side == "low" else self.high_obj + + if obj is None: + return "" + + repr_attrs = getattr(obj, "__repr_attrs__", []) + if self.status == "SAME": + repr_attrs = repr_attrs[:3] + + if self.status in {"SAME", "NEW"}: + attrs_str = "" + for attr in repr_attrs: + value = getattr(obj, attr) + attrs_str += f"{attr}: {recursive_attr_repr(value)}\n" + return attrs_str + + elif self.status == "DIFF": + attrs_str = "" + for diff in self.diff_list: + attrs_str += f"{diff.attr_name}: {diff.__repr_side__(side)}\n" + return attrs_str + else: + raise ValueError("") + + def diff_side_str(self, side: str) -> str: + obj = self.low_obj if side == "low" else self.high_obj + if obj is None: + return "" + res = f"{self.obj_type.__name__.upper()} #{obj.id}:\n" + res += self.diff_attributes_str(side) + return res + + def state_str(self, side: str) -> str: + other_obj: Optional[SyftObject] = None + if side == "high": + obj = self.high_obj + other_obj = self.low_obj + else: + obj = self.low_obj + other_obj = self.high_obj + + if obj is None: + return "-" + if self.status == "SAME": + return f"SAME\n{self.obj_type.__name__}" + + if isinstance(obj, ActionObject): + return obj.__repr__() + + if other_obj is None: # type: ignore[unreachable] + attrs_str = "" + attrs = getattr(obj, "__repr_attrs__", []) + for attr in attrs: + value = getattr(obj, attr) + attrs_str += f"{sketchy_tab}{attr} = {recursive_attr_repr(value)}\n" + attrs_str = attrs_str[:-1] + return f"NEW\n\nclass {self.object_type}:\n{attrs_str}" + + attr_text = f"DIFF\nclass {self.object_type}:\n" + for diff in self.diff_list: + # TODO + attr_text += ( + f"{sketchy_tab}{diff.attr_name}={diff.__repr_side__(side)}," + "\n" + ) + if len(self.diff_list) > 0: + attr_text = attr_text[:-2] + + return attr_text + + def get_obj(self) -> Optional[SyftObject]: + if self.status == "NEW": + return self.low_obj if self.low_obj is not None else self.high_obj + else: + raise ValueError("ERROR") + + def _coll_repr_(self) -> Dict[str, Any]: + low_state = f"{self.status}\n{self.diff_side_str('low')}" + high_state = f"{self.status}\n{self.diff_side_str('high')}" + return { + "low_state": html.escape(low_state), + "high_state": html.escape(high_state), + } + + def _repr_html_(self) -> str: + if self.low_obj is None and self.high_obj is None: + return SyftError(message="Something broke") + + base_str = f""" + <style> + {fonts_css} + .syft-dataset {{color: {SURFACE[options.color_theme]};}} + .syft-dataset h3, + .syft-dataset p + {{font-family: 'Open Sans';}} + {ITABLES_CSS} + </style> + <div class='syft-diff'> + """ + + obj_repr: str + attr_text: str + if self.low_obj is None: + if hasattr(self.high_obj, "_repr_html_"): + obj_repr = self.high_obj._repr_html_() # type: ignore + elif hasattr(self.high_obj, "_inner_repr"): + obj_repr = self.high_obj._inner_repr() # type: ignore + else: + obj_repr = self.__repr__() + attr_text = ( + f""" + <h3>{self.object_type} ObjectDiff (New {self.object_type} on the High Side):</h3> + """ + + obj_repr + ) + + elif self.high_obj is None: + if hasattr(self.low_obj, "_repr_html_"): + obj_repr = self.low_obj._repr_html_() # type: ignore + elif hasattr(self.low_obj, "_inner_repr"): + obj_repr = self.low_obj._inner_repr() # type: ignore + else: + obj_repr = self.__repr__() + attr_text = ( + f""" + <h3>{self.object_type} ObjectDiff (New {self.object_type} on the High Side):</h3> + """ + + obj_repr + ) + + elif self.status == "SAME": + obj_repr = "No changes between low side and high side" + else: + obj_repr = "" + for diff in self.diff_list: + obj_repr += diff.__repr__() + "<br>" + + obj_repr = obj_repr.replace("\n", "<br>") + # print("New lines", res) + + attr_text = f"<h3>{self.object_type} ObjectDiff:</h3>\n{obj_repr}" + return base_str + attr_text + + +def _wrap_text(text: str, width: int, indent: int = 4) -> str: + """Wrap text, preserving existing line breaks""" + return "\n".join( + [ + "\n".join( + textwrap.wrap( + line, + width, + break_long_words=False, + replace_whitespace=False, + subsequent_indent=" " * indent, + ) + ) + for line in text.splitlines() + if line.strip() != "" + ] + ) + + +class ObjectDiffBatch(SyftObject): + __canonical_name__ = "DiffHierarchy" + __version__ = SYFT_OBJECT_VERSION_1 + LINE_LENGTH: ClassVar[int] = 100 + INDENT: ClassVar[int] = 4 + ORDER: ClassVar[Dict] = {"low": 0, "high": 1} + + # Diffs are ordered in depth-first order, + # so the first diff is the root of the hierarchy + diffs: List[ObjectDiff] + hierarchy_levels: List[int] + dependencies: Dict[UID, List[UID]] = {} + dependents: Dict[UID, List[UID]] = {} + + @property + def visual_hierarchy(self) -> Tuple[Type, dict]: + # Returns + root_obj: Union[Request, UserCodeStatusCollection, ExecutionOutput, Any] = ( + self.root.low_obj if self.root.low_obj is not None else self.root.high_obj + ) + if isinstance(root_obj, Request): + return Request, { + Request: [UserCode], + UserCode: [UserCode], + } + if isinstance(root_obj, UserCodeStatusCollection): + return UserCode, { + UserCode: [UserCodeStatusCollection], + } + if isinstance(root_obj, ExecutionOutput): + return UserCode, { + UserCode: [Job], + Job: [ExecutionOutput, SyftLog, Job], + ExecutionOutput: [ActionObject], + } + raise ValueError(f"Unknown root type: {self.root.obj_type}") + + @model_validator(mode="after") + def make_dependents(self) -> Self: + dependents: Dict = {} + for parent, children in self.dependencies.items(): + for child in children: + dependents[child] = dependents.get(child, []) + [parent] + self.dependents = dependents + return self + + @property + def root(self) -> ObjectDiff: + return self.diffs[0] + + def __len__(self) -> int: + return len(self.diffs) + + def __repr__(self) -> str: + return f"""{self.hierarchy_str('low')} + +{self.hierarchy_str('high')} +""" + + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: + return "" # Turns off the _repr_markdown_ of SyftObject + + def _get_visual_hierarchy(self, node: ObjectDiff) -> dict[ObjectDiff, dict]: + _, child_types_map = self.visual_hierarchy + child_types = child_types_map.get(node.obj_type, []) + dep_ids = self.dependencies.get(node.object_id, []) + self.dependents.get( + node.object_id, [] + ) + + result = {} + for child_type in child_types: + children = [ + n + for n in self.diffs + if n.object_id in dep_ids + and isinstance(n.low_obj or n.high_obj, child_type) + ] + for child in children: + result[child] = self._get_visual_hierarchy(child) + + return result + + def get_visual_hierarchy(self) -> "ObjectDiffBatch": + visual_root_type = self.visual_hierarchy[0] + # First diff with a visual root type is the visual root + # because diffs are in depth-first order + visual_root = [ + diff + for diff in self.diffs + if isinstance(diff.low_obj or diff.high_obj, visual_root_type) + ][0] + return {visual_root: self._get_visual_hierarchy(visual_root)} # type: ignore + + def _get_obj_str(self, diff_obj: ObjectDiff, level: int, side: str) -> str: + obj = diff_obj.low_obj if side == "low" else diff_obj.high_obj + if obj is None: + return "" + indent = " " * level * self.INDENT + obj_str = diff_obj.diff_side_str(side) + obj_str = _wrap_text(obj_str, width=self.LINE_LENGTH - len(indent)) + + line_prefix = indent + f"―――― {diff_obj.status} " + line = "―" * (self.LINE_LENGTH - len(line_prefix)) + return f"""{line_prefix}{line} + +{textwrap.indent(obj_str, indent)} + +""" + + def hierarchy_str(self, side: str) -> str: + def _hierarchy_str_recursive(tree: Dict, level: int) -> str: + result = "" + for node, children in tree.items(): + result += self._get_obj_str(node, level, side) + result += _hierarchy_str_recursive(children, level + 1) + return result + + visual_hierarchy = self.get_visual_hierarchy() + res = _hierarchy_str_recursive(visual_hierarchy, 0) + if res == "": + res = f"No {side} side changes." + return f"""{side.upper()} SIDE STATE: + +{res}""" + + +class NodeDiff(SyftObject): + __canonical_name__ = "NodeDiff" + __version__ = SYFT_OBJECT_VERSION_1 + + obj_uid_to_diff: Dict[UID, ObjectDiff] = {} + dependencies: Dict[UID, List[UID]] = {} + + @classmethod + def from_sync_state( + cls: Type["NodeDiff"], low_state: SyncState, high_state: SyncState + ) -> "NodeDiff": + obj_uid_to_diff = {} + for obj_id in set(low_state.objects.keys()) | set(high_state.objects.keys()): + low_obj = low_state.objects.get(obj_id, None) + low_permissions: List = low_state.permissions.get(obj_id, []) + high_obj = high_state.objects.get(obj_id, None) + high_permissions: List = high_state.permissions.get(obj_id, []) + diff = ObjectDiff.from_objects( + low_obj, high_obj, low_permissions, high_permissions + ) + obj_uid_to_diff[diff.object_id] = diff + + node_diff = cls(obj_uid_to_diff=obj_uid_to_diff) + + node_diff._init_dependencies(low_state, high_state) + return node_diff + + def _init_dependencies(self, low_state: SyncState, high_state: SyncState) -> None: + all_parents = set(low_state.dependencies.keys()) | set( + high_state.dependencies.keys() + ) + for parent in all_parents: + low_deps = low_state.dependencies.get(parent, []) + high_deps = high_state.dependencies.get(parent, []) + self.dependencies[parent] = list(set(low_deps) | set(high_deps)) + + @property + def diffs(self) -> List[ObjectDiff]: + diffs_depthfirst = [ + diff for hierarchy in self.hierarchies for diff in hierarchy.diffs + ] + # deduplicate + diffs = [] + ids = set() + for diff in diffs_depthfirst: + if diff.object_id not in ids: + diffs.append(diff) + ids.add(diff.object_id) + return diffs + + def _repr_html_(self) -> Any: + return self.diffs._repr_html_() + + def _sort_hierarchies( + self, hierarchies: List[ObjectDiffBatch] + ) -> List[ObjectDiffBatch]: + without_usercode = [] + grouped_by_usercode: Dict[UID, List[ObjectDiffBatch]] = {} + for hierarchy in hierarchies: + has_usercode = False + for diff in hierarchy.diffs: + obj = diff.low_obj if diff.low_obj is not None else diff.high_obj + if isinstance(obj, UserCode): + grouped_by_usercode[obj.id] = hierarchy + has_usercode = True + break + if not has_usercode: + without_usercode.append(hierarchy) + + # Order of hierarchies, by root object type + hierarchy_order = [UserCodeStatusCollection, Request, ExecutionOutput] + # Sort group by hierarchy_order, then by root object id + for hierarchy_group in grouped_by_usercode.values(): + hierarchy_group.sort( + key=lambda x: ( + hierarchy_order.index(x.root.obj_type), + x.root.object_id, + ) + ) + + # sorted = sorted groups + without_usercode + sorted_hierarchies = [] + for grp in grouped_by_usercode.values(): + sorted_hierarchies.extend(grp) + sorted_hierarchies.extend(without_usercode) + return sorted_hierarchies + + @property + def hierarchies(self) -> List[ObjectDiffBatch]: + # Returns a list of hierarchies, where each hierarchy is a list of tuples (ObjectDiff, level), + # in depth-first order. + + # Each hierarchy only contains one root, at the first position + # Example: [(Diff1, 0), (Diff2, 1), (Diff3, 2), (Diff4, 1)] + # Diff1 + # -- Diff2 + # ---- Diff3 + # -- Diff4 + + def _build_hierarchy_helper( + uid: UID, level: int = 0, visited: Optional[Set] = None + ) -> List: + visited = visited if visited is not None else set() + + if uid in visited: + return [] + + result = [(uid, level)] + visited.add(uid) + if uid in self.dependencies: + deps = self.dependencies[uid] + for dep_uid in self.dependencies[uid]: + if dep_uid not in visited: + # NOTE we pass visited + deps to recursive calls, to have + # all objects at the highest level in the hierarchy + # Example: + # ExecutionOutput + # -- Job + # ---- Result + # -- Result + # We want to omit Job.Result, because it's already in ExecutionOutput.Result + result.extend( + _build_hierarchy_helper( + uid=dep_uid, + level=level + 1, + visited=visited | set(deps) - {dep_uid}, + ) + ) + return result + + hierarchies = [] + all_ids = set(self.obj_uid_to_diff.keys()) + child_ids = {child for deps in self.dependencies.values() for child in deps} + # Root ids are object ids with no parents + root_ids = list(all_ids - child_ids) + + for root_uid in root_ids: + uid_hierarchy = _build_hierarchy_helper(root_uid) + diffs = [self.obj_uid_to_diff[uid] for uid, _ in uid_hierarchy] + levels = [level for _, level in uid_hierarchy] + + batch_uids = {uid for uid, _ in uid_hierarchy} + dependencies = { + uid: [d for d in self.dependencies.get(uid, []) if d in batch_uids] + for uid in batch_uids + } + + batch = ObjectDiffBatch( + diffs=diffs, hierarchy_levels=levels, dependencies=dependencies + ) + hierarchies.append(batch) + + return hierarchies + + def objs_to_sync(self) -> List[SyftObject]: + objs: list[SyftObject] = [] + for diff in self.diffs: + if diff.status == "NEW": + objs.append(diff.get_obj()) + return objs + + +class ResolvedSyncState(SyftObject): + __canonical_name__ = "SyncUpdate" + __version__ = SYFT_OBJECT_VERSION_1 + + create_objs: List[SyftObject] = [] + update_objs: List[SyftObject] = [] + delete_objs: List[SyftObject] = [] + new_permissions: List[ActionObjectPermission] = [] + alias: str + + def add_cruds_from_diff(self, diff: ObjectDiff, decision: str) -> None: + if diff.status == "SAME": + return + + my_obj = diff.low_obj if self.alias == "low" else diff.high_obj + other_obj = diff.low_obj if self.alias == "high" else diff.high_obj + + if decision != self.alias: # chose for the other + if diff.status == "DIFF": + if other_obj not in self.update_objs: + self.update_objs.append(other_obj) + elif diff.status == "NEW": + if my_obj is None: + if other_obj not in self.create_objs: + self.create_objs.append(other_obj) + elif other_obj is None: + if my_obj not in self.delete_objs: + self.delete_objs.append(my_obj) + + def __repr__(self) -> str: + return ( + f"ResolvedSyncState(\n" + f" create_objs={self.create_objs},\n" + f" update_objs={self.update_objs},\n" + f" delete_objs={self.delete_objs}\n" + f" new_permissions={self.new_permissions}\n" + f")" + ) + + +def display_diff_object(obj_state: Optional[str]) -> Panel: + if obj_state is None: + return Panel(Markdown("None"), box=box.ROUNDED, expand=False) + return Panel( + Markdown(f"```python\n{obj_state}\n```", code_theme="default"), + box=box.ROUNDED, + expand=False, + ) + + +def display_diff_hierarchy(diff_hierarchy: List[Tuple[ObjectDiff, int]]) -> None: + console = Console() + + for diff, level in diff_hierarchy: + title = f"{diff.obj_type.__name__}({diff.object_id}) - State: {diff.status}" + + low_side_panel = display_diff_object(diff.low_state if diff.low_obj else None) + low_side_panel.title = "Low side" + low_side_panel.title_align = "left" + high_side_panel = display_diff_object( + diff.high_state if diff.high_obj is not None else None + ) + high_side_panel.title = "High side" + high_side_panel.title_align = "left" + + grouped_panels = Group(low_side_panel, high_side_panel) + + diff_panel = Panel( + grouped_panels, + title=title, + title_align="left", + box=box.ROUNDED, + expand=False, + padding=(1, 2), + ) + + if level > 0: + diff_panel = Padding(diff_panel, (0, 0, 0, 5 * level)) + + console.print(diff_panel) diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py new file mode 100644 index 00000000000..d25c2904e11 --- /dev/null +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -0,0 +1,265 @@ +# stdlib +from collections import defaultdict +from typing import Any +from typing import Dict +from typing import List +from typing import Set +from typing import Union +from typing import cast + +# third party +from result import Result + +# relative +from ...abstract_node import AbstractNode +from ...client.api import NodeIdentity +from ...node.credentials import SyftVerifyKey +from ...serde.serializable import serializable +from ...store.document_store import BaseStash +from ...store.document_store import DocumentStore +from ...store.linked_obj import LinkedObject +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ...util.telemetry import instrument +from ..action.action_object import ActionObject +from ..action.action_permissions import ActionObjectPermission +from ..action.action_permissions import ActionPermission +from ..code.user_code import UserCodeStatusCollection +from ..context import AuthedServiceContext +from ..job.job_stash import Job +from ..output.output_service import ExecutionOutput +from ..response import SyftError +from ..response import SyftSuccess +from ..service import AbstractService +from ..service import service_method +from ..user.user_roles import ADMIN_ROLE_LEVEL +from .sync_stash import SyncStash +from .sync_state import SyncState + + +@instrument +@serializable() +class SyncService(AbstractService): + store: DocumentStore + stash: SyncStash + + def __init__(self, store: DocumentStore): + self.store = store + self.stash = SyncStash(store=store) + + def add_actionobject_read_permissions( + self, + context: AuthedServiceContext, + action_object: ActionObject, + permissions_other: List[str], + ) -> None: + read_permissions = [x for x in permissions_other if "READ" in x] + + _id = action_object.id.id + blob_id = action_object.syft_blob_storage_entry_id + + store_to = context.node.get_service("actionservice").store # type: ignore + store_to_blob = context.node.get_service("blobstorageservice").stash.partition # type: ignore + + for read_permission in read_permissions: + creds, perm_str = read_permission.split("_") + perm = ActionPermission[perm_str] + permission = ActionObjectPermission( + uid=_id, permission=perm, credentials=SyftVerifyKey(creds) + ) + store_to.add_permission(permission) + + permission_blob = ActionObjectPermission( + uid=blob_id, permission=perm, credentials=SyftVerifyKey(creds) + ) + store_to_blob.add_permission(permission_blob) + + def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None: + if hasattr(x, "__dict__") and isinstance(x, SyftObject): + for val in x.__dict__.values(): + if isinstance(val, (list, tuple)): + for v in val: + self.set_obj_ids(context, v) + elif isinstance(val, dict): + for v in val.values(): + self.set_obj_ids(context, v) + else: + self.set_obj_ids(context, val) + x.syft_node_location = context.node.id # type: ignore + x.syft_client_verify_key = context.credentials + if hasattr(x, "node_uid"): + x.node_uid = context.node.id # type: ignore + + def transform_item( + self, context: AuthedServiceContext, item: SyftObject + ) -> SyftObject: + if isinstance(item, UserCodeStatusCollection): + identity = NodeIdentity.from_node(context.node) + res = {} + for key in item.status_dict.keys(): + # todo, check if they are actually only two nodes + res[identity] = item.status_dict[key] + item.status_dict = res + + self.set_obj_ids(context, item) + return item + + def get_stash_for_item( + self, context: AuthedServiceContext, item: SyftObject + ) -> BaseStash: + services = list(context.node.service_path_map.values()) # type: ignore + + all_stashes = {} + for serv in services: + if (_stash := getattr(serv, "stash", None)) is not None: + all_stashes[_stash.object_type] = _stash + + stash = all_stashes.get(type(item), None) + return stash + + def add_permissions_for_item( + self, + context: AuthedServiceContext, + item: SyftObject, + permissions_other: Set[ActionObjectPermission], + ) -> None: + if isinstance(item, Job) and context.node.node_side_type.value == "low": # type: ignore + _id = item.id + read_permissions = [x for x in permissions_other if "READ" in x] # type: ignore + job_store = context.node.get_service("jobservice").stash.partition # type: ignore + for read_permission in read_permissions: + creds, perm_str = read_permission.split("_") + perm = ActionPermission[perm_str] + permission = ActionObjectPermission( + uid=_id, permission=perm, credentials=SyftVerifyKey(creds) + ) + job_store.add_permission(permission) + + def set_object( + self, context: AuthedServiceContext, item: SyftObject + ) -> Result[SyftObject, str]: + stash = self.get_stash_for_item(context, item) + creds = context.credentials + + exists = stash.get_by_uid(context.credentials, item.id).ok() is not None + if exists: + res = stash.update(creds, item) + else: + # res = stash.delete_by_uid(node.python_node.verify_key, item.id) + res = stash.set(creds, item) + return res + + @service_method( + path="sync.sync_items", + name="sync_items", + roles=ADMIN_ROLE_LEVEL, + ) + def sync_items( + self, + context: AuthedServiceContext, + items: List[Union[ActionObject, SyftObject]], + permissions: Dict[UID, Set[str]], + ) -> Union[SyftSuccess, SyftError]: + permissions = defaultdict(list, permissions) + for item in items: + other_node_permissions = permissions[item.id.id] + if isinstance(item, ActionObject): + self.add_actionobject_read_permissions( + context, item, other_node_permissions + ) + else: + item = self.transform_item(context, item) # type: ignore[unreachable] + res = self.set_object(context, item) + + if res.is_ok(): + self.add_permissions_for_item(context, item, other_node_permissions) + else: + return SyftError(message=f"Failed to sync {res.err()}") + return SyftSuccess(message=f"Synced {len(items)} items") + + @service_method( + path="sync.get_permissions", + name="get_permissions", + roles=ADMIN_ROLE_LEVEL, + ) + def get_permissions( + self, + context: AuthedServiceContext, + items: List[Union[ActionObject, SyftObject]], + ) -> Dict: + permissions: Dict = {} + + def get_store(item): # type: ignore + if isinstance(item, ActionObject): + return context.node.get_service("actionservice").store + elif isinstance(item, Job): + return context.node.get_service("jobservice").stash.partition + else: + return None + + for item in items: + store = get_store(item) + if store is not None: + _id = item.id.id + permissions[item.id.id] = store.permissions[_id] + return permissions + + @service_method( + path="sync.get_state", + name="get_state", + roles=ADMIN_ROLE_LEVEL, + ) + def get_state( + self, context: AuthedServiceContext, add_to_store: bool = False + ) -> Union[SyncState, SyftError]: + new_state = SyncState() + + node = cast(AbstractNode, context.node) + + services_to_sync = [ + "projectservice", + "requestservice", + "usercodeservice", + "jobservice", + "logservice", + "outputservice", + "usercodestatusservice", + ] + + for service_name in services_to_sync: + service = node.get_service(service_name) + items = service.get_all(context) + new_state.add_objects(items, api=node.root_client.api) # type: ignore + + # TODO workaround, we only need action objects from outputs for now + action_object_ids = set() + for obj in new_state.objects.values(): + if isinstance(obj, ExecutionOutput): + action_object_ids |= set(obj.output_id_list) + elif isinstance(obj, Job) and obj.result is not None: + action_object_ids.add(obj.result.id) + + action_objects = [] + for uid in action_object_ids: + action_object = node.get_service("actionservice").get(context, uid) # type: ignore + if action_object.is_err(): + return SyftError(message=action_object.err()) + action_objects.append(action_object.ok()) + new_state.add_objects(action_objects) + + new_state._build_dependencies(api=node.root_client.api) # type: ignore + + new_state.permissions = self.get_permissions(context, new_state.objects) + + previous_state = self.stash.get_latest(context=context) + if previous_state is not None: + new_state.previous_state_link = LinkedObject.from_obj( + obj=previous_state, + service_type=SyncService, + node_uid=context.node.id, # type: ignore + ) + + if add_to_store: + self.stash.set(context.credentials, new_state) + + return new_state diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py new file mode 100644 index 00000000000..9ce8aeabeb2 --- /dev/null +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -0,0 +1,48 @@ +# stdlib +from typing import Optional +from typing import Union + +# relative +from ...serde.serializable import serializable +from ...store.document_store import BaseUIDStoreStash +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionKey +from ...store.document_store import PartitionSettings +from ...types.datetime import DateTime +from ...util.telemetry import instrument +from ..context import AuthedServiceContext +from ..response import SyftError +from .sync_state import SyncState + +OrderByDatePartitionKey = PartitionKey(key="created_at", type_=DateTime) + + +@instrument +@serializable() +class SyncStash(BaseUIDStoreStash): + object_type = SyncState + settings: PartitionSettings = PartitionSettings( + name=SyncState.__canonical_name__, object_type=SyncState + ) + + def __init__(self, store: DocumentStore): + super().__init__(store) + self.store = store + self.settings = self.settings + self._object_type = self.object_type + + def get_latest( + self, context: AuthedServiceContext + ) -> Union[Optional[SyncState], SyftError]: + all_states = self.get_all( + credentials=context.node.verify_key, # type: ignore + order_by=OrderByDatePartitionKey, + ) + + if all_states.is_err(): + return SyftError(message=all_states.err()) + + all_states = all_states.ok() + if len(all_states) > 0: + return all_states[-1] + return None diff --git a/packages/syft/src/syft/service/sync/sync_state.py b/packages/syft/src/syft/service/sync/sync_state.py new file mode 100644 index 00000000000..0e6ecb28074 --- /dev/null +++ b/packages/syft/src/syft/service/sync/sync_state.py @@ -0,0 +1,152 @@ +# stdlib +import html +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING + +# relative +from ...serde.serializable import serializable +from ...store.linked_obj import LinkedObject +from ...types.datetime import DateTime +from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SyftObject +from ...types.uid import LineageID +from ...types.uid import UID +from ..action.action_permissions import ActionPermission + +if TYPE_CHECKING: + # relative + from .diff_state import NodeDiff + + +def get_hierarchy_level_prefix(level: int) -> str: + if level == 0: + return "" + else: + return "--" * level + " " + + +@serializable() +class SyncStateRow(SyftObject): + """A row in the SyncState table""" + + __canonical_name__ = "SyncStateItem" + __version__ = SYFT_OBJECT_VERSION_1 + + object: SyftObject + previous_object: Optional[SyftObject] = None + current_state: str + previous_state: str + level: int = 0 + + # TODO table formatting + __repr_attrs__ = [ + "previous_state", + "current_state", + ] + + def _coll_repr_(self) -> Dict[str, Any]: + current_state = f"{self.status}\n{self.current_state}" + previous_state = f"{self.status}\n{self.previous_state}" + return { + "previous_state": html.escape(previous_state), + "current_state": html.escape(current_state), + } + + @property + def object_type(self) -> str: + prefix = get_hierarchy_level_prefix(self.level) + return f"{prefix}{type(self.object).__name__}" + + @property + def status(self) -> str: + # TODO use Diffs to determine status + if self.previous_object is None: + return "NEW" + elif self.previous_object.syft_eq(ext_obj=self.object): + return "SAME" + else: + return "UPDATED" + + +@serializable() +class SyncState(SyftObject): + __canonical_name__ = "SyncState" + __version__ = SYFT_OBJECT_VERSION_1 + + objects: Dict[UID, SyftObject] = {} + dependencies: Dict[UID, List[UID]] = {} + created_at: DateTime = DateTime.now() + previous_state_link: Optional[LinkedObject] = None + permissions: Dict[UID, List[ActionPermission]] = {} + + __attr_searchable__ = ["created_at"] + + @property + def previous_state(self) -> Optional["SyncState"]: + if self.previous_state_link is not None: + return self.previous_state_link.resolve + return None + + @property + def all_ids(self) -> Set[UID]: + return set(self.objects.keys()) + + def add_objects(self, objects: List[SyftObject], api: Any = None) -> None: + for obj in objects: + if isinstance(obj.id, LineageID): + self.objects[obj.id.id] = obj + else: + self.objects[obj.id] = obj + + # TODO might get slow with large states, + # need to build dependencies every time to not have UIDs + # in dependencies that are not in objects + self._build_dependencies(api=api) + + def _build_dependencies(self, api: Any = None) -> None: + self.dependencies = {} + + all_ids = self.all_ids + for obj in self.objects.values(): + if hasattr(obj, "get_sync_dependencies"): + deps = obj.get_sync_dependencies(api=api) + deps = [d.id for d in deps if d.id in all_ids] + if len(deps): + self.dependencies[obj.id] = deps + + def get_previous_state_diff(self) -> "NodeDiff": + # Re-use DiffState to compare to previous state + # Low = previous, high = current + # relative + from .diff_state import NodeDiff + + previous_state = self.previous_state or SyncState() + return NodeDiff.from_sync_state(previous_state, self) + + @property + def rows(self) -> List[SyncStateRow]: + result = [] + ids = set() + + previous_diff = self.get_previous_state_diff() + for hierarchy in previous_diff.hierarchies: + for diff, level in zip(hierarchy.diffs, hierarchy.hierarchy_levels): + if diff.object_id in ids: + continue + ids.add(diff.object_id) + row = SyncStateRow( + object=diff.high_obj, + previous_object=diff.low_obj, + current_state=diff.diff_side_str("high"), + previous_state=diff.diff_side_str("low"), + level=level, + ) + result.append(row) + return result + + def _repr_html_(self) -> str: + return self.rows._repr_html_() diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 71a07d2dfb0..038a198c645 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -6,15 +6,16 @@ from typing import List from typing import Optional from typing import Tuple +from typing import Type from typing import Union # third party from bcrypt import checkpw from bcrypt import gensalt from bcrypt import hashpw -import pydantic +from pydantic import EmailStr from pydantic import ValidationError -from pydantic.networks import EmailStr +from pydantic import field_validator # relative from ...client.api import APIRegistry @@ -22,10 +23,10 @@ from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...types.syft_metaclass import Empty -from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 +from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import drop @@ -35,48 +36,35 @@ from ...types.transforms import transform from ...types.transforms import validate_email from ...types.uid import UID +from ..notifier.notifier_enums import NOTIFIERS from ..response import SyftError from ..response import SyftSuccess from .user_roles import ServiceRole -class UserV1(SyftObject): - __canonical_name__ = "User" - __version__ = SYFT_OBJECT_VERSION_1 - - email: Optional[EmailStr] - name: Optional[str] - hashed_password: Optional[str] - salt: Optional[str] - signing_key: Optional[SyftSigningKey] - verify_key: Optional[SyftVerifyKey] - role: Optional[ServiceRole] - institution: Optional[str] - website: Optional[str] = None - created_at: Optional[str] = None - - @serializable() class User(SyftObject): # version __canonical_name__ = "User" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 - id: Optional[UID] - - @pydantic.validator("email", pre=True, always=True) - def make_email(cls, v: EmailStr) -> EmailStr: - return EmailStr(v) + id: Optional[UID] = None # type: ignore[assignment] # fields - email: Optional[EmailStr] - name: Optional[str] - hashed_password: Optional[str] - salt: Optional[str] - signing_key: Optional[SyftSigningKey] - verify_key: Optional[SyftVerifyKey] - role: Optional[ServiceRole] - institution: Optional[str] + notifications_enabled: Dict[NOTIFIERS, bool] = { + NOTIFIERS.EMAIL: True, + NOTIFIERS.SMS: False, + NOTIFIERS.SLACK: False, + NOTIFIERS.APP: False, + } + email: Optional[EmailStr] = None + name: Optional[str] = None + hashed_password: Optional[str] = None + salt: Optional[str] = None + signing_key: Optional[SyftSigningKey] = None + verify_key: Optional[SyftVerifyKey] = None + role: Optional[ServiceRole] = None + institution: Optional[str] = None website: Optional[str] = None created_at: Optional[str] = None # TODO where do we put this flag? @@ -93,6 +81,9 @@ def default_role(role: ServiceRole) -> Callable: def hash_password(context: TransformContext) -> TransformContext: + if context.output is None: + return context + if context.output["password"] is not None and ( (context.output["password_verify"] is None) or context.output["password"] == context.output["password_verify"] @@ -100,13 +91,16 @@ def hash_password(context: TransformContext) -> TransformContext: salt, hashed = salt_and_hash_password(context.output["password"], 12) context.output["hashed_password"] = hashed context.output["salt"] = salt + return context def generate_key(context: TransformContext) -> TransformContext: - signing_key = SyftSigningKey.generate() - context.output["signing_key"] = signing_key - context.output["verify_key"] = signing_key.verify_key + if context.output is not None: + signing_key = SyftSigningKey.generate() + context.output["signing_key"] = signing_key + context.output["verify_key"] = signing_key.verify_key + return context @@ -126,30 +120,13 @@ def check_pwd(password: str, hashed_password: str) -> bool: ) -class UserUpdateV1(PartialSyftObject): - __canonical_name__ = "UserUpdate" - __version__ = SYFT_OBJECT_VERSION_1 - - email: EmailStr - name: str - role: ServiceRole - password: str - password_verify: str - verify_key: SyftVerifyKey - institution: str - website: str - - @serializable() class UserUpdate(PartialSyftObject): __canonical_name__ = "UserUpdate" - __version__ = SYFT_OBJECT_VERSION_2 - - @pydantic.validator("email", pre=True) - def make_email(cls, v: Any) -> Any: - return EmailStr(v) if isinstance(v, str) and not isinstance(v, EmailStr) else v + __version__ = SYFT_OBJECT_VERSION_3 - @pydantic.validator("role", pre=True) + @field_validator("role", mode="before") + @classmethod def str_to_role(cls, v: Any) -> Any: if isinstance(v, str) and hasattr(ServiceRole, v.upper()): return getattr(ServiceRole, v.upper()) @@ -166,35 +143,20 @@ def str_to_role(cls, v: Any) -> Any: mock_execution_permission: bool -class UserCreateV1(UserUpdateV1): - __canonical_name__ = "UserCreate" - __version__ = SYFT_OBJECT_VERSION_1 - - email: EmailStr - name: str - role: Optional[ServiceRole] = None # type: ignore[assignment] - password: str - password_verify: Optional[str] = None # type: ignore[assignment] - verify_key: Optional[SyftVerifyKey] - institution: Optional[str] # type: ignore[assignment] - website: Optional[str] # type: ignore[assignment] - created_by: Optional[SyftSigningKey] - - @serializable() -class UserCreate(UserUpdate): +class UserCreate(SyftObject): __canonical_name__ = "UserCreate" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 email: EmailStr name: str role: Optional[ServiceRole] = None # type: ignore[assignment] password: str password_verify: Optional[str] = None # type: ignore[assignment] - verify_key: Optional[SyftVerifyKey] - institution: Optional[str] # type: ignore[assignment] - website: Optional[str] # type: ignore[assignment] - created_by: Optional[SyftSigningKey] + verify_key: Optional[SyftVerifyKey] = None # type: ignore[assignment] + institution: Optional[str] = "" # type: ignore[assignment] + website: Optional[str] = "" # type: ignore[assignment] + created_by: Optional[SyftSigningKey] = None # type: ignore[assignment] mock_execution_permission: bool = False __repr_attrs__ = ["name", "email"] @@ -203,7 +165,7 @@ class UserCreate(UserUpdate): @serializable() class UserSearch(PartialSyftObject): __canonical_name__ = "UserSearch" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: UID email: EmailStr @@ -211,22 +173,17 @@ class UserSearch(PartialSyftObject): name: str -class UserViewV1(SyftObject): - __canonical_name__ = "UserView" - __version__ = SYFT_OBJECT_VERSION_1 - - email: EmailStr - name: str - role: ServiceRole # make sure role cant be set without uid - institution: Optional[str] - website: Optional[str] - - @serializable() class UserView(SyftObject): __canonical_name__ = "UserView" - __version__ = SYFT_OBJECT_VERSION_2 - + __version__ = SYFT_OBJECT_VERSION_3 + + notifications_enabled: Dict[NOTIFIERS, bool] = { + NOTIFIERS.EMAIL: True, + NOTIFIERS.SMS: False, + NOTIFIERS.SLACK: False, + NOTIFIERS.APP: False, + } email: EmailStr name: str role: ServiceRole # make sure role cant be set without uid @@ -234,7 +191,14 @@ class UserView(SyftObject): website: Optional[str] mock_execution_permission: bool - __repr_attrs__ = ["name", "email", "institution", "website", "role"] + __repr_attrs__ = [ + "name", + "email", + "institution", + "website", + "role", + "notifications_enabled", + ] def _coll_repr_(self) -> Dict[str, Any]: return { @@ -243,6 +207,10 @@ def _coll_repr_(self) -> Dict[str, Any]: "Institute": self.institution, "Website": self.website, "Role": self.role.name.capitalize(), + "Notifications": "Email: " + + ( + "Enabled" if self.notifications_enabled[NOTIFIERS.EMAIL] else "Disabled" + ), } def _set_password(self, new_password: str) -> Union[SyftError, SyftSuccess]: @@ -252,8 +220,7 @@ def _set_password(self, new_password: str) -> Union[SyftError, SyftSuccess]: ) if api is None: return SyftError(message=f"You must login to {self.node_uid}") - if api.services is None: - return SyftError(message=f"Services for api {api} is None") + api.services.user.update( uid=self.id, user_update=UserUpdate(password=new_password) ) @@ -290,8 +257,6 @@ def set_email(self, email: str) -> Union[SyftSuccess, SyftError]: except ValidationError: return SyftError(message="{email} is not a valid email address.") - if api.services is None: - return SyftError(message=f"Services for {api} is None") result = api.services.user.update(uid=self.id, user_update=user_update) if isinstance(result, SyftError): @@ -305,11 +270,11 @@ def set_email(self, email: str) -> Union[SyftSuccess, SyftError]: def update( self, - name: Union[Empty, str] = Empty, - institution: Union[Empty, str] = Empty, - website: Union[str, Empty] = Empty, - role: Union[str, Empty] = Empty, - mock_execution_permission: Union[bool, Empty] = Empty, + name: Union[Type[Empty], str] = Empty, + institution: Union[Type[Empty], str] = Empty, + website: Union[Type[Empty], str] = Empty, + role: Union[Type[Empty], str] = Empty, + mock_execution_permission: Union[Type[Empty], bool] = Empty, ) -> Union[SyftSuccess, SyftError]: """Used to update name, institution, website of a user.""" api = APIRegistry.api_for( @@ -325,8 +290,6 @@ def update( role=role, mock_execution_permission=mock_execution_permission, ) - if api.services is None: - return SyftError(message=f"Services for {api} is None") result = api.services.user.update(uid=self.id, user_update=user_update) if isinstance(result, SyftError): @@ -384,6 +347,7 @@ def user_to_view_user() -> List[Callable]: "institution", "website", "mock_execution_permission", + "notifications_enabled", ] ) ] @@ -402,43 +366,3 @@ class UserPrivateKey(SyftObject): @transform(User, UserPrivateKey) def user_to_user_verify() -> List[Callable]: return [keep(["email", "signing_key", "id", "role"])] - - -@migrate(UserV1, User) -def upgrade_user_v1_to_v2() -> List[Callable]: - return [make_set_default(key="mock_execution_permission", value=False)] - - -@migrate(User, UserV1) -def downgrade_user_v2_to_v1() -> List[Callable]: - return [drop(["mock_execution_permission"])] - - -@migrate(UserUpdateV1, UserUpdate) -def upgrade_user_update_v1_to_v2() -> List[Callable]: - return [make_set_default(key="mock_execution_permission", value=False)] - - -@migrate(UserUpdate, UserUpdateV1) -def downgrade_user_update_v2_to_v1() -> List[Callable]: - return [drop(["mock_execution_permission"])] - - -@migrate(UserCreateV1, UserCreate) -def upgrade_user_create_v1_to_v2() -> List[Callable]: - return [make_set_default(key="mock_execution_permission", value=False)] - - -@migrate(UserCreate, UserCreateV1) -def downgrade_user_create_v2_to_v1() -> List[Callable]: - return [drop(["mock_execution_permission"])] - - -@migrate(UserViewV1, UserView) -def upgrade_user_view_v1_to_v2() -> List[Callable]: - return [make_set_default(key="mock_execution_permission", value=False)] - - -@migrate(UserView, UserViewV1) -def downgrade_user_view_v2_to_v1() -> List[Callable]: - return [drop(["mock_execution_permission"])] diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index e46273c4197..56d7cd47a07 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -14,6 +14,7 @@ from ...node.credentials import UserLoginCredentials from ...serde.serializable import serializable from ...store.document_store import DocumentStore +from ...store.linked_obj import LinkedObject from ...types.syft_metaclass import Empty from ...types.uid import UID from ...util.telemetry import instrument @@ -22,6 +23,10 @@ from ..context import AuthedServiceContext from ..context import NodeServiceContext from ..context import UnauthedServiceContext +from ..notification.email_templates import OnBoardEmailTemplate +from ..notification.notification_service import CreateNotification +from ..notification.notification_service import NotificationService +from ..notifier.notifier_enums import NOTIFIERS from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService @@ -137,7 +142,6 @@ def get_role_for_credentials( self, credentials: Union[SyftVerifyKey, SyftSigningKey] ) -> Union[Optional[ServiceRole], SyftError]: # they could be different - if isinstance(credentials, SyftVerifyKey): result = self.stash.get_by_verify_key( credentials=credentials, verify_key=credentials @@ -225,12 +229,12 @@ def get_current_user( def update( self, context: AuthedServiceContext, uid: UID, user_update: UserUpdate ) -> Union[UserView, SyftError]: - updates_role = user_update.role is not Empty + updates_role = user_update.role is not Empty # type: ignore[comparison-overlap] can_edit_roles = ServiceRoleCapability.CAN_EDIT_ROLES in context.capabilities() if updates_role and not can_edit_roles: return SyftError(message=f"{context.role} is not allowed to edit roles") - if (user_update.mock_execution_permission is not Empty) and not can_edit_roles: + if (user_update.mock_execution_permission is not Empty) and not can_edit_roles: # type: ignore[comparison-overlap] return SyftError( message=f"{context.role} is not allowed to update permissions" ) @@ -361,6 +365,8 @@ def delete(self, context: AuthedServiceContext, uid: UID) -> Union[bool, SyftErr if result.is_err(): return SyftError(message=str(result.err())) + # TODO: Remove notifications for the deleted user + return result.ok() def exchange_credentials( @@ -438,6 +444,7 @@ def register( result = self.stash.get_by_email(credentials=user.verify_key, email=user.email) if result.is_err(): return SyftError(message=str(result.err())) + user_exists = result.ok() is not None if user_exists: return SyftError(message=f"User already exists with email: {user.email}") @@ -457,8 +464,30 @@ def register( user = result.ok() success_message = f"User '{user.name}' successfully registered!" + + # Notification Step + root_key = self.admin_verify_key() + root_context = AuthedServiceContext(node=context.node, credentials=root_key) + link = None + if new_user.created_by: + link = LinkedObject.with_context(user, context=root_context) + message = CreateNotification( + subject=success_message, + from_user_verify_key=root_key, + to_user_verify_key=user.verify_key, + linked_obj=link, + notifier_types=[NOTIFIERS.EMAIL], + email_template=OnBoardEmailTemplate, + ) + + method = context.node.get_service_method(NotificationService.send) + result = method(context=root_context, notification=message) + if request_user_role in DATA_OWNER_ROLE_LEVEL: success_message += " To see users, run `[your_client].users`" + + # TODO: Add a notifications for the new user + msg = SyftSuccess(message=success_message) return (msg, user.to(UserPrivateKey)) @@ -482,6 +511,55 @@ def get_by_verify_key( return result.ok() return SyftError(message=f"No User with verify_key: {verify_key}") + # TODO: This exposed service is only for the development phase. + # enable/disable notifications will be called from Notifier Service + + def _set_notification_status( + self, + notifier_type: NOTIFIERS, + new_status: bool, + verify_key: SyftVerifyKey, + ) -> Optional[SyftError]: + result = self.stash.get_by_verify_key( + credentials=verify_key, verify_key=verify_key + ) + if result.is_ok(): + # this seems weird that we get back None as Ok(None) + user = result.ok() + else: + return SyftError(message=str(result.err())) + + user.notifications_enabled[notifier_type] = new_status + + result = self.stash.update( + credentials=user.verify_key, + user=user, + ) + if result.is_err(): + return SyftError(message=str(result.err())) + else: + return None + + def enable_notifications( + self, context: AuthedServiceContext, notifier_type: NOTIFIERS + ) -> Union[SyftSuccess, SyftError]: + result = self._set_notification_status(notifier_type, True, context.credentials) + if result is not None: + return result + else: + return SyftSuccess(message="Notifications enabled successfully!") + + def disable_notifications( + self, context: AuthedServiceContext, notifier_type: NOTIFIERS + ) -> Union[SyftSuccess, SyftError]: + result = self._set_notification_status( + notifier_type, False, context.credentials + ) + if result is not None: + return result + else: + return SyftSuccess(message="Notifications disabled successfully!") + TYPE_TO_SERVICE[User] = UserService SERVICE_TO_TYPES[UserService].update({User}) diff --git a/packages/syft/src/syft/service/warnings.py b/packages/syft/src/syft/service/warnings.py index 553531046a8..015121c4bfa 100644 --- a/packages/syft/src/syft/service/warnings.py +++ b/packages/syft/src/syft/service/warnings.py @@ -1,6 +1,7 @@ # stdlib from typing import Any from typing import Optional +from typing import cast # third party from IPython.display import display @@ -10,6 +11,7 @@ # relative from ..abstract_node import AbstractNode from ..abstract_node import NodeSideType +from ..abstract_node import NodeType from ..node.credentials import SyftCredentials from ..serde.serializable import serializable from ..types.base import SyftBaseModel @@ -20,8 +22,8 @@ class WarningContext( Context, ): - node: Optional[AbstractNode] - credentials: Optional[SyftCredentials] + node: Optional[AbstractNode] = None + credentials: Optional[SyftCredentials] = None role: ServiceRole @@ -75,17 +77,19 @@ def message_from(self, context: Optional[WarningContext] = None) -> Self: if context is not None: node = context.node if node is not None: - node_side_type = node.node_side_type + node_side_type = cast(NodeSideType, node.node_side_type) node_type = node.node_type + _msg = ( "which could host datasets with private information." if node_side_type.value == NodeSideType.HIGH_SIDE.value else "which only hosts mock or synthetic data." ) - message = ( - "You're performing an operation on " - f"{node_side_type.value} side {node_type.value}, {_msg}" - ) + if node_type is not None: + message = ( + "You're performing an operation on " + f"{node_side_type.value} side {node_type.value}, {_msg}" + ) confirmation = node_side_type.value == NodeSideType.HIGH_SIDE.value return CRUDWarning(confirmation=confirmation, message=message) @@ -101,17 +105,19 @@ def message_from(self, context: Optional[WarningContext] = None) -> Self: if context is not None: node = context.node if node is not None: - node_side_type = node.node_side_type + node_side_type = cast(NodeSideType, node.node_side_type) node_type = node.node_type + _msg = ( "which could host datasets with private information." if node_side_type.value == NodeSideType.HIGH_SIDE.value else "which only hosts mock or synthetic data." ) - message = ( - "You're performing an operation on " - f"{node_side_type.value} side {node_type.value}, {_msg}" - ) + if node_type is not None: + message = ( + "You're performing an operation on " + f"{node_side_type.value} side {node_type.value}, {_msg}" + ) return CRUDReminder(confirmation=confirmation, message=message) @@ -124,8 +130,8 @@ def message_from(self, context: Optional[WarningContext] = None) -> Self: if context is not None: node = context.node if node is not None: - node_side_type = node.node_side_type - node_type = node.node_type + node_side_type = cast(NodeSideType, node.node_side_type) + node_type = cast(NodeType, node.node_type) if node_side_type.value == NodeSideType.LOW_SIDE.value: message = ( "You're performing an operation on " @@ -144,8 +150,8 @@ def message_from(self, context: Optional[WarningContext] = None) -> Self: if context is not None: node = context.node if node is not None: - node_side_type = node.node_side_type - node_type = node.node_type + node_side_type = cast(NodeSideType, node.node_side_type) + node_type = cast(NodeType, node.node_type) if node_side_type.value == NodeSideType.HIGH_SIDE.value: message = ( "You're performing an operation on " diff --git a/packages/syft/src/syft/service/worker/image_identifier.py b/packages/syft/src/syft/service/worker/image_identifier.py index 4651c3f4f2e..ac29f9ed3c9 100644 --- a/packages/syft/src/syft/service/worker/image_identifier.py +++ b/packages/syft/src/syft/service/worker/image_identifier.py @@ -29,7 +29,7 @@ class SyftWorkerImageIdentifier(SyftBaseModel): https://docs.docker.com/engine/reference/commandline/tag/#tag-an-image-referenced-by-name-and-tag """ - registry: Optional[Union[SyftImageRegistry, str]] + registry: Optional[Union[SyftImageRegistry, str]] = None repo: str tag: str diff --git a/packages/syft/src/syft/service/worker/image_registry.py b/packages/syft/src/syft/service/worker/image_registry.py index 7292273c605..bac6b8274a4 100644 --- a/packages/syft/src/syft/service/worker/image_registry.py +++ b/packages/syft/src/syft/service/worker/image_registry.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse # third party -from pydantic import validator +from pydantic import field_validator from typing_extensions import Self # relative @@ -28,7 +28,8 @@ class SyftImageRegistry(SyftObject): id: UID url: str - @validator("url") + @field_validator("url") + @classmethod def validate_url(cls, val: str) -> str: if not val: raise ValueError("Invalid Registry URL. Must not be empty") diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 8acbd9e6f4b..bf4a111a282 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -62,7 +62,7 @@ def add( def delete( self, context: AuthedServiceContext, - uid: UID = None, + uid: Optional[UID] = None, url: Optional[str] = None, ) -> Union[SyftSuccess, SyftError]: # TODO - we need to make sure that there are no workers running an image bound to this registry diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index d9d467f8b7b..e42b7021a6a 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -169,8 +169,6 @@ def run_container_using_docker( environment["CREATE_PRODUCER"] = "false" environment["N_CONSUMERS"] = 1 environment["PORT"] = str(get_free_tcp_port()) - environment["HTTP_PORT"] = str(88 + worker_count) - environment["HTTPS_PORT"] = str(446 + worker_count) environment["CONSUMER_SERVICE_NAME"] = pool_name environment["SYFT_WORKER_UID"] = syft_worker_uid environment["DEV_MODE"] = debug @@ -430,7 +428,7 @@ def run_workers_in_kubernetes( # create worker object for pod in pool_pods: - status = runner.get_pod_status(pod) + status: Optional[Union[PodStatus, WorkerStatus]] = runner.get_pod_status(pod) status, healthcheck, error = map_pod_to_worker_status(status) # this worker id will be the same as the one in the worker diff --git a/packages/syft/src/syft/service/worker/worker.py b/packages/syft/src/syft/service/worker/worker.py deleted file mode 100644 index ef3fc4aec5d..00000000000 --- a/packages/syft/src/syft/service/worker/worker.py +++ /dev/null @@ -1,61 +0,0 @@ -# stdlib -from typing import Any -from typing import Callable -from typing import Dict -from typing import List - -# relative -from ...serde.serializable import serializable -from ...store.document_store import SYFT_OBJECT_VERSION_1 -from ...store.document_store import SyftObject -from ...types.datetime import DateTime -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.transforms import drop -from ...types.transforms import make_set_default - - -@serializable() -class DockerWorkerV1(SyftObject): - # version - __canonical_name__ = "ContainerImage" - __version__ = SYFT_OBJECT_VERSION_1 - - __attr_searchable__ = ["container_id"] - __attr_unique__ = ["container_id"] - __repr_attrs__ = ["container_id", "created_at"] - - container_id: str - created_at: DateTime = DateTime.now() - - -@serializable() -class DockerWorker(SyftObject): - # version - __canonical_name__ = "ContainerImage" - __version__ = SYFT_OBJECT_VERSION_2 - - __attr_searchable__ = ["container_id", "container_name"] - __attr_unique__ = ["container_id"] - __repr_attrs__ = ["container_id", "created_at"] - - container_name: str - container_id: str - created_at: DateTime = DateTime.now() - - def _coll_repr_(self) -> Dict[str, Any]: - return { - "container_name": self.container_name, - "container_id": self.container_id, - "created_at": self.created_at, - } - - -@migrate(DockerWorker, DockerWorkerV1) -def downgrade_job_v2_to_v1() -> List[Callable]: - return [drop(["container_name"])] - - -@migrate(DockerWorkerV1, DockerWorker) -def upgrade_job_v2_to_v3() -> List[Callable]: - return [make_set_default("job_consumer_id", None)] diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index 9359fffce58..0c737c2d799 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -94,7 +94,7 @@ def build( if registry_uid: # get registry from image registry service - image_registry_service: SyftImageRegistryService = context.node.get_service( + image_registry_service: AbstractService = context.node.get_service( SyftImageRegistryService ) registry_result = image_registry_service.get_by_id(context, registry_uid) diff --git a/packages/syft/src/syft/service/worker/worker_pool.py b/packages/syft/src/syft/service/worker/worker_pool.py index 13c1cf1c158..2cc89394a49 100644 --- a/packages/syft/src/syft/service/worker/worker_pool.py +++ b/packages/syft/src/syft/service/worker/worker_pool.py @@ -69,14 +69,14 @@ class SyftWorker(SyftObject): id: UID name: str - container_id: Optional[str] + container_id: Optional[str] = None created_at: DateTime = DateTime.now() - healthcheck: Optional[WorkerHealth] + healthcheck: Optional[WorkerHealth] = None status: WorkerStatus - image: Optional[SyftWorkerImage] + image: Optional[SyftWorkerImage] = None worker_pool_name: str consumer_state: ConsumerState = ConsumerState.DETACHED - job_id: Optional[UID] + job_id: Optional[UID] = None @property def logs(self) -> Union[str, SyftError]: @@ -86,8 +86,6 @@ def logs(self) -> Union[str, SyftError]: ) if api is None: return SyftError(message=f"You must login to {self.node_uid}") - if api.services is None: - return SyftError(message=f"Services for {api} is None") return api.services.worker.logs(uid=self.id) def get_job_repr(self) -> str: @@ -98,8 +96,6 @@ def get_job_repr(self) -> str: ) if api is None: return SyftError(message=f"You must login to {self.node_uid}") - if api.services is None: - return f"Services for api {api} is None" job = api.services.job.get(self.job_id) if job.action.user_code_id is not None: func_name = api.services.code.get_by_id( @@ -118,8 +114,6 @@ def refresh_status(self) -> Optional[SyftError]: ) if api is None: return SyftError(message=f"You must login to {self.node_uid}") - if api.services is None: - return SyftError(message=f"Services for {api} is None") res = api.services.worker.status(uid=self.id) if isinstance(res, SyftError): @@ -166,7 +160,7 @@ class WorkerPool(SyftObject): ] name: str - image_id: Optional[UID] + image_id: Optional[UID] = None max_count: int worker_list: List[LinkedObject] created_at: DateTime = DateTime.now() @@ -187,7 +181,7 @@ def image(self) -> Optional[Union[SyftWorkerImage, SyftError]]: return None @property - def running_workers(self) -> Union[List[UID], SyftError]: + def running_workers(self) -> Union[List[SyftWorker], SyftError]: """Query the running workers using an API call to the server""" _running_workers = [] for worker in self.workers: @@ -197,7 +191,7 @@ def running_workers(self) -> Union[List[UID], SyftError]: return _running_workers @property - def healthy_workers(self) -> Union[List[UID], SyftError]: + def healthy_workers(self) -> Union[List[SyftWorker], SyftError]: """ Query the healthy workers using an API call to the server """ @@ -255,7 +249,7 @@ def workers(self) -> List[SyftWorker]: resolved_workers = [] for worker in self.worker_list: resolved_worker = worker.resolve - if resolved_worker is None: + if isinstance(resolved_worker, SyftError) or resolved_worker is None: continue resolved_worker.refresh_status() resolved_workers.append(resolved_worker) @@ -274,8 +268,8 @@ class ContainerSpawnStatus(SyftBaseModel): __repr_attrs__ = ["worker_name", "worker", "error"] worker_name: str - worker: Optional[SyftWorker] - error: Optional[str] + worker: Optional[SyftWorker] = None + error: Optional[str] = None def _get_worker_container( diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index b1e198e4602..cdd2f83aa35 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -117,7 +117,7 @@ def launch( worker_image: SyftWorkerImage = result.ok() context.node = cast(AbstractNode, context.node) - worker_service: WorkerService = context.node.get_service("WorkerService") + worker_service: AbstractService = context.node.get_service("WorkerService") worker_stash = worker_service.stash # Create worker pool from given image, with the given worker pool @@ -406,7 +406,7 @@ def add_workers( worker_image: SyftWorkerImage = result.ok() context.node = cast(AbstractNode, context.node) - worker_service: WorkerService = context.node.get_service("WorkerService") + worker_service: AbstractService = context.node.get_service("WorkerService") worker_stash = worker_service.stash # Add workers to given pool from the given image @@ -586,7 +586,7 @@ def sync_pool_from_request( pool_name = change.pool_name num_workers = change.num_workers image_uid = change.image_uid - elif isinstance(change, CreateCustomImageChange): + elif isinstance(change, CreateCustomImageChange): # type: ignore[unreachable] config = change.config tag = change.tag @@ -598,7 +598,7 @@ def sync_pool_from_request( image_uid=image_uid, ) elif config is not None: - return self.create_image_and_pool_request( + return self.create_image_and_pool_request( # type: ignore[unreachable] context=context, pool_name=pool_name, num_workers=num_workers, @@ -699,7 +699,7 @@ def _create_workers_in_pool( ) if isinstance(result, OkErr): - node = cast(AbstractNode, context.node) + node = context.node if result.is_ok(): worker_obj = LinkedObject.from_obj( obj=result.ok(), diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 2002302a28a..86db5af2329 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -14,6 +14,7 @@ # relative from ...abstract_node import AbstractNode from ...custom_worker.k8s import IN_KUBERNETES +from ...custom_worker.k8s import PodStatus from ...custom_worker.runner_k8s import KubernetesRunner from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable @@ -68,14 +69,14 @@ def start_workers( @service_method( path="worker.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL ) - def list(self, context: AuthedServiceContext) -> Union[SyftSuccess, SyftError]: + def list(self, context: AuthedServiceContext) -> Union[list[SyftWorker], SyftError]: """List all the workers.""" result = self.stash.get_all(context.credentials) if result.is_err(): return SyftError(message=f"Failed to fetch workers. {result.err()}") - workers: List[SyftWorker] = result.ok() + workers: list[SyftWorker] = result.ok() if context.node is not None and context.node.in_memory_workers: return workers @@ -172,8 +173,8 @@ def delete( # relative from .worker_pool_service import SyftWorkerPoolService - worker_pool_service: SyftWorkerPoolService = context.node.get_service( - "SyftWorkerPoolService" + worker_pool_service: AbstractService = context.node.get_service( + SyftWorkerPoolService ) worker_pool_stash = worker_pool_service.stash result = worker_pool_stash.get_by_name( @@ -256,7 +257,7 @@ def refresh_worker_status( workers: List[SyftWorker], worker_stash: WorkerStash, credentials: SyftVerifyKey, -) -> List[SyftWorker]: +) -> list[SyftWorker]: if IN_KUBERNETES: result = refresh_status_kubernetes(workers) else: @@ -282,7 +283,9 @@ def refresh_status_kubernetes(workers: List[SyftWorker]) -> List[SyftWorker]: updated_workers = [] runner = KubernetesRunner() for worker in workers: - status = runner.get_pod_status(pod_name=worker.name) + status: Optional[Union[PodStatus, WorkerStatus]] = runner.get_pod_status( + pod=worker.name + ) if not status: return SyftError(message=f"Pod does not exist. name={worker.name}") status, health, _ = map_pod_to_worker_status(status) diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index bb86b6ef25a..02370444537 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -40,10 +40,10 @@ - use `BlobRetrieval.read` to retrieve the SyftObject `syft_object = blob_retrieval.read()` """ - # stdlib -import os -from pathlib import Path +from io import BytesIO +from typing import Any +from typing import Generator from typing import Optional from typing import Type from typing import Union @@ -66,84 +66,43 @@ from ...types.blob_storage import DEFAULT_CHUNK_SIZE from ...types.blob_storage import SecureFilePathLocation from ...types.grid_url import GridURL -from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_3 +from ...types.syft_object import SYFT_OBJECT_VERSION_4 from ...types.syft_object import SyftObject -from ...types.transforms import drop -from ...types.transforms import make_set_default -from ...types.transforms import str_url_to_grid_url from ...types.uid import UID DEFAULT_TIMEOUT = 10 MAX_RETRIES = 20 -@serializable() -class BlobRetrievalV1(SyftObject): - __canonical_name__ = "BlobRetrieval" - __version__ = SYFT_OBJECT_VERSION_1 - - type_: Optional[Type] - file_name: str - - @serializable() class BlobRetrieval(SyftObject): __canonical_name__ = "BlobRetrieval" __version__ = SYFT_OBJECT_VERSION_2 - type_: Optional[Type] + type_: Optional[Type] = None file_name: str syft_blob_storage_entry_id: Optional[UID] = None - file_size: Optional[int] - - -@migrate(BlobRetrieval, BlobRetrievalV1) -def downgrade_blobretrieval_v2_to_v1(): - return [ - drop(["syft_blob_storage_entry_id", "file_size"]), - ] - - -@migrate(BlobRetrievalV1, BlobRetrieval) -def upgrade_blobretrieval_v1_to_v2(): - return [ - make_set_default("syft_blob_storage_entry_id", None), - make_set_default("file_size", 1), - ] - - -@serializable() -class SyftObjectRetrievalV2(BlobRetrieval): - __canonical_name__ = "SyftObjectRetrieval" - __version__ = SYFT_OBJECT_VERSION_2 - - syft_object: bytes + file_size: Optional[int] = None @serializable() class SyftObjectRetrieval(BlobRetrieval): __canonical_name__ = "SyftObjectRetrieval" - __version__ = SYFT_OBJECT_VERSION_3 + __version__ = SYFT_OBJECT_VERSION_4 syft_object: bytes - path: Path - def _read_data(self, stream=False, _deserialize=True, **kwargs): + def _read_data( + self, stream: bool = False, _deserialize: bool = True, **kwargs: Any + ) -> Any: # development setup, we can access the same filesystem - if os.access(self.path, os.R_OK) and self.path.is_file(): - with open(self.path, "rb") as fp: - res = fp.read() - if _deserialize: - res = deserialize(res, from_bytes=True) - # single container setup, we have to use the data in the object + if not _deserialize: + res = self.syft_object else: - if not _deserialize: - res = self.syft_object - else: - res = deserialize(self.syft_object, from_bytes=True) + res = deserialize(self.syft_object, from_bytes=True) # TODO: implement proper streaming from local files if stream: @@ -151,34 +110,16 @@ def _read_data(self, stream=False, _deserialize=True, **kwargs): else: return res - def read(self, _deserialize=True) -> Union[SyftObject, SyftError]: + def read(self, _deserialize: bool = True) -> Union[SyftObject, SyftError]: return self._read_data(_deserialize=_deserialize) -@migrate(SyftObjectRetrieval, SyftObjectRetrievalV2) -def downgrade_syftobjretrival_v3_to_v2(): - return [ - drop(["path"]), - ] - - -@migrate(SyftObjectRetrievalV2, SyftObjectRetrieval) -def upgrade_syftobjretrival_v2_to_v3(): - return [ - make_set_default("path", Path("")), - ] - - -class BlobRetrievalByURLV1(BlobRetrievalV1): - __canonical_name__ = "BlobRetrievalByURL" - __version__ = SYFT_OBJECT_VERSION_1 - - url: GridURL - - def syft_iter_content( - blob_url, chunk_size, max_retries=MAX_RETRIES, timeout=DEFAULT_TIMEOUT -): + blob_url: Union[str, GridURL], + chunk_size: int, + max_retries: int = MAX_RETRIES, + timeout: int = DEFAULT_TIMEOUT, +) -> Generator: """custom iter content with smart retries (start from last byte read)""" current_byte = 0 for attempt in range(max_retries): @@ -205,13 +146,6 @@ def syft_iter_content( raise -class BlobRetrievalByURLV2(BlobRetrievalV1): - __canonical_name__ = "BlobRetrievalByURL" - __version__ = SYFT_OBJECT_VERSION_2 - - url: GridURL - - @serializable() class BlobRetrievalByURL(BlobRetrieval): __canonical_name__ = "BlobRetrievalByURL" @@ -231,7 +165,13 @@ def read(self) -> Union[SyftObject, SyftError]: else: return self._read_data() - def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, *args, **kwargs): + def _read_data( + self, + stream: bool = False, + chunk_size: int = DEFAULT_CHUNK_SIZE, + *args: Any, + **kwargs: Any, + ) -> Any: # relative from ...client.api import APIRegistry @@ -239,7 +179,7 @@ def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, *args, **kwarg node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) - if api is not None and isinstance(self.url, GridURL): + if api and api.connection and isinstance(self.url, GridURL): blob_url = api.connection.to_blob_route( self.url.url_path, host=self.url.host_or_ip ) @@ -261,33 +201,6 @@ def _read_data(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, *args, **kwarg return SyftError(message=f"Failed to retrieve with Error: {e}") -@migrate(BlobRetrievalByURLV2, BlobRetrievalByURLV1) -def downgrade_blobretrivalbyurl_v2_to_v1(): - return [ - drop(["syft_blob_storage_entry_id", "file_size"]), - ] - - -@migrate(BlobRetrievalByURLV1, BlobRetrievalByURLV2) -def upgrade_blobretrivalbyurl_v1_to_v2(): - return [ - make_set_default("syft_blob_storage_entry_id", None), - make_set_default("file_size", 1), - ] - - -@migrate(BlobRetrievalByURL, BlobRetrievalByURLV2) -def downgrade_blobretrivalbyurl_v3_to_v2(): - return [ - str_url_to_grid_url, - ] - - -@migrate(BlobRetrievalByURLV2, BlobRetrievalByURL) -def upgrade_blobretrivalbyurl_v2_to_v3(): - return [] - - @serializable() class BlobDeposit(SyftObject): __canonical_name__ = "BlobDeposit" @@ -295,8 +208,8 @@ class BlobDeposit(SyftObject): blob_storage_entry_id: UID - def write(self, data: bytes) -> Union[SyftSuccess, SyftError]: - pass + def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: + raise NotImplementedError @serializable() @@ -308,7 +221,7 @@ class BlobStorageConnection: def __enter__(self) -> Self: raise NotImplementedError - def __exit__(self, *exc) -> None: + def __exit__(self, *exc: Any) -> None: raise NotImplementedError def read(self, fp: SecureFilePathLocation, type_: Optional[Type]) -> BlobRetrieval: diff --git a/packages/syft/src/syft/store/blob_storage/on_disk.py b/packages/syft/src/syft/store/blob_storage/on_disk.py index 1ceebfdb129..45b5b848880 100644 --- a/packages/syft/src/syft/store/blob_storage/on_disk.py +++ b/packages/syft/src/syft/store/blob_storage/on_disk.py @@ -41,6 +41,8 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: syft_node_location=self.syft_node_location, syft_client_verify_key=self.syft_client_verify_key, ) + if write_to_disk_method is None: + return SyftError(message="write_to_disk_method is None") return write_to_disk_method(data=data.read(), uid=self.blob_storage_entry_id) @@ -53,17 +55,16 @@ def __init__(self, base_directory: Path) -> None: def __enter__(self) -> Self: return self - def __exit__(self, *exc) -> None: + def __exit__(self, *exc: Any) -> None: pass def read( - self, fp: SecureFilePathLocation, type_: Optional[Type], **kwargs + self, fp: SecureFilePathLocation, type_: Optional[Type], **kwargs: Any ) -> BlobRetrieval: file_path = self._base_directory / fp.path return SyftObjectRetrieval( syft_object=file_path.read_bytes(), file_name=file_path.name, - path=file_path, type_=type_, ) diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 85524236a37..9abb5da6984 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -3,7 +3,9 @@ import math from queue import Queue import threading +from typing import Any from typing import Dict +from typing import Generator from typing import List from typing import Optional from typing import Type @@ -35,11 +37,7 @@ from ...types.blob_storage import SeaweedSecureFilePathLocation from ...types.blob_storage import SecureFilePathLocation from ...types.grid_url import GridURL -from ...types.syft_migration import migrate -from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.transforms import drop -from ...types.transforms import make_set_default from ...util.constants import DEFAULT_TIMEOUT WRITE_EXPIRATION_TIME = 900 # seconds @@ -47,14 +45,6 @@ DEFAULT_UPLOAD_CHUNK_SIZE = 819200 -@serializable() -class SeaweedFSBlobDepositV1(BlobDeposit): - __canonical_name__ = "SeaweedFSBlobDeposit" - __version__ = SYFT_OBJECT_VERSION_1 - - urls: List[GridURL] - - @serializable() class SeaweedFSBlobDeposit(BlobDeposit): __canonical_name__ = "SeaweedFSBlobDeposit" @@ -94,7 +84,7 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: self.urls, start=1, ): - if api is not None: + if api is not None and api.connection is not None: blob_url = api.connection.to_blob_route( url.url_path, host=url.host_or_ip ) @@ -103,10 +93,12 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: # read a chunk untill we have read part_size class PartGenerator: - def __init__(self): + def __init__(self) -> None: self.no_lines = 0 - def async_generator(self, chunk_size=DEFAULT_UPLOAD_CHUNK_SIZE): + def async_generator( + self, chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE + ) -> Generator: item_queue: Queue = Queue() threading.Thread( target=self.add_chunks_to_queue, @@ -120,8 +112,10 @@ def async_generator(self, chunk_size=DEFAULT_UPLOAD_CHUNK_SIZE): item = item_queue.get() def add_chunks_to_queue( - self, queue, chunk_size=DEFAULT_UPLOAD_CHUNK_SIZE - ): + self, + queue: Queue, + chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE, + ) -> None: """Creates a data geneator for the part""" n = 0 @@ -160,25 +154,13 @@ def add_chunks_to_queue( syft_node_location=self.syft_node_location, syft_client_verify_key=self.syft_client_verify_key, ) + if mark_write_complete_method is None: + return SyftError(message="mark_write_complete_method is None") return mark_write_complete_method( etags=etags, uid=self.blob_storage_entry_id, no_lines=no_lines ) -@migrate(SeaweedFSBlobDeposit, SeaweedFSBlobDepositV1) -def downgrade_seaweedblobdeposit_v2_to_v1(): - return [ - drop(["size"]), - ] - - -@migrate(SeaweedFSBlobDepositV1, SeaweedFSBlobDeposit) -def upgrade_seaweedblobdeposit_v1_to_v2(): - return [ - make_set_default("size", 1), - ] - - @serializable() class SeaweedFSClientConfig(BlobStorageClientConfig): host: str @@ -240,11 +222,14 @@ def __init__( def __enter__(self) -> Self: return self - def __exit__(self, *exc) -> None: + def __exit__(self, *exc: Any) -> None: self.client.close() def read( - self, fp: SecureFilePathLocation, type_: Optional[Type], bucket_name=None + self, + fp: SecureFilePathLocation, + type_: Optional[Type], + bucket_name: Optional[str] = None, ) -> BlobRetrieval: if bucket_name is None: bucket_name = self.default_bucket_name diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py index 516a2fc85c5..7f0aa6e1e64 100644 --- a/packages/syft/src/syft/store/dict_document_store.py +++ b/packages/syft/src/syft/store/dict_document_store.py @@ -18,11 +18,12 @@ @serializable() -class DictBackingStore(dict, KeyValueBackingStore): +class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc] + # TODO: fix the mypy issue """Dictionary-based Store core logic""" def __init__(self, *args: Any, **kwargs: Any) -> None: - super(dict).__init__() + super().__init__() self._ddtype = kwargs.get("ddtype", None) def __getitem__(self, key: Any) -> Any: @@ -46,7 +47,7 @@ class DictStorePartition(KeyValueStorePartition): DictStore specific configuration """ - def prune(self): + def prune(self) -> None: self.init_store() @@ -71,7 +72,7 @@ def __init__( store_config = DictStoreConfig() super().__init__(root_verify_key=root_verify_key, store_config=store_config) - def reset(self): + def reset(self) -> None: for _, partition in self.partitions.items(): partition.prune() @@ -91,7 +92,6 @@ class DictStoreConfig(StoreConfig): * NoLockingConfig: no locking, ideal for single-thread stores. * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. * FileLockingConfig: file based locking, ideal for same-device different-processes/threads stores. - * RedisLockingConfig: Redis-based locking, ideal for multi-device stores. Defaults to ThreadingLockingConfig. """ diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 5151f43294c..88566a2f9b0 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -29,7 +29,7 @@ from ..service.context import AuthedServiceContext from ..service.response import SyftSuccess from ..types.base import SyftBaseModel -from ..types.syft_object import SYFT_OBJECT_VERSION_1 +from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftBaseObject from ..types.syft_object import SyftObject from ..types.uid import UID @@ -59,7 +59,7 @@ def first_or_none(result: Any) -> Ok: if sys.version_info >= (3, 9): - def is_generic_alias(t: type): + def is_generic_alias(t: type) -> bool: return isinstance(t, (types.GenericAlias, typing._GenericAlias)) else: @@ -104,7 +104,7 @@ def extract_list(self, obj: Any) -> List: obj = [obj] # is a list type so lets compare directly - check_type("obj", obj, self.type_) + check_type(obj, self.type_) return obj @property @@ -117,7 +117,7 @@ class PartitionKeys(BaseModel): pks: Union[PartitionKey, Tuple[PartitionKey, ...], List[PartitionKey]] @property - def all(self) -> List[PartitionKey]: + def all(self) -> Union[tuple[PartitionKey, ...], list[PartitionKey]]: # make sure we always return a list even if there's a single value return self.pks if isinstance(self.pks, (tuple, list)) else [self.pks] @@ -140,7 +140,7 @@ def from_dict(cks_dict: Dict[str, type]) -> PartitionKeys: @serializable() class QueryKey(PartitionKey): - value: Any + value: Any = None def __eq__(self, other: Any) -> bool: return ( @@ -170,8 +170,9 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: # object has a method for getting these types # we can't use properties because we don't seem to be able to get the # return types - if isinstance(pk_value, (types.FunctionType, types.MethodType)): - pk_value = pk_value() + # TODO: fix the mypy issue + if isinstance(pk_value, (types.FunctionType, types.MethodType)): # type: ignore[unreachable] + pk_value = pk_value() # type: ignore[unreachable] if pk_value and not isinstance(pk_value, pk_type): raise Exception( @@ -180,11 +181,11 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: return QueryKey(key=pk_key, type_=pk_type, value=pk_value) @property - def as_dict(self): + def as_dict(self) -> dict[str, Any]: return {self.key: self.value} @property - def as_dict_mongo(self): + def as_dict_mongo(self) -> dict[str, Any]: key = self.key if key == "id": key = "_id" @@ -199,8 +200,8 @@ class PartitionKeysWithUID(PartitionKeys): uid_pk: PartitionKey @property - def all(self) -> List[PartitionKey]: - all_keys = self.pks if isinstance(self.pks, (tuple, list)) else [self.pks] + def all(self) -> Union[tuple[PartitionKey, ...], list[PartitionKey]]: + all_keys = list(self.pks) if isinstance(self.pks, (tuple, list)) else [self.pks] if self.uid_pk not in all_keys: all_keys.insert(0, self.uid_pk) return all_keys @@ -211,7 +212,7 @@ class QueryKeys(SyftBaseModel): qks: Union[QueryKey, Tuple[QueryKey, ...], List[QueryKey]] @property - def all(self) -> List[QueryKey]: + def all(self) -> Union[tuple[QueryKey, ...], list[QueryKey]]: # make sure we always return a list even if there's a single value return self.qks if isinstance(self.qks, (tuple, list)) else [self.qks] @@ -260,7 +261,7 @@ def from_dict(qks_dict: Dict[str, Any]) -> QueryKeys: return QueryKeys(qks=qks) @property - def as_dict(self): + def as_dict(self) -> dict: qk_dict = {} for qk in self.all: qk_key = qk.key @@ -269,7 +270,7 @@ def as_dict(self): return qk_dict @property - def as_dict_mongo(self): + def as_dict_mongo(self) -> dict: qk_dict = {} for qk in self.all: qk_key = qk.key @@ -316,7 +317,7 @@ class StorePartition: def __init__( self, - root_verify_key: SyftVerifyKey, + root_verify_key: Optional[SyftVerifyKey], settings: PartitionSettings, store_config: StoreConfig, ) -> None: @@ -337,7 +338,7 @@ def init_store(self) -> Result[Ok, Err]: except BaseException as e: return Err(str(e)) - return Ok() + return Ok(True) def matches_unique_cks(self, partition_key: PartitionKey) -> bool: return partition_key in self.unique_cks @@ -352,7 +353,9 @@ def store_query_keys(self, objs: Any) -> QueryKeys: return QueryKeys(qks=[self.store_query_key(obj) for obj in objs]) # Thread-safe methods - def _thread_safe_cbk(self, cbk: Callable, *args, **kwargs): + def _thread_safe_cbk( + self, cbk: Callable, *args: Any, **kwargs: Any + ) -> Union[Any, Err]: locked = self.lock.acquire(blocking=True) if not locked: print("FAILED TO LOCK") @@ -423,7 +426,7 @@ def update( credentials: SyftVerifyKey, qk: QueryKey, obj: SyftObject, - has_permission=False, + has_permission: bool = False, ) -> Result[SyftObject, str]: return self._thread_safe_cbk( self._update, @@ -444,7 +447,7 @@ def get_all_from_store( ) def delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission=False + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False ) -> Result[SyftSuccess, Err]: return self._thread_safe_cbk( self._delete, credentials, qk, has_permission=has_permission @@ -475,12 +478,21 @@ def migrate_data( # These methods are called from the public thread-safe API, and will hang the process. def _set( self, + credentials: SyftVerifyKey, obj: SyftObject, + add_permissions: Optional[List[ActionObjectPermission]] = None, ignore_duplicates: bool = False, ) -> Result[SyftObject, str]: raise NotImplementedError - def _update(self, qk: QueryKey, obj: SyftObject) -> Result[SyftObject, str]: + def _update( + self, + credentials: SyftVerifyKey, + qk: QueryKey, + obj: SyftObject, + has_permission: bool = False, + overwrite: bool = False, + ) -> Result[SyftObject, str]: raise NotImplementedError def _get_all_from_store( @@ -491,10 +503,17 @@ def _get_all_from_store( ) -> Result[List[SyftObject], str]: raise NotImplementedError - def _delete(self, qk: QueryKey) -> Result[SyftSuccess, Err]: + def _delete( + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False + ) -> Result[SyftSuccess, Err]: raise NotImplementedError - def _all(self) -> Result[List[BaseStash.object_type], str]: + def _all( + self, + credentials: SyftVerifyKey, + order_by: Optional[PartitionKey] = None, + has_permission: Optional[bool] = False, + ) -> Result[List[BaseStash.object_type], str]: raise NotImplementedError def add_permission(self, permission: ActionObjectPermission) -> None: @@ -688,7 +707,7 @@ def find_and_delete( return self.delete(credentials=credentials, qk=qk) def delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission=False + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False ) -> Result[SyftSuccess, Err]: return self.partition.delete( credentials=credentials, qk=qk, has_permission=has_permission @@ -698,7 +717,7 @@ def update( self, credentials: SyftVerifyKey, obj: BaseStash.object_type, - has_permission=False, + has_permission: bool = False, ) -> Result[BaseStash.object_type, str]: qk = self.partition.store_query_key(obj) return self.partition.update( @@ -756,13 +775,12 @@ class StoreConfig(SyftBaseObject): * NoLockingConfig: no locking, ideal for single-thread stores. * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. * FileLockingConfig: file based locking, ideal for same-device different-processes/threads stores. - * RedisLockingConfig: Redis-based locking, ideal for multi-device stores. Defaults to NoLockingConfig. """ __canonical_name__ = "StoreConfig" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 store_type: Type[DocumentStore] - client_config: Optional[StoreClientConfig] + client_config: Optional[StoreClientConfig] = None locking_config: LockingConfig = NoLockingConfig() diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 4c9ad60b122..1b8ce0f9280 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -5,6 +5,7 @@ from collections import defaultdict from enum import Enum from typing import Any +from typing import Dict from typing import List from typing import Optional from typing import Set @@ -59,16 +60,16 @@ def __repr__(self) -> str: def __len__(self) -> int: raise NotImplementedError - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: raise NotImplementedError - def clear(self) -> Self: + def clear(self) -> None: raise NotImplementedError def copy(self) -> Self: raise NotImplementedError - def update(self, *args: Any, **kwargs: Any) -> Self: + def update(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError def keys(self) -> Any: @@ -123,7 +124,8 @@ def init_store(self) -> Result[Ok, Err]: self.searchable_keys = self.store_config.backing_store( "searchable_keys", self.settings, self.store_config ) - self.permissions = self.store_config.backing_store( + # uid -> set['<uid>_permission'] + self.permissions: Dict[UID, Set[str]] = self.store_config.backing_store( "permissions", self.settings, self.store_config, ddtype=set ) @@ -139,7 +141,7 @@ def init_store(self) -> Result[Ok, Err]: except BaseException as e: return Err(str(e)) - return Ok() + return Ok(True) def __len__(self) -> int: return len(self.data) @@ -175,8 +177,8 @@ def _set( ignore_duplicates: bool = False, ) -> Result[SyftObject, str]: try: - if obj.id is None: - obj.id = UID() + # if obj.id is None: + # obj.id = UID() store_query_key: QueryKey = self.settings.store_key.with_obj(obj) uid = store_query_key.value write_permission = ActionObjectWRITE(uid=uid, credentials=credentials) @@ -248,7 +250,7 @@ def add_permission(self, permission: ActionObjectPermission) -> None: permissions.add(permission.permission_string) self.permissions[permission.uid] = permissions - def remove_permission(self, permission: ActionObjectPermission): + def remove_permission(self, permission: ActionObjectPermission) -> None: permissions = self.permissions[permission.uid] permissions.remove(permission.permission_string) self.permissions[permission.uid] = permissions @@ -262,7 +264,10 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: raise Exception(f"ObjectPermission type: {permission.permission} not valid") # TODO: fix for other admins - if self.root_verify_key.verify == permission.credentials.verify: + if ( + permission.credentials + and self.root_verify_key.verify == permission.credentials.verify + ): return True if ( @@ -370,8 +375,8 @@ def _update( credentials: SyftVerifyKey, qk: QueryKey, obj: SyftObject, - has_permission=False, - overwrite=False, + has_permission: bool = False, + overwrite: bool = False, ) -> Result[SyftObject, str]: try: if qk.value not in self.data: @@ -449,7 +454,7 @@ def create(self, obj: SyftObject) -> Result[SyftObject, str]: pass def _delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission=False + self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False ) -> Result[SyftSuccess, Err]: try: if has_permission or self.has_permission( @@ -486,9 +491,9 @@ def _delete_search_keys_for(self, obj: SyftObject) -> Result[SyftSuccess, str]: def _get_keys_index(self, qks: QueryKeys) -> Result[Set[Any], str]: try: # match AND - subsets = [] + subsets: list = [] for qk in qks.all: - subset = {} + subset: set = set() pk_key, pk_value = qk.key, qk.value if pk_key not in self.unique_keys: return Err(f"Failed to query index with {qk}") @@ -515,7 +520,7 @@ def _find_keys_search(self, qks: QueryKeys) -> Result[Set[QueryKey], str]: # match AND subsets = [] for qk in qks.all: - subset = {} + subset: set = set() pk_key, pk_value = qk.key, qk.value if pk_key not in self.searchable_keys: return Err(f"Failed to search with {qk}") diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 97611d56b64..b2ddad102e5 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -29,6 +29,8 @@ class LinkedObject(SyftObject): object_type: Type[SyftObject] object_uid: UID + __exclude_sync_diff_attrs__ = ["node_uid"] + def __str__(self) -> str: resolved_obj_type = ( type(self.resolve) if self.object_type is None else self.object_type @@ -44,15 +46,20 @@ def resolve(self) -> SyftObject: node_uid=self.node_uid, user_verify_key=self.syft_client_verify_key, ) + if api is None: + raise ValueError(f"api is None. You must login to {self.node_uid}") + return api.services.notifications.resolve_object(self) def resolve_with_context(self, context: NodeServiceContext) -> Any: + if context.node is None: + raise ValueError(f"context {context}'s node is None") return context.node.get_service(self.service_type).resolve_link( context=context, linked_obj=self ) def update_with_context( - self, context: NodeServiceContext, obj: Any + self, context: Union[NodeServiceContext, ChangeContext, Any], obj: Any ) -> Union[SyftSuccess, SyftError]: if isinstance(context, AuthedServiceContext): credentials = context.credentials @@ -60,15 +67,19 @@ def update_with_context( credentials = context.approving_user_credentials else: return SyftError(message="wrong context passed") - result = context.node.get_service(self.service_type).stash.update( - credentials, obj - ) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") + service = context.node.get_service(self.service_type) + if hasattr(service, "stash"): + result = service.stash.update(credentials, obj) + else: + return SyftError(message=f"service {service} does not have a stash") return result @classmethod def from_obj( cls, - obj: SyftObject, + obj: Union[SyftObject, Type[SyftObject]], service_type: Optional[Type[Any]] = None, node_uid: Optional[UID] = None, ) -> Self: @@ -119,6 +130,8 @@ def with_context( if object_uid is None: raise Exception(f"{cls} Requires an object UID") + if context.node is None: + raise ValueError(f"context {context}'s node is None") node_uid = context.node.id return LinkedObject( @@ -128,14 +141,15 @@ def with_context( object_uid=object_uid, ) - @staticmethod + @classmethod def from_uid( + cls, object_uid: UID, object_type: Type[SyftObject], service_type: Type[Any], node_uid: UID, ) -> Self: - return LinkedObject( + return cls( node_uid=node_uid, service_type=service_type, object_type=object_type, diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index a32bcd67c8d..d7fd0e1ef95 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -5,6 +5,7 @@ from pathlib import Path import threading import time +from typing import Any from typing import Callable from typing import Dict from typing import Optional @@ -12,10 +13,8 @@ # third party from pydantic import BaseModel -import redis from sherlock.lock import BaseLock from sherlock.lock import FileLock -from sherlock.lock import RedisLock # relative from ..serde.serializable import serializable @@ -73,34 +72,18 @@ class FileLockingConfig(LockingConfig): client_path: Optional[Path] = None -@serializable() -class RedisClientConfig(BaseModel): - host: str = "localhost" - port: int = 6379 - db: int = 0 - username: Optional[str] = None - password: Optional[str] = None - - -@serializable() -class RedisLockingConfig(LockingConfig): - """Redis locking policy""" - - client: RedisClientConfig = RedisClientConfig() - - class ThreadingLock(BaseLock): """ Threading-based Lock. Used to provide the same API as the rest of the locks. """ - def __init__(self, expire: int, **kwargs): + def __init__(self, expire: int, **kwargs: Any) -> None: self.expire = expire - self.locked_timestamp = 0 + self.locked_timestamp: float = 0.0 self.lock = threading.Lock() @property - def _locked(self): + def _locked(self) -> bool: """ Implementation of method to check if lock has been acquired. Must be :returns: if the lock is acquired or not @@ -116,7 +99,7 @@ def _locked(self): return self.lock.locked() - def _acquire(self): + def _acquire(self) -> bool: """ Implementation of acquiring a lock in a non-blocking fashion. :returns: if the lock was successfully acquired or not @@ -137,7 +120,7 @@ def _acquire(self): self.locked_timestamp = time.time() return status - def _release(self): + def _release(self) -> None: """ Implementation of releasing an acquired lock. """ @@ -166,7 +149,7 @@ class PatchedFileLock(FileLock): """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self._lock_file_enabled = True try: super().__init__(*args, **kwargs) @@ -203,7 +186,7 @@ def _thread_safe_cbk(self, cbk: Callable) -> bool: def _acquire(self) -> bool: return self._thread_safe_cbk(self._acquire_file_lock) - def _release(self) -> None: + def _release(self) -> bool: res = self._thread_safe_cbk(self._release_file_lock) return res @@ -245,12 +228,12 @@ def _acquire_file_lock(self) -> bool: self._data_file.write_text(json.dumps(data)) # We succeeded in writing to the file so we now hold the lock. - self._owner = owner + self._owner: Optional[str] = owner return True @property - def _locked(self): + def _locked(self) -> bool: if self._lock_py_thread.locked(): return True @@ -340,23 +323,16 @@ def __init__(self, config: LockingConfig): elif isinstance(config, ThreadingLockingConfig): self._lock = ThreadingLock(**base_params) elif isinstance(config, FileLockingConfig): - client = config.client_path + client: Optional[Path] = config.client_path self._lock = PatchedFileLock( **base_params, client=client, ) - elif isinstance(config, RedisLockingConfig): - client = redis.StrictRedis(**config.client.dict()) - - self._lock = RedisLock( - **base_params, - client=client, - ) else: raise ValueError("Unsupported config type") @property - def _locked(self): + def _locked(self) -> bool: """ Implementation of method to check if lock has been acquired. @@ -365,8 +341,7 @@ def _locked(self): """ if self.passthrough: return False - - return self._lock.locked() + return self._lock.locked() if self._lock else False def acquire(self, blocking: bool = True) -> bool: """ @@ -380,9 +355,9 @@ def acquire(self, blocking: bool = True) -> bool: if not blocking: return self._acquire() - timeout = self.timeout + timeout: float = float(self.timeout) start_time = time.time() - elapsed = 0 + elapsed: float = 0.0 while timeout >= elapsed: if not self._acquire(): time.sleep(self.retry_interval) @@ -407,21 +382,22 @@ def _acquire(self) -> bool: return True try: - return self._lock._acquire() + return self._lock._acquire() if self._lock else False except BaseException: return False - def _release(self): + def _release(self) -> Optional[bool]: """ Implementation of releasing an acquired lock. """ if self.passthrough: - return - + return None + if not self._lock: + return None try: return self._lock._release() except BaseException: - pass + return None def _renew(self) -> bool: """ @@ -430,4 +406,4 @@ def _renew(self) -> bool: if self.passthrough: return True - return self._lock._renew() + return self._lock._renew() if self._lock else False diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py index cbbf1c5d4f0..c5fc0fae783 100644 --- a/packages/syft/src/syft/store/mongo_client.py +++ b/packages/syft/src/syft/store/mongo_client.py @@ -126,7 +126,7 @@ class MongoStoreClientConfig(StoreClientConfig): class MongoClientCache: - __client_cache__: Dict[str, Type["MongoClient"]] = {} + __client_cache__: Dict[int, Optional[Type["MongoClient"]]] = {} _lock: Lock = Lock() @classmethod @@ -184,7 +184,7 @@ def connect(self, config: MongoStoreClientConfig) -> Result[Ok, Err]: self.client = None return Err(str(e)) - return Ok() + return Ok(True) def with_db(self, db_name: str) -> Result[MongoDatabase, Err]: try: @@ -239,6 +239,6 @@ def with_collection_permissions( return Ok(collection_permissions) - def close(self): + def close(self) -> None: self.client.close() MongoClientCache.__client_cache__.pop(hash(str(self.config)), None) diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py index efdd6496154..b4a67b41d41 100644 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ b/packages/syft/src/syft/store/mongo_document_store.py @@ -1,11 +1,11 @@ # stdlib from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional from typing import Set from typing import Type -from typing import Union # third party from pymongo import ASCENDING @@ -62,11 +62,11 @@ class MongoDict(SyftBaseObject): def dict(self) -> Dict[Any, Any]: return dict(zip(self.keys, self.values)) - @staticmethod - def from_dict(input: Dict[Any, Any]) -> Self: - return MongoDict(keys=list(input.keys()), values=list(input.values())) + @classmethod + def from_dict(cls, input: Dict[Any, Any]) -> Self: + return cls(keys=list(input.keys()), values=list(input.values())) - def __repr__(self): + def __repr__(self) -> str: return self.dict.__repr__() @@ -82,30 +82,34 @@ def _repr_debug_(value: Any) -> str: def to_mongo(context: TransformContext) -> TransformContext: output = {} - unique_keys_dict = context.obj._syft_unique_keys_dict() - search_keys_dict = context.obj._syft_searchable_keys_dict() - all_dict = unique_keys_dict - all_dict.update(search_keys_dict) - for k in all_dict: - value = getattr(context.obj, k, "") - # if the value is a method, store its value - if callable(value): - output[k] = value() - else: - output[k] = value + if context.obj: + unique_keys_dict = context.obj._syft_unique_keys_dict() + search_keys_dict = context.obj._syft_searchable_keys_dict() + all_dict = unique_keys_dict + all_dict.update(search_keys_dict) + for k in all_dict: + value = getattr(context.obj, k, "") + # if the value is a method, store its value + if callable(value): + output[k] = value() + else: + output[k] = value - if "id" in context.output: + output["__canonical_name__"] = context.obj.__canonical_name__ + output["__version__"] = context.obj.__version__ + output["__blob__"] = _serialize(context.obj, to_bytes=True) + output["__arepr__"] = _repr_debug_(context.obj) # a comes first in alphabet + + if context.output and "id" in context.output: output["_id"] = context.output["id"] - output["__canonical_name__"] = context.obj.__canonical_name__ - output["__version__"] = context.obj.__version__ - output["__blob__"] = _serialize(context.obj, to_bytes=True) - output["__arepr__"] = _repr_debug_(context.obj) # a comes first in alphabet + context.output = output + return context @transform(SyftObject, MongoBsonObject) -def syft_obj_to_mongo(): +def syft_obj_to_mongo() -> list[Callable]: return [to_mongo] @@ -127,7 +131,7 @@ class MongoStorePartition(StorePartition): Mongo specific configuration """ - storage_type: StorableObjectType = MongoBsonObject + storage_type: Type[StorableObjectType] = MongoBsonObject def init_store(self) -> Result[Ok, Err]: store_status = super().init_store() @@ -167,7 +171,9 @@ def _create_update_index(self) -> Result[Ok, Err]: return collection_status collection: MongoCollection = collection_status.ok() - def check_index_keys(current_keys, new_index_keys): + def check_index_keys( + current_keys: list[tuple[str, int]], new_index_keys: list[tuple[str, int]] + ) -> bool: current_keys.sort() new_index_keys.sort() return current_keys == new_index_keys @@ -190,7 +196,7 @@ def check_index_keys(current_keys, new_index_keys): if current_index_keys is not None: keys_same = check_index_keys(current_index_keys["key"], new_index_keys) if keys_same: - return Ok() + return Ok(True) # Drop current index, since incompatible with current object try: @@ -202,7 +208,7 @@ def check_index_keys(current_keys, new_index_keys): # If no new indexes, then skip index creation if len(new_index_keys) == 0: - return Ok() + return Ok(True) try: collection.create_index(new_index_keys, unique=True, name=index_name) @@ -211,7 +217,7 @@ def check_index_keys(current_keys, new_index_keys): f"Failed to create index for {object_name} with index keys: {new_index_keys}" ) - return Ok() + return Ok(True) @property def collection(self) -> Result[MongoCollection, Err]: @@ -231,7 +237,7 @@ def permissions(self) -> Result[MongoCollection, Err]: return Ok(self._permissions) - def set(self, *args, **kwargs): + def set(self, *args: Any, **kwargs: Any) -> Result[SyftObject, str]: return self._set(*args, **kwargs) def _set( @@ -244,7 +250,7 @@ def _set( # TODO: Refactor this function since now it's doing both set and # update at the same time write_permission = ActionObjectWRITE(uid=obj.id, credentials=credentials) - can_write = self.has_permission(write_permission) + can_write: bool = self.has_permission(write_permission) store_query_key: QueryKey = self.settings.store_key.with_obj(obj) collection_status = self.collection @@ -258,7 +264,7 @@ def _set( if (not store_key_exists) and (not self.item_keys_exist(obj, collection)): # attempt to claim ownership for writing ownership_result = self.take_ownership(uid=obj.id, credentials=credentials) - can_write: bool = ownership_result.is_ok() + can_write = ownership_result.is_ok() elif not ignore_duplicates: unique_query_keys: QueryKeys = self.settings.unique_keys.with_obj(obj) keys = ", ".join(f"`{key.key}`" for key in unique_query_keys.all) @@ -291,7 +297,7 @@ def _set( else: return Err(f"No permission to write object with id {obj.id}") - def item_keys_exist(self, obj, collection): + def item_keys_exist(self, obj: SyftObject, collection: MongoCollection) -> bool: qks: QueryKeys = self.settings.unique_keys.with_obj(obj) query = {"$or": [{k: v} for k, v in qks.as_dict_mongo.items()]} res = collection.find_one(query) @@ -303,6 +309,7 @@ def _update( qk: QueryKey, obj: SyftObject, has_permission: bool = False, + overwrite: bool = False, ) -> Result[SyftObject, str]: collection_status = self.collection if collection_status.is_err(): @@ -355,13 +362,13 @@ def _find_index_or_search_keys( order_by: Optional[PartitionKey] = None, ) -> Result[List[SyftObject], str]: # TODO: pass index as hint to find method - qks = QueryKeys(qks=(index_qks.all + search_qks.all)) + qks = QueryKeys(qks=(list(index_qks.all) + list(search_qks.all))) return self._get_all_from_store( credentials=credentials, qks=qks, order_by=order_by ) @property - def data(self): + def data(self) -> dict: values: List = self._all(credentials=None, has_permission=True).ok() return {v.id: v for v in values} @@ -447,7 +454,10 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: return False # TODO: fix for other admins - if self.root_verify_key.verify == permission.credentials.verify: + if ( + permission.credentials + and self.root_verify_key.verify == permission.credentials.verify + ): return True if permission.permission_string in permissions["permissions"]: @@ -535,8 +545,8 @@ def take_ownership( return collection_status collection: MongoCollection = collection_status.ok() - data: List[UID] = collection.find_one({"_id": uid}) - permissions: List[UID] = collection_permissions.find_one({"_id": uid}) + data: Optional[List[UID]] = collection.find_one({"_id": uid}) + permissions: Optional[List[UID]] = collection_permissions.find_one({"_id": uid}) # first person using this UID can claim ownership if permissions is None and data is None: @@ -557,7 +567,7 @@ def _all( credentials: SyftVerifyKey, order_by: Optional[PartitionKey] = None, has_permission: Optional[bool] = False, - ): + ) -> Result[List[SyftObject], str]: qks = QueryKeys(qks=()) return self._get_all_from_store( credentials=credentials, @@ -566,7 +576,7 @@ def _all( has_permission=has_permission, ) - def __len__(self): + def __len__(self) -> int: collection_status = self.collection if collection_status.is_err(): return 0 @@ -653,7 +663,7 @@ def __init__( self.ddtype = ddtype self.init_client() - def init_client(self) -> Union[None, Err]: + def init_client(self) -> Optional[Err]: self.client = MongoClient(config=self.store_config.client_config) collection_status = self.client.with_collection( @@ -664,12 +674,13 @@ def init_client(self) -> Union[None, Err]: if collection_status.is_err(): return collection_status self._collection: MongoCollection = collection_status.ok() + return None @property def collection(self) -> Result[MongoCollection, Err]: if not hasattr(self, "_collection"): res = self.init_client() - if res.is_err(): + if res is not None and res.is_err(): return res return Ok(self._collection) @@ -757,7 +768,7 @@ def _len(self) -> int: def __len__(self) -> int: return self._len() - def _delete(self, key: UID) -> None: + def _delete(self, key: UID) -> Result[SyftSuccess, Err]: collection_status = self.collection if collection_status.is_err(): return collection_status @@ -765,8 +776,9 @@ def _delete(self, key: UID) -> None: result = collection.delete_one({"_id": key}) if result.deleted_count != 1: raise KeyError(f"{key} does not exist") + return Ok(SyftSuccess(message="Deleted")) - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: self._delete(key) def _delete_all(self) -> None: @@ -776,7 +788,7 @@ def _delete_all(self) -> None: collection: MongoCollection = collection_status.ok() collection.delete_many({}) - def clear(self) -> Self: + def clear(self) -> None: self._delete_all() def _get_all(self) -> Any: @@ -818,14 +830,14 @@ def copy(self) -> Self: # 🟡 TODO raise NotImplementedError - def update(self, *args: Any, **kwargs: Any) -> Self: + def update(self, *args: Any, **kwargs: Any) -> None: """ Inserts the specified items to the dictionary. """ # 🟡 TODO raise NotImplementedError - def __del__(self): + def __del__(self) -> None: """ Close the mongo client connection: - Cleanup client resources and disconnect from MongoDB @@ -852,13 +864,12 @@ class MongoStoreConfig(StoreConfig): * NoLockingConfig: no locking, ideal for single-thread stores. * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. * FileLockingConfig: file based locking, ideal for same-device different-processes/threads stores. - * RedisLockingConfig: Redis-based locking, ideal for multi-device stores. Defaults to NoLockingConfig. """ client_config: MongoStoreClientConfig store_type: Type[DocumentStore] = MongoDocumentStore db_name: str = "app" - backing_store = MongoBackingStore + backing_store: Type[KeyValueBackingStore] = MongoBackingStore # TODO: should use a distributed lock, with RedisLockingConfig locking_config: LockingConfig = NoLockingConfig() diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index fe1201ef92b..75e24367376 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -16,7 +16,7 @@ # third party from pydantic import Field -from pydantic import validator +from pydantic import field_validator from result import Err from result import Ok from result import Result @@ -58,7 +58,7 @@ def _repr_debug_(value: Any) -> str: return repr(value) -def raise_exception(table_name: str, e: Exception): +def raise_exception(table_name: str, e: Exception) -> None: if "disk I/O error" in str(e): message = f"Error usually related to concurrent writes. {str(e)}" raise Exception(message) @@ -101,8 +101,10 @@ def __init__( self.settings = settings self.store_config = store_config self._ddtype = ddtype - self.file_path = self.store_config.client_config.file_path - self.db_filename = store_config.client_config.filename + if self.store_config.client_config: + self.file_path = self.store_config.client_config.file_path + if store_config.client_config: + self.db_filename = store_config.client_config.filename # if tempfile.TemporaryDirectory() varies from process to process # could this cause different locks on the same file @@ -127,16 +129,17 @@ def _connect(self) -> None: if not path.exists(): path.parent.mkdir(parents=True, exist_ok=True) - connection = sqlite3.connect( - self.file_path, - timeout=self.store_config.client_config.timeout, - check_same_thread=False, # do we need this if we use the lock? - # check_same_thread=self.store_config.client_config.check_same_thread, - ) - # TODO: Review OSX compatibility. - # Set journal mode to WAL. - # connection.execute("pragma journal_mode=wal") - SQLITE_CONNECTION_POOL_DB[cache_key(self.db_filename)] = connection + if self.store_config.client_config: + connection = sqlite3.connect( + self.file_path, + timeout=self.store_config.client_config.timeout, + check_same_thread=False, # do we need this if we use the lock? + # check_same_thread=self.store_config.client_config.check_same_thread, + ) + # TODO: Review OSX compatibility. + # Set journal mode to WAL. + # connection.execute("pragma journal_mode=wal") + SQLITE_CONNECTION_POOL_DB[cache_key(self.db_filename)] = connection def create_table(self) -> None: try: @@ -183,7 +186,7 @@ def _execute( ) -> Result[Ok[sqlite3.Cursor], Err[str]]: with SyftLock(self.lock_config): cursor: Optional[sqlite3.Cursor] = None - err = None + # err = None try: cursor = self.cur.execute(sql, *args) except Exception as e: @@ -196,8 +199,8 @@ def _execute( # err = Err(str(e)) self.db.commit() # Commit if everything went ok - if err is not None: - return err + # if err is not None: + # return err return Ok(cursor) @@ -323,10 +326,10 @@ def __repr__(self) -> str: def __len__(self) -> int: return self._len() - def __delitem__(self, key: str): + def __delitem__(self, key: str) -> None: self._delete(key) - def clear(self) -> Self: + def clear(self) -> None: self._delete_all() def copy(self) -> Self: @@ -352,7 +355,7 @@ def __contains__(self, key: Any) -> bool: def __iter__(self) -> Any: return iter(self.keys()) - def __del__(self): + def __del__(self) -> None: try: self._close() except BaseException: @@ -434,7 +437,8 @@ class SQLiteStoreClientConfig(StoreClientConfig): # We need this in addition to Field(default_factory=...) # so users can still do SQLiteStoreClientConfig(path=None) - @validator("path", pre=True) + @field_validator("path", mode="before") + @classmethod def __default_path(cls, path: Optional[Union[str, Path]]) -> Union[str, Path]: if path is None: return tempfile.gettempdir() @@ -462,7 +466,6 @@ class SQLiteStoreConfig(StoreConfig): * NoLockingConfig: no locking, ideal for single-thread stores. * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. * FileLockingConfig: file based locking, ideal for same-device different-processes/threads stores. - * RedisLockingConfig: Redis-based locking, ideal for multi-device stores. Defaults to FileLockingConfig. """ diff --git a/packages/syft/src/syft/types/base.py b/packages/syft/src/syft/types/base.py index bb5160aebb0..764be4ae07a 100644 --- a/packages/syft/src/syft/types/base.py +++ b/packages/syft/src/syft/types/base.py @@ -2,8 +2,8 @@ # third party from pydantic import BaseModel +from pydantic import ConfigDict class SyftBaseModel(BaseModel): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index e36dc56df4e..623411aad2b 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -8,9 +8,12 @@ import threading from time import sleep from typing import Any +from typing import Callable from typing import ClassVar +from typing import Iterator from typing import List from typing import Optional +from typing import TYPE_CHECKING from typing import Type from typing import Union @@ -21,51 +24,36 @@ from typing_extensions import Self # relative +from ..client.api import SyftAPI +from ..client.client import SyftClient from ..node.credentials import SyftVerifyKey from ..serde import serialize from ..serde.serializable import serializable from ..service.action.action_object import ActionObject +from ..service.action.action_object import ActionObjectPointer from ..service.action.action_object import BASE_PASSTHROUGH_ATTRS from ..service.action.action_types import action_types from ..service.response import SyftError from ..service.response import SyftException from ..service.service import from_api_or_context from ..types.grid_url import GridURL -from ..types.transforms import drop from ..types.transforms import keep -from ..types.transforms import make_set_default from ..types.transforms import transform from .datetime import DateTime -from .syft_migration import migrate from .syft_object import SYFT_OBJECT_VERSION_1 from .syft_object import SYFT_OBJECT_VERSION_2 from .syft_object import SYFT_OBJECT_VERSION_3 from .syft_object import SyftObject from .uid import UID -READ_EXPIRATION_TIME = 1800 # seconds -DEFAULT_CHUNK_SIZE = 10000 * 1024 - - -@serializable() -class BlobFileV1(SyftObject): - __canonical_name__ = "BlobFile" - __version__ = SYFT_OBJECT_VERSION_1 - - file_name: str +if TYPE_CHECKING: + # relative + from ..store.blob_storage import BlobRetrievalByURL + from ..store.blob_storage import BlobStorageConnection - __repr_attrs__ = ["id", "file_name"] - - -class BlobFileV2(SyftObject): - __canonical_name__ = "BlobFile" - __version__ = SYFT_OBJECT_VERSION_2 - - file_name: str - syft_blob_storage_entry_id: Optional[UID] = None - file_size: Optional[int] = None - __repr_attrs__ = ["id", "file_name"] +READ_EXPIRATION_TIME = 1800 # seconds +DEFAULT_CHUNK_SIZE = 10000 * 1024 @serializable() @@ -76,29 +64,37 @@ class BlobFile(SyftObject): file_name: str syft_blob_storage_entry_id: Optional[UID] = None file_size: Optional[int] = None - path: Optional[Path] - uploaded = False + path: Optional[Path] = None + uploaded: bool = False __repr_attrs__ = ["id", "file_name"] - def read(self, stream=False, chunk_size=DEFAULT_CHUNK_SIZE, force=False): + def read( + self, + stream: bool = False, + chunk_size: int = DEFAULT_CHUNK_SIZE, + force: bool = False, + ) -> Any: # get blob retrieval object from api + syft_blob_storage_entry_id read_method = from_api_or_context( "blob_storage.read", self.syft_node_location, self.syft_client_verify_key ) - blob_retrieval_object = read_method(self.syft_blob_storage_entry_id) - return blob_retrieval_object._read_data( - stream=stream, chunk_size=chunk_size, _deserialize=False - ) + if read_method is not None: + blob_retrieval_object = read_method(self.syft_blob_storage_entry_id) + return blob_retrieval_object._read_data( + stream=stream, chunk_size=chunk_size, _deserialize=False + ) + else: + return None @classmethod - def upload_from_path(self, path, client): + def upload_from_path(cls, path: Union[str, Path], client: SyftClient) -> Any: # syft absolute import syft as sy return sy.ActionObject.from_path(path=path).send(client).syft_action_data - def _upload_to_blobstorage_from_api(self, api): + def _upload_to_blobstorage_from_api(self, api: SyftAPI) -> Optional[SyftError]: if self.path is None: raise ValueError("cannot upload BlobFile, no path specified") storage_entry = CreateBlobStorageEntry.from_path(self.path) @@ -117,12 +113,14 @@ def _upload_to_blobstorage_from_api(self, api): self.syft_blob_storage_entry_id = blob_deposit_object.blob_storage_entry_id self.uploaded = True - def upload_to_blobstorage(self, client): + return None + + def upload_to_blobstorage(self, client: SyftClient) -> Optional[SyftError]: self.syft_node_location = client.id self.syft_client_verify_key = client.verify_key return self._upload_to_blobstorage_from_api(client.api) - def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE): + def _iter_lines(self, chunk_size: int = DEFAULT_CHUNK_SIZE) -> Iterator[bytes]: """Synchronous version of the async iter_lines. This implementation is also optimized in terms of splitting chunks, making it faster for larger lines""" @@ -130,7 +128,7 @@ def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE): for chunk in self.read(stream=True, chunk_size=chunk_size): if b"\n" in chunk: if pending is not None: - chunk = pending + chunk + chunk = pending + chunk # type: ignore[unreachable] lines = chunk.splitlines() if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: pending = lines.pop() @@ -146,7 +144,13 @@ def _iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE): if pending is not None: yield pending - def read_queue(self, queue, chunk_size, progress=False, buffer_lines=10000): + def read_queue( + self, + queue: Queue, + chunk_size: int, + progress: bool = False, + buffer_lines: int = 10000, + ) -> None: total_read = 0 for _i, line in enumerate(self._iter_lines(chunk_size=chunk_size)): line_size = len(line) + 1 # add byte for \n @@ -165,7 +169,9 @@ def read_queue(self, queue, chunk_size, progress=False, buffer_lines=10000): # Put anything not a string at the end queue.put(0) - def iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE, progress=False): + def iter_lines( + self, chunk_size: int = DEFAULT_CHUNK_SIZE, progress: bool = False + ) -> Iterator[str]: item_queue: Queue = Queue() threading.Thread( target=self.read_queue, @@ -177,41 +183,26 @@ def iter_lines(self, chunk_size=DEFAULT_CHUNK_SIZE, progress=False): yield item item = item_queue.get() - def _coll_repr_(self): + def _coll_repr_(self) -> dict[str, str]: return {"file_name": self.file_name} -@migrate(BlobFile, BlobFileV1) -def downgrade_blobfile_v2_to_v1(): - return [ - drop(["syft_blob_storage_entry_id", "file_size"]), - ] - - -@migrate(BlobFileV1, BlobFile) -def upgrade_blobfile_v1_to_v2(): - return [ - make_set_default("syft_blob_storage_entry_id", None), - make_set_default("file_size", None), - ] - - class BlobFileType(type): pass -class BlobFileObjectPointer: +class BlobFileObjectPointer(ActionObjectPointer): pass @serializable() class BlobFileObject(ActionObject): __canonical_name__ = "BlobFileOBject" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 syft_internal_type: ClassVar[Type[Any]] = BlobFile - syft_pointer_type = BlobFileObjectPointer - syft_passthrough_attrs = BASE_PASSTHROUGH_ATTRS + syft_pointer_type: ClassVar[Type[ActionObjectPointer]] = BlobFileObjectPointer + syft_passthrough_attrs: List[str] = BASE_PASSTHROUGH_ATTRS @serializable() @@ -225,18 +216,16 @@ class SecureFilePathLocation(SyftObject): def __repr__(self) -> str: return f"{self.path}" - def generate_url(self, *args): + def generate_url( + self, + connection: "BlobStorageConnection", + type_: Optional[Type], + bucket_name: Optional[str], + *args: Any, + ) -> "BlobRetrievalByURL": raise NotImplementedError -@serializable() -class SeaweedSecureFilePathLocationV1(SecureFilePathLocation): - __canonical_name__ = "SeaweedSecureFilePathLocation" - __version__ = SYFT_OBJECT_VERSION_1 - - upload_id: str - - @serializable() class SeaweedSecureFilePathLocation(SecureFilePathLocation): __canonical_name__ = "SeaweedSecureFilePathLocation" @@ -244,7 +233,13 @@ class SeaweedSecureFilePathLocation(SecureFilePathLocation): upload_id: Optional[str] = None - def generate_url(self, connection, type_, bucket_name): + def generate_url( + self, + connection: "BlobStorageConnection", + type_: Optional[Type], + bucket_name: Optional[str], + *args: Any, + ) -> "BlobRetrievalByURL": try: url = connection.client.generate_presigned_url( ClientMethod="get_object", @@ -262,18 +257,6 @@ def generate_url(self, connection, type_, bucket_name): raise SyftException(e) -@migrate(SeaweedSecureFilePathLocationV1, SeaweedSecureFilePathLocation) -def upgrade_seaweedsecurefilepathlocation_v1_to_v2(): - return [make_set_default("bucket_name", "")] - - -@migrate(SeaweedSecureFilePathLocation, SeaweedSecureFilePathLocationV1) -def downgrade_seaweedsecurefilepathlocation_v2_to_v1(): - return [ - drop(["bucket_name"]), - ] - - @serializable() class AzureSecureFilePathLocation(SecureFilePathLocation): __canonical_name__ = "AzureSecureFilePathLocation" @@ -283,7 +266,9 @@ class AzureSecureFilePathLocation(SecureFilePathLocation): azure_profile_name: str # Used by Seaweedfs to refer to a remote config bucket_name: str - def generate_url(self, connection, type_, *args): + def generate_url( + self, connection: "BlobStorageConnection", type_: Optional[Type], *args: Any + ) -> "BlobRetrievalByURL": # SAS is almost the same thing as the presigned url config = connection.config.remote_profiles[self.azure_profile_name] account_name = config.account_name @@ -305,22 +290,6 @@ def generate_url(self, connection, type_, *args): return BlobRetrievalByURL(url=url, file_name=Path(self.path).name, type_=type_) -@serializable() -class BlobStorageEntryV1(SyftObject): - __canonical_name__ = "BlobStorageEntry" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - location: Union[SecureFilePathLocation, SeaweedSecureFilePathLocation] - type_: Optional[Type] - mimetype: str = "bytes" - file_size: int - uploaded_by: SyftVerifyKey - created_at: DateTime = DateTime.now() - - __attr_searchable__ = ["bucket_name"] - - @serializable() class BlobStorageEntry(SyftObject): __canonical_name__ = "BlobStorageEntry" @@ -328,69 +297,35 @@ class BlobStorageEntry(SyftObject): id: UID location: Union[SecureFilePathLocation, SeaweedSecureFilePathLocation] - type_: Optional[Type] + type_: Optional[Type] = None mimetype: str = "bytes" file_size: int no_lines: Optional[int] = 0 uploaded_by: SyftVerifyKey created_at: DateTime = DateTime.now() - bucket_name: Optional[str] + bucket_name: Optional[str] = None __attr_searchable__ = ["bucket_name"] -@migrate(BlobStorageEntry, BlobStorageEntryV1) -def downgrade_blobstorageentry_v2_to_v1(): - return [ - drop(["no_lines", "bucket_name"]), - ] - - -@migrate(BlobStorageEntryV1, BlobStorageEntry) -def upgrade_blobstorageentry_v1_to_v2(): - return [make_set_default("no_lines", 1), make_set_default("bucket_name", None)] - - -@serializable() -class BlobStorageMetadataV1(SyftObject): - __canonical_name__ = "BlobStorageMetadata" - __version__ = SYFT_OBJECT_VERSION_1 - - type_: Optional[Type[SyftObject]] - mimetype: str = "bytes" - file_size: int - - @serializable() class BlobStorageMetadata(SyftObject): __canonical_name__ = "BlobStorageMetadata" __version__ = SYFT_OBJECT_VERSION_2 - type_: Optional[Type[SyftObject]] + type_: Optional[Type[SyftObject]] = None mimetype: str = "bytes" file_size: int no_lines: Optional[int] = 0 -@migrate(BlobStorageMetadata, BlobStorageMetadataV1) -def downgrade_blobmeta_v2_to_v1(): - return [ - drop(["no_lines"]), - ] - - -@migrate(BlobStorageMetadataV1, BlobStorageMetadata) -def upgrade_blobmeta_v1_to_v2(): - return [make_set_default("no_lines", 1)] - - @serializable() class CreateBlobStorageEntry(SyftObject): __canonical_name__ = "CreateBlobStorageEntry" __version__ = SYFT_OBJECT_VERSION_1 id: UID - type_: Optional[Type] + type_: Optional[Type] = None mimetype: str = "bytes" file_size: int extensions: List[str] = [] @@ -408,7 +343,7 @@ def from_path(cls, fp: Union[str, Path], mimetype: Optional[str] = None) -> Self if not path.is_file(): raise SyftException(f"{fp} is not a file.") - if fp.suffix.lower() == ".jsonl": + if path.suffix.lower() == ".jsonl": mimetype = "application/json-lines" if mimetype is None: mime_types = mimetypes.guess_type(fp) @@ -433,7 +368,7 @@ def file_name(self) -> str: @transform(BlobStorageEntry, BlobStorageMetadata) -def storage_entry_to_metadata(): +def storage_entry_to_metadata() -> list[Callable]: return [keep(["id", "type_", "mimetype", "file_size"])] diff --git a/packages/syft/src/syft/types/datetime.py b/packages/syft/src/syft/types/datetime.py index c03e1433fd0..79ca1f35311 100644 --- a/packages/syft/src/syft/types/datetime.py +++ b/packages/syft/src/syft/types/datetime.py @@ -1,6 +1,7 @@ # stdlib from datetime import datetime from functools import total_ordering +from typing import Any from typing import Optional # third party @@ -19,7 +20,7 @@ class DateTime(SyftObject): __canonical_name__ = "DateTime" __version__ = SYFT_OBJECT_VERSION_1 - id: Optional[UID] + id: Optional[UID] = None # type: ignore utc_timestamp: float @classmethod @@ -33,7 +34,9 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.utc_timestamp) - def __eq__(self, other: Self) -> bool: + def __eq__(self, other: Any) -> bool: + if other is None: + return False return self.utc_timestamp == other.utc_timestamp def __lt__(self, other: Self) -> bool: diff --git a/packages/syft/src/syft/types/grid_url.py b/packages/syft/src/syft/types/grid_url.py index 9100c03704e..61287649d03 100644 --- a/packages/syft/src/syft/types/grid_url.py +++ b/packages/syft/src/syft/types/grid_url.py @@ -21,7 +21,7 @@ @serializable(attrs=["protocol", "host_or_ip", "port", "path", "query"]) class GridURL: @classmethod - def from_url(cls, url: Union[str, GridURL]) -> Self: + def from_url(cls, url: Union[str, GridURL]) -> GridURL: if isinstance(url, GridURL): return url try: @@ -37,7 +37,7 @@ def from_url(cls, url: Union[str, GridURL]) -> Self: host_or_ip = host_or_ip_parts[0] if parts.scheme == "https": port = 443 - return cls( + return GridURL( host_or_ip=host_or_ip, path=parts.path, port=port, @@ -141,7 +141,7 @@ def base_url_no_port(self) -> str: def url_path(self) -> str: return f"{self.path}{self.query_string}" - def to_tls(self) -> Self: + def to_tls(self) -> GridURL: if self.protocol == "https": return self @@ -165,7 +165,7 @@ def __str__(self) -> str: def __hash__(self) -> int: return hash(self.__str__()) - def __copy__(self) -> Self: + def __copy__(self) -> GridURL: return self.__class__.from_url(self.url) def set_port(self, port: int) -> Self: diff --git a/packages/syft/src/syft/types/identity.py b/packages/syft/src/syft/types/identity.py index a3359570ebb..52c61ef8c0d 100644 --- a/packages/syft/src/syft/types/identity.py +++ b/packages/syft/src/syft/types/identity.py @@ -29,6 +29,8 @@ def __repr__(self) -> str: @classmethod def from_client(cls, client: SyftClient) -> Self: + if not client.credentials: + raise ValueError(f"{client} has no signing key!") return cls(node_id=client.id, verify_key=client.credentials.verify_key) diff --git a/packages/syft/src/syft/types/syft_metaclass.py b/packages/syft/src/syft/types/syft_metaclass.py index 762213b7ba4..08ac3ce32de 100644 --- a/packages/syft/src/syft/types/syft_metaclass.py +++ b/packages/syft/src/syft/types/syft_metaclass.py @@ -1,113 +1,40 @@ -# Reference: https://github.com/pydantic/pydantic/issues/1223#issuecomment-998160737 - - # stdlib -import inspect -import threading from typing import Any -from typing import Dict -from typing import Generator -from typing import Tuple -from typing import Type +from typing import TypeVar +from typing import Union +from typing import final # third party -from pydantic.fields import UndefinedType -from pydantic.main import BaseModel -from pydantic.main import ModelField -from pydantic.main import ModelMetaclass +from pydantic import BaseModel +from pydantic._internal._model_construction import ModelMetaclass # relative -from ..serde.recursive_primitives import recursive_serde_register_type from ..serde.serializable import serializable -TupleGenerator = Generator[Tuple[str, Any], None, None] - - -@serializable() -class Empty: - pass - +_T = TypeVar("_T", bound=BaseModel) -class PartialModelMetaclass(ModelMetaclass): - def __new__( - meta: Type["PartialModelMetaclass"], *args: Any, **kwargs: Any - ) -> "PartialModelMetaclass": - cls = super().__new__(meta, *args, *kwargs) - cls_init = cls.__init__ - # Because the class will be modified temporarily, need to lock __init__ - init_lock = threading.Lock() - # To preserve identical hashes of temporary nested partial models, - # only one instance of each temporary partial class can exist - - def __init__(self: BaseModel, *args: Any, **kwargs: Any) -> None: - with init_lock: - fields = self.__class__.__fields__ - fields_map: Dict[ModelField, Tuple[Any, bool]] = {} - - def optionalize( - fields: Dict[str, ModelField], *, restore: bool = False - ) -> None: - for _, field in fields.items(): - if not restore: - if isinstance(field.required, UndefinedType): - raise Exception(f"{field.name} is a required field.") - fields_map[field] = (field.type_, field.required) - # If field has None allowed as a value - # then it becomes a required field. - if field.allow_none and field.name in kwargs: - field.required = True - else: - field.required = False - if inspect.isclass(field.type_) and issubclass( - field.type_, BaseModel - ): - field.populate_validators() - if field.sub_fields is not None: - for sub_field in field.sub_fields: - sub_field.type_ = field.type_ - sub_field.populate_validators() - optionalize(field.type_.__fields__) - else: - # No need to recursively de-optionalize once original types - # are restored - field.type_, field.required = fields_map[field] - if field.sub_fields is not None: - for sub_field in field.sub_fields: - sub_field.type_ = field.type_ - - # Make fields and fields of nested model types optional - optionalize(fields) - # Transform kwargs that are PartialModels to their dict() forms. This - # will exclude `None` (see below) from the dictionary used to construct - # the temporarily-partial model field, avoiding ValidationErrors of - # type type_error.none.not_allowed. - for kwarg, value in kwargs.items(): - if value.__class__.__class__ is PartialModelMetaclass: - kwargs[kwarg] = value.dict() - elif isinstance(value, (tuple, list)): - kwargs[kwarg] = value.__class__( - v.dict() - if v.__class__.__class__ is PartialModelMetaclass - else v - for v in value - ) +class EmptyType(type): + def __repr__(self) -> str: + return self.__name__ - # Validation is performed in __init__, for which all fields are now optional - cls_init(self, *args, **kwargs) - # Restore requiredness - optionalize(fields, restore=True) + def __bool__(self) -> bool: + return False - cls.__init__ = __init__ - def iter_exclude_empty(self) -> TupleGenerator: - for key, value in self.__dict__.items(): - if value is not Empty: - yield key, value +@serializable() +@final +class Empty(metaclass=EmptyType): + pass - cls.__iter__ = iter_exclude_empty - return cls +class PartialModelMetaclass(ModelMetaclass): + def __call__(cls: type[_T], *args: Any, **kwargs: Any) -> _T: + for field_info in cls.model_fields.values(): + if field_info.annotation is not None and field_info.is_required(): + field_info.annotation = Union[field_info.annotation, EmptyType] + field_info.default = Empty + cls.model_rebuild(force=True) -recursive_serde_register_type(PartialModelMetaclass) + return super().__call__(*args, **kwargs) # type: ignore[misc] diff --git a/packages/syft/src/syft/types/syft_migration.py b/packages/syft/src/syft/types/syft_migration.py index 86f99320d8e..6f7e10795de 100644 --- a/packages/syft/src/syft/types/syft_migration.py +++ b/packages/syft/src/syft/types/syft_migration.py @@ -39,7 +39,7 @@ def migrate( f"{klass_from_str} has version: {version_from}, {klass_to_str} has version: {version_to}" ) - def decorator(function: Callable): + def decorator(function: Callable) -> Callable: transforms = function() wrapper = generate_transform_wrapper( diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 9d8ed78a5d7..d9a7dab5901 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -8,28 +8,39 @@ import inspect from inspect import Signature import re +import sys import traceback import types +import typing from typing import Any from typing import Callable from typing import ClassVar from typing import Dict +from typing import Generator +from typing import Iterable from typing import KeysView from typing import List from typing import Optional from typing import Sequence +from typing import TYPE_CHECKING from typing import Tuple from typing import Type from typing import Union +from typing import get_args +from typing import get_origin import warnings # third party import pandas as pd import pydantic +from pydantic import ConfigDict from pydantic import EmailStr -from pydantic.fields import Undefined +from pydantic import Field +from pydantic import model_validator +from pydantic.fields import PydanticUndefined from result import OkErr from typeguard import check_type +from typing_extensions import Self # relative from ..node.credentials import SyftVerifyKey @@ -46,6 +57,18 @@ from .syft_metaclass import PartialModelMetaclass from .uid import UID +if sys.version_info >= (3, 10): + # stdlib + from types import NoneType + from types import UnionType +else: + UnionType = Union + NoneType = type(None) + +if TYPE_CHECKING: + # relative + from ..service.sync.diff_state import AttrDiff + IntStr = Union[int, str] AbstractSetIntStr = Set[IntStr] MappingIntStrAny = Mapping[IntStr, Any] @@ -75,8 +98,27 @@ ] +def _is_optional(x: Any) -> bool: + return get_origin(x) in (Optional, UnionType, Union) and any( + arg is NoneType for arg in get_args(x) + ) + + +def _get_optional_inner_type(x: Any) -> Any: + if get_origin(x) not in (Optional, UnionType, Union): + return x + + args = get_args(x) + + if not any(arg is NoneType for arg in args): + return x + + non_none = [arg for arg in args if arg is not NoneType] + return non_none[0] if len(non_none) == 1 else x + + class SyftHashableObject: - __hash_exclude_attrs__ = [] + __hash_exclude_attrs__: list = [] def __hash__(self) -> int: return int.from_bytes(self.__sha256__(), byteorder="big") @@ -91,17 +133,16 @@ def hash(self) -> str: class SyftBaseObject(pydantic.BaseModel, SyftHashableObject): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) # the name which doesn't change even when there are multiple classes __canonical_name__: str __version__: int # data is always versioned - syft_node_location: Optional[UID] - syft_client_verify_key: Optional[SyftVerifyKey] + syft_node_location: Optional[UID] = Field(default=None, exclude=True) + syft_client_verify_key: Optional[SyftVerifyKey] = Field(default=None, exclude=True) - def _set_obj_location_(self, node_uid: UID, credentials: SyftVerifyKey): + def _set_obj_location_(self, node_uid: UID, credentials: SyftVerifyKey) -> None: self.syft_node_location = node_uid self.syft_client_verify_key = credentials @@ -114,7 +155,9 @@ class Context(SyftBaseObject): class SyftObjectRegistry: - __object_version_registry__: Dict[str, Type["SyftObject"]] = {} + __object_version_registry__: Dict[ + str, Union[Type["SyftObject"], Type["SyftObjectRegistry"]] + ] = {} __object_transform_registry__: Dict[str, Callable] = {} def __init_subclass__(cls, **kwargs: Any) -> None: @@ -145,7 +188,9 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls.__object_version_registry__[mapping_string] = cls @classmethod - def versioned_class(cls, name: str, version: int) -> Optional[Type["SyftObject"]]: + def versioned_class( + cls, name: str, version: int + ) -> Optional[Union[Type["SyftObject"], Type["SyftObjectRegistry"]]]: mapping_string = f"{name}_{version}" if mapping_string not in cls.__object_version_registry__: return None @@ -214,7 +259,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls.register_version(klass=klass) @classmethod - def register_version(cls, klass: type): + def register_version(cls, klass: type) -> None: if hasattr(klass, "__canonical_name__") and hasattr(klass, "__version__"): mapping_string = klass.__canonical_name__ klass_version = klass.__version__ @@ -304,6 +349,10 @@ def get_migration( return cls.__migration_transform_registry__[klass_from][ mapping_string ] + raise ValueError( + f"No migration found for class type: {type_from} to " + f"type: {type_to} in the migration registry." + ) @classmethod def get_migration_for_version( @@ -336,25 +385,35 @@ def get_migration_for_version( ) -print_type_cache = defaultdict(list) +print_type_cache: dict = defaultdict(list) + + +base_attrs_sync_ignore = [ + "syft_node_location", + "syft_client_verify_key", +] class SyftObject(SyftBaseObject, SyftObjectRegistry, SyftMigrationRegistry): __canonical_name__ = "SyftObject" __version__ = SYFT_OBJECT_VERSION_1 - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + json_encoders={UID: str}, + ) # all objects have a UID id: UID # # move this to transforms - @pydantic.root_validator(pre=True) - def make_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: - id_field = cls.__fields__["id"] - if "id" not in values and id_field.required: - values["id"] = id_field.type_() + @model_validator(mode="before") + @classmethod + def make_id(cls, values: Any) -> Any: + if isinstance(values, dict): + id_field = cls.model_fields["id"] + if "id" not in values and id_field.is_required(): + values["id"] = id_field.annotation() return values __attr_searchable__: ClassVar[ @@ -369,7 +428,7 @@ def make_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: __repr_attrs__: ClassVar[List[str]] = [] # show these in html repr collections __attr_custom_repr__: ClassVar[ - List[str] + Optional[List[str]] ] = None # show these in html repr of an object def __syft_get_funcs__(self) -> List[Tuple[str, Signature]]: @@ -404,7 +463,7 @@ def __str__(self) -> str: def _repr_debug_(self) -> str: class_name = get_qualname_for(type(self)) _repr_str = f"class {class_name}:\n" - fields = getattr(self, "__fields__", {}) + fields = getattr(self, "model_fields", {}) for attr in fields.keys(): if attr in DYNAMIC_SYFT_ATTRIBUTES: continue @@ -417,7 +476,7 @@ def _repr_debug_(self) -> str: _repr_str += f" {attr}: {value_type} = {value}\n" return _repr_str - def _repr_markdown_(self, wrap_as_python=True, indent=0) -> str: + def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: s_indent = " " * indent * 2 class_name = get_qualname_for(type(self)) if self.__attr_custom_repr__ is not None: @@ -425,7 +484,7 @@ def _repr_markdown_(self, wrap_as_python=True, indent=0) -> str: elif self.__repr_attrs__ is not None: fields = self.__repr_attrs__ else: - fields = list(getattr(self, "__fields__", {}).keys()) + fields = list(getattr(self, "__fields__", {}).keys()) # type: ignore[unreachable] if "id" not in fields: fields = ["id"] + fields @@ -434,7 +493,7 @@ def _repr_markdown_(self, wrap_as_python=True, indent=0) -> str: fields = [x for x in fields if x not in dynam_attrs] _repr_str = f"{s_indent}class {class_name}:\n" for attr in fields: - value = self + value: Any = self # if it's a compound string if "." in attr: # break it into it's bits & fetch the attr @@ -450,9 +509,11 @@ def _repr_markdown_(self, wrap_as_python=True, indent=0) -> str: value = value.__repr_syft_nested__() if isinstance(value, list): value = [ - elem.__repr_syft_nested__() - if hasattr(elem, "__repr_syft_nested__") - else elem + ( + elem.__repr_syft_nested__() + if hasattr(elem, "__repr_syft_nested__") + else elem + ) for elem in value ] value = f'"{value}"' if isinstance(value, str) else value @@ -478,8 +539,8 @@ def keys(self) -> KeysView[str]: return self.__dict__.keys() # allows splatting with ** - def __getitem__(self, key: str) -> Any: - return self.__dict__.__getitem__(key) + def __getitem__(self, key: Union[str, int]) -> Any: + return self.__dict__.__getitem__(key) # type: ignore def _upgrade_version(self, latest: bool = True) -> "SyftObject": constructor = SyftObjectRegistry.versioned_class( @@ -523,58 +584,27 @@ def to_dict( new_dict[k] = v return new_dict - def dict( - self, - *, - include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None, - by_alias: bool = False, - skip_defaults: Optional[bool] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ): - if exclude is None: - exclude = set() - - for attr in DYNAMIC_SYFT_ATTRIBUTES: - exclude.add(attr) - return super().dict( - include=include, - exclude=exclude, - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - def __post_init__(self) -> None: pass - def _syft_set_validate_private_attrs_(self, **kwargs): + def _syft_set_validate_private_attrs_(self, **kwargs: Any) -> None: # Validate and set private attributes # https://github.com/pydantic/pydantic/issues/2105 for attr, decl in self.__private_attributes__.items(): value = kwargs.get(attr, decl.get_default()) var_annotation = self.__annotations__.get(attr) - if value is not Undefined: - if decl.default_factory: - # If the value is defined via PrivateAttr with default factory - value = decl.default_factory(value) - elif var_annotation is not None: + if value is not PydanticUndefined: + if var_annotation is not None: # Otherwise validate value against the variable annotation - check_type(attr, value, var_annotation) + check_type(value, var_annotation) setattr(self, attr, value) else: - # check if the private is optional - is_optional_attr = type(None) in getattr(var_annotation, "__args__", []) - if not is_optional_attr: + if not _is_optional(var_annotation): raise ValueError( f"{attr}\n field required (type=value_error.missing)" ) - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._syft_set_validate_private_attrs_(**kwargs) self.__post_init__() @@ -587,8 +617,8 @@ def __hash__(self) -> int: def _syft_keys_types_dict(cls, attr_name: str) -> Dict[str, type]: kt_dict = {} for key in getattr(cls, attr_name, []): - if key in cls.__fields__: - type_ = cls.__fields__[key].type_ + if key in cls.model_fields: + type_ = _get_optional_inner_type(cls.model_fields[key].annotation) else: try: method = getattr(cls, key) @@ -602,8 +632,9 @@ def _syft_keys_types_dict(cls, attr_name: str) -> Dict[str, type]: # EmailStr seems to be lost every time the value is set even with a validator # this means the incoming type is str so our validators fail - if type(type_) is type and issubclass(type_, EmailStr): + if type_ is EmailStr: type_ = str + kt_dict[key] = type_ return kt_dict @@ -626,6 +657,113 @@ def migrate_to(self, version: int, context: Optional[Context] = None) -> Any: ) return self + def syft_eq(self, ext_obj: Optional[Self]) -> bool: + if ext_obj is None: + return False + attrs_to_check = self.__dict__.keys() + + obj_exclude_attrs = getattr(self, "__exclude_sync_diff_attrs__", []) + for attr in attrs_to_check: + if attr not in base_attrs_sync_ignore and attr not in obj_exclude_attrs: + obj_attr = getattr(self, attr) + ext_obj_attr = getattr(ext_obj, attr) + if hasattr(obj_attr, "syft_eq") and not inspect.isclass(obj_attr): + if not obj_attr.syft_eq(ext_obj=ext_obj_attr): + return False + elif obj_attr != ext_obj_attr: + return False + return True + + def get_diffs(self, ext_obj: Self) -> List["AttrDiff"]: + # self is low, ext is high + # relative + from ..service.sync.diff_state import AttrDiff + from ..service.sync.diff_state import ListDiff + + diff_attrs = [] + + # Sanity check + if self.id != ext_obj.id: + raise Exception("Not the same id for low side and high side requests") + + attrs_to_check = self.__dict__.keys() + + obj_exclude_attrs = getattr(self, "__exclude_sync_diff_attrs__", []) + + for attr in attrs_to_check: + if attr not in base_attrs_sync_ignore and attr not in obj_exclude_attrs: + obj_attr = getattr(self, attr) + ext_obj_attr = getattr(ext_obj, attr) + + if isinstance(obj_attr, list) and isinstance(ext_obj_attr, list): + list_diff = ListDiff.from_lists( + attr_name=attr, low_list=obj_attr, high_list=ext_obj_attr + ) + if not list_diff.is_empty: + diff_attrs.append(list_diff) + + # TODO: to the same check as above for Dicts when we use them + else: + cmp = obj_attr.__eq__ + if hasattr(obj_attr, "syft_eq"): + cmp = obj_attr.syft_eq + + if not cmp(ext_obj_attr): + diff_attr = AttrDiff( + attr_name=attr, + low_attr=obj_attr, + high_attr=ext_obj_attr, + ) + diff_attrs.append(diff_attr) + return diff_attrs + + ## OVERRIDING pydantic.BaseModel.__getattr__ + ## return super().__getattribute__(item) -> return self.__getattribute__(item) + ## so that ActionObject.__getattribute__ works properly, + ## raising AttributeError when underlying object does not have the attribute + if not typing.TYPE_CHECKING: + # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access + + def __getattr__(self, item: str) -> Any: + private_attributes = object.__getattribute__(self, "__private_attributes__") + if item in private_attributes: + attribute = private_attributes[item] + if hasattr(attribute, "__get__"): + return attribute.__get__(self, type(self)) # type: ignore + + try: + # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items + return self.__pydantic_private__[item] # type: ignore + except KeyError as exc: + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {item!r}" + ) from exc + else: + # `__pydantic_extra__` can fail to be set if the model is not yet fully initialized. + # See `BaseModel.__repr_args__` for more details + try: + pydantic_extra = object.__getattribute__(self, "__pydantic_extra__") + except AttributeError: + pydantic_extra = None + + if pydantic_extra is not None: + try: + return pydantic_extra[item] + except KeyError as exc: + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {item!r}" + ) from exc + else: + if hasattr(self.__class__, item): + return self.__getattribute__( + item + ) # Raises AttributeError if appropriate + else: + # this is the current error + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {item!r}" + ) + def short_qual_name(name: str) -> str: # If the name is a qualname of formax a.b.c.d we will only get d @@ -633,14 +771,18 @@ def short_qual_name(name: str) -> str: return name.split(".")[-1] -def short_uid(uid: UID) -> str: +def short_uid(uid: Optional[UID]) -> Optional[str]: if uid is None: return uid else: return str(uid)[:6] + "..." -def get_repr_values_table(_self, is_homogenous, extra_fields=None): +def get_repr_values_table( + _self: Union[Mapping, Iterable], + is_homogenous: bool, + extra_fields: Optional[list] = None, +) -> dict: if extra_fields is None: extra_fields = [] @@ -705,9 +847,11 @@ def get_repr_values_table(_self, is_homogenous, extra_fields=None): and hasattr(value[0], "__repr_syft_nested__") ): value = [ - x.__repr_syft_nested__() - if hasattr(x, "__repr_syft_nested__") - else x + ( + x.__repr_syft_nested__() + if hasattr(x, "__repr_syft_nested__") + else x + ) for x in value ] if value is None: @@ -726,14 +870,14 @@ def get_repr_values_table(_self, is_homogenous, extra_fields=None): return df.to_dict("records") -def list_dict_repr_html(self) -> str: +def list_dict_repr_html(self: Union[Mapping, Set, Iterable]) -> str: try: max_check = 1 items_checked = 0 has_syft = False - extra_fields = [] + extra_fields: list = [] if isinstance(self, Mapping): - values = list(self.values()) + values: Any = list(self.values()) elif isinstance(self, Set): values = list(self) else: @@ -748,7 +892,7 @@ def list_dict_repr_html(self) -> str: break if hasattr(type(item), "mro") and type(item) != type: - mro = type(item).mro() + mro: Union[list, str] = type(item).mro() elif hasattr(item, "mro") and type(item) != type: mro = item.mro() else: @@ -772,7 +916,12 @@ def list_dict_repr_html(self) -> str: cls_name = first_value.__class__.__name__ else: cls_name = "" - vals = get_repr_values_table(self, is_homogenous, extra_fields=extra_fields) + try: + vals = get_repr_values_table( + self, is_homogenous, extra_fields=extra_fields + ) + except Exception: + return str(self) return create_table_template( vals, @@ -805,6 +954,12 @@ def to(self, projection: type, context: Optional[Context] = None) -> Any: transform = SyftObjectRegistry.get_transform(type(self), projection) return transform(self, context) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +TupleGenerator = Generator[Tuple[str, Any], None, None] + class PartialSyftObject(SyftObject, metaclass=PartialModelMetaclass): """Syft Object to which partial arguments can be provided.""" @@ -812,45 +967,8 @@ class PartialSyftObject(SyftObject, metaclass=PartialModelMetaclass): __canonical_name__ = "PartialSyftObject" __version__ = SYFT_OBJECT_VERSION_1 - def __init__(self, *args, **kwargs) -> None: - # Filter out Empty values from args and kwargs - args_, kwargs_ = (), {} - for arg in args: - if arg is not Empty: - args_.append(arg) - - for key, val in kwargs.items(): - if val is not Empty: - kwargs_[key] = val - - super().__init__(*args_, **kwargs_) - - fields_with_default = set() - for _field_name, _field in self.__fields__.items(): - if _field.default is not None or _field.allow_none: - fields_with_default.add(_field_name) - - # Fields whose values are set via a validator hook - fields_set_via_validator = [] - - for _field_name in self.__validators__.keys(): - _field = self.__fields__[_field_name] - if self.__dict__[_field_name] is None: - # Since all fields are None, only allow None - # where either none is allowed or default is None - if _field.allow_none or _field.default is None: - fields_set_via_validator.append(_field) - - # Exclude unset fields - unset_fields = ( - set(self.__fields__) - - set(self.__fields_set__) - - set(fields_set_via_validator) - ) - - empty_fields = unset_fields - fields_with_default - for field_name in empty_fields: - self.__dict__[field_name] = Empty + def __iter__(self) -> TupleGenerator: + yield from ((k, v) for k, v in super().__iter__() if v is not Empty) recursive_serde_register_type(PartialSyftObject) @@ -867,7 +985,7 @@ def attach_attribute_to_syft_object(result: Any, attr_dict: Dict[str, Any]) -> A result = result.value if isinstance(result, MutableMapping): - iterable_keys = result.keys() + iterable_keys: Iterable = result.keys() elif isinstance(result, MutableSequence): iterable_keys = range(len(result)) elif isinstance(result, tuple): diff --git a/packages/syft/src/syft/types/transforms.py b/packages/syft/src/syft/types/transforms.py index d4d7a0c38ca..1b3a4967ad8 100644 --- a/packages/syft/src/syft/types/transforms.py +++ b/packages/syft/src/syft/types/transforms.py @@ -28,19 +28,21 @@ class NotNone: class TransformContext(Context): - output: Optional[Dict[str, Any]] - node: Optional[AbstractNode] - credentials: Optional[SyftVerifyKey] - obj: Optional[Any] - - @staticmethod - def from_context(obj: Any, context: Optional[Context] = None) -> Self: - t_context = TransformContext() + output: Optional[Dict[str, Any]] = None + node: Optional[AbstractNode] = None + credentials: Optional[SyftVerifyKey] = None + obj: Optional[Any] = None + + @classmethod + def from_context(cls, obj: Any, context: Optional[Context] = None) -> Self: + t_context = cls() t_context.obj = obj try: - t_context.output = dict(obj) + t_context.output = obj.to_dict(exclude_empty=True) except Exception: - t_context.output = obj.to_dict() + t_context.output = dict(obj) + if context is None: + return t_context if hasattr(context, "credentials"): t_context.credentials = context.credentials if hasattr(context, "node"): @@ -67,7 +69,7 @@ def geteitherattr( def make_set_default(key: str, value: Any) -> Callable: def set_default(context: TransformContext) -> TransformContext: - if not geteitherattr(context.obj, context.output, key, None): + if context.output and not geteitherattr(context.obj, context.output, key, None): context.output[key] = value return context @@ -76,9 +78,10 @@ def set_default(context: TransformContext) -> TransformContext: def drop(list_keys: List[str]) -> Callable: def drop_keys(context: TransformContext) -> TransformContext: - for key in list_keys: - if key in context.output: - del context.output[key] + if context.output: + for key in list_keys: + if key in context.output: + del context.output[key] return context return drop_keys @@ -86,9 +89,12 @@ def drop_keys(context: TransformContext) -> TransformContext: def rename(old_key: str, new_key: str) -> Callable: def drop_keys(context: TransformContext) -> TransformContext: - context.output[new_key] = geteitherattr(context.obj, context.output, old_key) - if old_key in context.output: - del context.output[old_key] + if context.output: + context.output[new_key] = geteitherattr( + context.obj, context.output, old_key + ) + if old_key in context.output: + del context.output[old_key] return context return drop_keys @@ -96,6 +102,9 @@ def drop_keys(context: TransformContext) -> TransformContext: def keep(list_keys: List[str]) -> Callable: def drop_keys(context: TransformContext) -> TransformContext: + if context.output is None: + return context + for key in list_keys: if key not in context.output: context.output[key] = getattr(context.obj, key, None) @@ -111,7 +120,9 @@ def drop_keys(context: TransformContext) -> TransformContext: return drop_keys -def convert_types(list_keys: List[str], types: Union[type, List[type]]) -> Callable: +def convert_types( + list_keys: List[str], types: Union[type, List[type]] +) -> Callable[[TransformContext], TransformContext]: if not isinstance(types, list): types = [types] * len(list_keys) @@ -119,42 +130,48 @@ def convert_types(list_keys: List[str], types: Union[type, List[type]]) -> Calla raise Exception("convert types lists must be the same length") def run_convert_types(context: TransformContext) -> TransformContext: - for key, _type in zip(list_keys, types): - context.output[key] = _type(geteitherattr(context.obj, context.output, key)) + if context.output: + for key, _type in zip(list_keys, types): + context.output[key] = _type( + geteitherattr(context.obj, context.output, key) + ) return context return run_convert_types def generate_id(context: TransformContext) -> TransformContext: + if context.output is None: + return context if "id" not in context.output or not isinstance(context.output["id"], UID): context.output["id"] = UID() return context def validate_url(context: TransformContext) -> TransformContext: - if context.output["url"] is not None: + if context.output and context.output["url"] is not None: context.output["url"] = GridURL.from_url(context.output["url"]).url_no_port return context def validate_email(context: TransformContext) -> TransformContext: - if context.output["email"] is not None: - context.output["email"] = EmailStr(context.output["email"]) - EmailStr.validate(context.output["email"]) + if context.output and context.output["email"] is not None: + EmailStr._validate(context.output["email"]) return context def str_url_to_grid_url(context: TransformContext) -> TransformContext: - url = context.output.get("url", None) - if url is not None and isinstance(url, str): - context.output["url"] = GridURL.from_url(str) + if context.output: + url = context.output.get("url", None) + if url is not None and isinstance(url, str): + context.output["url"] = GridURL.from_url(str) return context def add_credentials_for_key(key: str) -> Callable: def add_credentials(context: TransformContext) -> TransformContext: - context.output[key] = context.credentials + if context.output is not None: + context.output[key] = context.credentials return context return add_credentials @@ -162,7 +179,8 @@ def add_credentials(context: TransformContext) -> TransformContext: def add_node_uid_for_key(key: str) -> Callable: def add_node_uid(context: TransformContext) -> TransformContext: - context.output[key] = context.node.id + if context.output is not None and context.node is not None: + context.output[key] = context.node.id return context return add_node_uid @@ -188,7 +206,7 @@ def validate_klass_and_version( klass_to: Union[Type, str], version_from: Optional[int] = None, version_to: Optional[int] = None, -): +) -> tuple[str, Optional[int], str, Optional[int]]: if not isinstance(klass_from, (type, str)): raise NotImplementedError( "Arguments to `klass_from` should be either of `Type` or `str` type." @@ -238,7 +256,7 @@ def transform_method( version_to=version_to, ) - def decorator(function: Callable): + def decorator(function: Callable) -> Callable: SyftObjectRegistry.add_transform( klass_from=klass_from_str, version_from=version_from, @@ -270,7 +288,7 @@ def transform( version_to=version_to, ) - def decorator(function: Callable): + def decorator(function: Callable) -> Callable: transforms = function() wrapper = generate_transform_wrapper( diff --git a/packages/syft/src/syft/types/twin_object.py b/packages/syft/src/syft/types/twin_object.py index 5529daaff53..d06d97d8b77 100644 --- a/packages/syft/src/syft/types/twin_object.py +++ b/packages/syft/src/syft/types/twin_object.py @@ -3,17 +3,20 @@ # stdlib from typing import Any -from typing import Dict +from typing import ClassVar from typing import Optional # third party -import pydantic +from pydantic import field_validator +from pydantic import model_validator +from typing_extensions import Self # relative from ..serde.serializable import serializable from ..service.action.action_object import ActionObject from ..service.action.action_object import TwinMode from ..service.action.action_types import action_types +from ..service.response import SyftError from .syft_object import SyftObject from .uid import UID @@ -24,7 +27,7 @@ def to_action_object(obj: Any) -> ActionObject: if type(obj) in action_types: return action_types[type(obj)](syft_action_data_cache=obj) - raise Exception(f"{type(obj)} not in action_types") + raise ValueError(f"{type(obj)} not in action_types") @serializable() @@ -32,7 +35,7 @@ class TwinObject(SyftObject): __canonical_name__ = "TwinObject" __version__ = 1 - __attr_searchable__ = [] + __attr_searchable__: ClassVar[list[str]] = [] id: UID private_obj: ActionObject @@ -40,21 +43,27 @@ class TwinObject(SyftObject): mock_obj: ActionObject mock_obj_id: UID = None # type: ignore - @pydantic.validator("private_obj", pre=True, always=True) - def make_private_obj(cls, v: ActionObject) -> ActionObject: + @field_validator("private_obj", mode="before") + @classmethod + def make_private_obj(cls, v: Any) -> ActionObject: return to_action_object(v) - @pydantic.validator("private_obj_id", pre=True, always=True) - def make_private_obj_id(cls, v: Optional[UID], values: Dict) -> UID: - return values["private_obj"].id if v is None else v + @model_validator(mode="after") + def make_private_obj_id(self) -> Self: + if self.private_obj_id is None: + self.private_obj_id = self.private_obj.id # type: ignore[unreachable] + return self - @pydantic.validator("mock_obj", pre=True, always=True) - def make_mock_obj(cls, v: ActionObject): + @field_validator("mock_obj", mode="before") + @classmethod + def make_mock_obj(cls, v: Any) -> ActionObject: return to_action_object(v) - @pydantic.validator("mock_obj_id", pre=True, always=True) - def make_mock_obj_id(cls, v: Optional[UID], values: Dict) -> UID: - return values["mock_obj"].id if v is None else v + @model_validator(mode="after") + def make_mock_obj_id(self) -> Self: + if self.mock_obj_id is None: + self.mock_obj_id = self.mock_obj.id # type: ignore[unreachable] + return self @property def private(self) -> ActionObject: @@ -72,7 +81,7 @@ def mock(self) -> ActionObject: mock.id = twin_id return mock - def _save_to_blob_storage(self): + def _save_to_blob_storage(self) -> Optional[SyftError]: # Set node location and verify key self.private_obj._set_obj_location_( self.syft_node_location, diff --git a/packages/syft/src/syft/types/uid.py b/packages/syft/src/syft/types/uid.py index 68f68e8a639..a2867e2e561 100644 --- a/packages/syft/src/syft/types/uid.py +++ b/packages/syft/src/syft/types/uid.py @@ -40,7 +40,7 @@ class UID: __slots__ = "value" value: uuid_type - def __init__(self, value: Optional[Union[uuid_type, str, bytes]] = None): + def __init__(self, value: Optional[Union[uuid_type, str, bytes, "UID"]] = None): """Initializes the internal id using the uuid package. This initializes the object. Normal use for this object is @@ -80,6 +80,7 @@ def from_string(value: str) -> "UID": except ValueError as e: critical(f"Unable to convert {value} to UUID. {e}") traceback_and_raise(e) + raise @staticmethod def with_seed(value: str) -> "UID": @@ -215,7 +216,7 @@ class LineageID(UID): def __init__( self, - value: Optional[Union[uuid_type, str, bytes]] = None, + value: Optional[Union[uuid_type, str, bytes, "LineageID"]] = None, syft_history_hash: Optional[int] = None, ): if isinstance(value, LineageID): @@ -232,7 +233,7 @@ def __init__( def id(self) -> UID: return UID(self.value) - def __hash__(self): + def __hash__(self) -> int: return hash((self.syft_history_hash, self.value)) def __eq__(self, other: Any) -> bool: diff --git a/packages/syft/src/syft/util/notebook_ui/notebook_addons.py b/packages/syft/src/syft/util/notebook_ui/notebook_addons.py index 0e60ad9acf1..fd72e302490 100644 --- a/packages/syft/src/syft/util/notebook_ui/notebook_addons.py +++ b/packages/syft/src/syft/util/notebook_ui/notebook_addons.py @@ -215,6 +215,7 @@ grid-template-columns: 1fr repeat(${cols}, 1fr); grid-template-rows: repeat(2, 1fr); overflow-x: auto; + position: relative; } .grid-std-cells { @@ -236,6 +237,7 @@ align-items: center; padding: 6px 4px; + resize: horizontal; /* Lt On Surface/Surface */ /* Lt On Surface/High */ border: 1px solid #CFCDD6; @@ -294,7 +296,7 @@ } .paginationContainer{ width: 100%; - height: 30px; + /*height: 30px;*/ display: flex; justify-content: center; gap: 8px; @@ -643,9 +645,7 @@ } else { text = String(item[attr]) } - if (text.length > 150){ - text = text.slice(0,150) + "..."; - } + text = text.replaceAll("\\n", "</br>"); div.innerHTML = text; } diff --git a/packages/syft/src/syft/util/schema.py b/packages/syft/src/syft/util/schema.py index 55edd83962a..c5c3e8e12ee 100644 --- a/packages/syft/src/syft/util/schema.py +++ b/packages/syft/src/syft/util/schema.py @@ -29,7 +29,7 @@ def make_fake_type(_type_str: str) -> dict[str, Any]: - jsonschema: dict = {} + jsonschema: dict[str, Any] = {} jsonschema["title"] = _type_str jsonschema["type"] = "object" jsonschema["properties"] = {} @@ -37,7 +37,7 @@ def make_fake_type(_type_str: str) -> dict[str, Any]: return jsonschema -def get_type_mapping(_type: Any) -> str: +def get_type_mapping(_type: Type) -> str: if _type in primitive_mapping: return primitive_mapping[_type] return _type.__name__ @@ -62,9 +62,9 @@ def get_types(cls: Type, keys: List[str]) -> Optional[Dict[str, Type]]: def convert_attribute_types( - cls: Any, attribute_list: Any, attribute_types: Any + cls: Type, attribute_list: list[str], attribute_types: list[Type] ) -> dict[str, Any]: - jsonschema: dict = {} + jsonschema: dict[str, Any] = {} jsonschema["title"] = cls.__name__ jsonschema["type"] = "object" jsonschema["properties"] = {} @@ -81,7 +81,7 @@ def process_type_bank(type_bank: Dict[str, Tuple[Any, ...]]) -> Dict[str, Dict]: # first pass gets each type into basic json schema format json_mappings = {} count = 0 - converted_types: defaultdict = defaultdict(int) + converted_types: Dict[str, int] = defaultdict(int) for k in type_bank: count += 1 t = type_bank[k] diff --git a/packages/syft/src/syft/util/telemetry.py b/packages/syft/src/syft/util/telemetry.py index c31c5c968d4..3e62409d165 100644 --- a/packages/syft/src/syft/util/telemetry.py +++ b/packages/syft/src/syft/util/telemetry.py @@ -27,53 +27,55 @@ def noop(__func_or_class: T, /, *args: Any, **kwargs: Any) -> T: if not TRACE_MODE: instrument = noop - -try: - print("OpenTelemetry Tracing enabled") - service_name = os.environ.get("SERVICE_NAME", "client") - jaeger_host = os.environ.get("JAEGER_HOST", "localhost") - jaeger_port = int(os.environ.get("JAEGER_PORT", "14268")) - - # third party - from opentelemetry import trace - from opentelemetry.exporter.jaeger.thrift import JaegerExporter - from opentelemetry.sdk.resources import Resource - from opentelemetry.sdk.resources import SERVICE_NAME - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import BatchSpanProcessor - - trace.set_tracer_provider( - TracerProvider(resource=Resource.create({SERVICE_NAME: service_name})) - ) - jaeger_exporter = JaegerExporter( - # agent_host_name=jaeger_host, - # agent_port=jaeger_port, - collector_endpoint=f"http://{jaeger_host}:{jaeger_port}/api/traces?format=jaeger.thrift", - # udp_split_oversized_batches=True, - ) - - trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(jaeger_exporter)) - - # from opentelemetry.sdk.trace.export import ConsoleSpanExporter - # console_exporter = ConsoleSpanExporter() - # span_processor = BatchSpanProcessor(console_exporter) - # trace.get_tracer_provider().add_span_processor(span_processor) - - # third party - import opentelemetry.instrumentation.requests - - opentelemetry.instrumentation.requests.RequestsInstrumentor().instrument() - - # relative - # from opentelemetry.instrumentation.digma.trace_decorator import ( - # instrument as _instrument, - # ) - # - # until this is merged: - # https://github.com/digma-ai/opentelemetry-instrumentation-digma/pull/41 - from .trace_decorator import instrument as _instrument - - instrument = _instrument -except Exception: # nosec - print("Failed to import opentelemetry") - instrument = noop +else: + try: + print("OpenTelemetry Tracing enabled") + service_name = os.environ.get("SERVICE_NAME", "client") + jaeger_host = os.environ.get("JAEGER_HOST", "localhost") + jaeger_port = int(os.environ.get("JAEGER_PORT", "14268")) + + # third party + from opentelemetry import trace + from opentelemetry.exporter.jaeger.thrift import JaegerExporter + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.resources import SERVICE_NAME + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + trace.set_tracer_provider( + TracerProvider(resource=Resource.create({SERVICE_NAME: service_name})) + ) + jaeger_exporter = JaegerExporter( + # agent_host_name=jaeger_host, + # agent_port=jaeger_port, + collector_endpoint=f"http://{jaeger_host}:{jaeger_port}/api/traces?format=jaeger.thrift", + # udp_split_oversized_batches=True, + ) + + trace.get_tracer_provider().add_span_processor( + BatchSpanProcessor(jaeger_exporter) + ) + + # from opentelemetry.sdk.trace.export import ConsoleSpanExporter + # console_exporter = ConsoleSpanExporter() + # span_processor = BatchSpanProcessor(console_exporter) + # trace.get_tracer_provider().add_span_processor(span_processor) + + # third party + import opentelemetry.instrumentation.requests + + opentelemetry.instrumentation.requests.RequestsInstrumentor().instrument() + + # relative + # from opentelemetry.instrumentation.digma.trace_decorator import ( + # instrument as _instrument, + # ) + # + # until this is merged: + # https://github.com/digma-ai/opentelemetry-instrumentation-digma/pull/41 + from .trace_decorator import instrument as _instrument + + instrument = _instrument + except Exception: # nosec + print("Failed to import opentelemetry") + instrument = noop diff --git a/packages/syft/src/syft/util/version_compare.py b/packages/syft/src/syft/util/version_compare.py index 606606bb6e8..ffef1102ad6 100644 --- a/packages/syft/src/syft/util/version_compare.py +++ b/packages/syft/src/syft/util/version_compare.py @@ -33,7 +33,9 @@ def get_operator(version_string: str) -> Tuple[str, Callable, str]: return version_string, op, op_char -def check_rule(version_string: str, LATEST_STABLE_SYFT: str, __version__: str) -> tuple: +def check_rule( + version_string: str, LATEST_STABLE_SYFT: str, __version__: str +) -> tuple[Any, list[str], list[str]]: version_string, op, op_char = get_operator(version_string) syft_version = version.parse(__version__) stable_version = version.parse(LATEST_STABLE_SYFT) diff --git a/packages/syft/tests/syft/action_graph/action_graph_service_test.py b/packages/syft/tests/syft/action_graph/action_graph_service_test.py index 22199150ad3..3cac37f975e 100644 --- a/packages/syft/tests/syft/action_graph/action_graph_service_test.py +++ b/packages/syft/tests/syft/action_graph/action_graph_service_test.py @@ -105,7 +105,6 @@ def test_action_graph_service_add_action_no_mutagen( assert action_node.status == ExecutionStatus.PROCESSING assert action_node.retry == 0 assert isinstance(action_node.created_at, DateTime) - assert action_node.updated_at is None assert action_node.user_verify_key == authed_context.credentials assert action_node.is_mutated is False assert action_node.is_mutagen is False @@ -117,7 +116,6 @@ def test_action_graph_service_add_action_no_mutagen( assert result_node.status == ExecutionStatus.PROCESSING assert result_node.retry == 0 assert isinstance(result_node.created_at, DateTime) - assert result_node.updated_at is None assert result_node.user_verify_key == authed_context.credentials assert result_node.is_mutated is False assert result_node.is_mutagen is False @@ -168,7 +166,6 @@ def test_action_graph_service_add_action_mutagen( assert result_node.id == action.result_id.id assert action_node.type == NodeType.ACTION assert result_node.type == NodeType.ACTION_OBJECT - assert result_node.updated_at is None assert result_node.is_mutated is False assert result_node.is_mutagen is False assert result_node.next_mutagen_node is None diff --git a/packages/syft/tests/syft/action_graph/action_graph_test.py b/packages/syft/tests/syft/action_graph/action_graph_test.py index d01466afa75..8e7e235105a 100644 --- a/packages/syft/tests/syft/action_graph/action_graph_test.py +++ b/packages/syft/tests/syft/action_graph/action_graph_test.py @@ -30,7 +30,7 @@ from syft.service.action.action_object import Action from syft.service.action.action_object import ActionObject from syft.store.document_store import QueryKeys -from syft.store.locks import NoLockingConfig +from syft.store.locks import ThreadingLockingConfig from syft.types.datetime import DateTime from syft.types.syft_metaclass import Empty from syft.types.uid import UID @@ -164,7 +164,7 @@ def test_in_memory_store_client_config() -> None: def test_in_memory_graph_config() -> None: store_config = InMemoryGraphConfig() default_client_conf = InMemoryStoreClientConfig() - locking_config = NoLockingConfig() + locking_config = ThreadingLockingConfig() assert store_config.client_config == default_client_conf assert store_config.store_type == NetworkXBackingStore diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index f38a5d91d0f..c735750205f 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -60,7 +60,7 @@ def test_blob_storage_write_syft_object(): ) blob_data = CreateBlobStorageEntry.from_obj(data) blob_deposit = blob_storage.allocate(authed_context, blob_data) - user = UserCreate(email="info@openmined.org") + user = UserCreate(email="info@openmined.org", name="Jana Doe", password="password") file_data = io.BytesIO(sy.serialize(user, to_bytes=True)) written_data = blob_deposit.write(file_data) diff --git a/packages/syft/tests/syft/dataset/dataset_stash_test.py b/packages/syft/tests/syft/dataset/dataset_stash_test.py index 5befab5f759..0e226397edf 100644 --- a/packages/syft/tests/syft/dataset/dataset_stash_test.py +++ b/packages/syft/tests/syft/dataset/dataset_stash_test.py @@ -3,6 +3,7 @@ # third party import pytest +from typeguard import TypeCheckError # syft absolute from syft.service.dataset.dataset import Dataset @@ -46,7 +47,7 @@ def test_dataset_actionidpartitionkey() -> None: ActionIDsPartitionKey.with_obj(obj="dummy_str") # Not sure what Exception should be raised here, Type or Attibute - with pytest.raises(TypeError): + with pytest.raises(TypeCheckError): ActionIDsPartitionKey.with_obj(obj=["first_str", "second_str"]) diff --git a/packages/syft/tests/syft/locks_test.py b/packages/syft/tests/syft/locks_test.py index f6b20a85f69..8e1c9f3fac0 100644 --- a/packages/syft/tests/syft/locks_test.py +++ b/packages/syft/tests/syft/locks_test.py @@ -12,18 +12,14 @@ from joblib import Parallel from joblib import delayed import pytest -from pytest_mock_resources import create_redis_fixture # syft absolute from syft.store.locks import FileLockingConfig from syft.store.locks import LockingConfig from syft.store.locks import NoLockingConfig -from syft.store.locks import RedisLockingConfig from syft.store.locks import SyftLock from syft.store.locks import ThreadingLockingConfig -redis_server_mock = create_redis_fixture(scope="session") - def_params = { "lock_name": "testing_lock", "expire": 5, # seconds, @@ -55,20 +51,12 @@ def locks_file_config(): return FileLockingConfig(**def_params) -@pytest.fixture(scope="function") -def locks_redis_config(redis_server_mock): - def_params["lock_name"] = generate_lock_name() - redis_config = redis_server_mock.pmr_credentials.as_redis_kwargs() - return RedisLockingConfig(**def_params, client=redis_config) - - @pytest.mark.parametrize( "config", [ pytest.lazy_fixture("locks_nop_config"), pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -106,7 +94,6 @@ def test_acquire_nop(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -137,7 +124,6 @@ def test_acquire_release(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -158,7 +144,6 @@ def test_acquire_release_with(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -189,7 +174,6 @@ def test_acquire_expire(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -216,7 +200,6 @@ def test_acquire_double_aqcuire_timeout_fail(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -245,7 +228,6 @@ def test_acquire_double_aqcuire_timeout_ok(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -274,7 +256,6 @@ def test_acquire_double_aqcuire_nonblocking(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -304,7 +285,6 @@ def test_acquire_double_aqcuire_retry_interval(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -325,7 +305,6 @@ def test_acquire_double_release(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -351,7 +330,6 @@ def test_acquire_same_name_diff_namespace(config: LockingConfig): [ pytest.lazy_fixture("locks_threading_config"), pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( @@ -414,7 +392,6 @@ def _kv_cbk(tid: int) -> None: "config", [ pytest.lazy_fixture("locks_file_config"), - pytest.lazy_fixture("locks_redis_config"), ], ) @pytest.mark.skipif( diff --git a/packages/syft/tests/syft/request/request_code_accept_deny_test.py b/packages/syft/tests/syft/request/request_code_accept_deny_test.py index 7451dd26d91..b3776a56987 100644 --- a/packages/syft/tests/syft/request/request_code_accept_deny_test.py +++ b/packages/syft/tests/syft/request/request_code_accept_deny_test.py @@ -11,6 +11,7 @@ from syft.node.worker import Worker from syft.service.action.action_object import ActionObject from syft.service.action.action_permissions import ActionPermission +from syft.service.code.user_code import UserCode from syft.service.code.user_code import UserCodeStatus from syft.service.context import ChangeContext from syft.service.request.request import ActionStoreChange @@ -137,12 +138,14 @@ def simple_function(data): result = ds_client.code.submit(simple_function) assert isinstance(result, SyftSuccess) - user_code = ds_client.code.get_all()[0] + user_code: UserCode = ds_client.code.get_all()[0] - linked_obj = LinkedObject.from_obj(user_code, node_uid=worker.id) + linked_user_code = LinkedObject.from_obj(user_code, node_uid=worker.id) user_code_change = UserCodeStatusChange( - value=UserCodeStatus.APPROVED, linked_obj=linked_obj + value=UserCodeStatus.APPROVED, + linked_user_code=linked_user_code, + linked_obj=user_code.status_link, ) change_context = ChangeContext( diff --git a/packages/syft/tests/syft/request/request_multiple_nodes_test.py b/packages/syft/tests/syft/request/request_multiple_nodes_test.py index 9ec214ea7fe..4c644790ca7 100644 --- a/packages/syft/tests/syft/request/request_multiple_nodes_test.py +++ b/packages/syft/tests/syft/request/request_multiple_nodes_test.py @@ -124,7 +124,7 @@ def compute_sum(data) -> float: # Submit + execute on second node request_1_do = client_do_1.requests[0] - client_do_2.code.sync_code_from_request(request_1_do) + client_do_2.sync_code_from_request(request_1_do) # DO executes + syncs client_do_2._fetch_api(client_do_2.credentials) @@ -162,7 +162,7 @@ def compute_mean(data) -> float: # Submit + execute on second node request_1_do = client_do_1.requests[0] - client_do_2.code.sync_code_from_request(request_1_do) + client_do_2.sync_code_from_request(request_1_do) client_do_2._fetch_api(client_do_2.credentials) job_2 = client_do_2.code.compute_mean(data=dataset_2, blocking=False) diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index 131b93d70da..d5eefcd7f77 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -25,13 +25,12 @@ from syft.service.action.action_object import propagate_node_uid from syft.service.action.action_object import send_action_side_effect from syft.service.action.action_types import action_type_for_type +from syft.types.uid import LineageID +from syft.types.uid import UID def helper_make_action_obj(orig_obj: Any): - obj_id = Action.make_id(None) - lin_obj_id = Action.make_result_id(obj_id) - - return ActionObject.from_obj(orig_obj, id=obj_id, syft_lineage_id=lin_obj_id) + return ActionObject.from_obj(orig_obj) def helper_make_action_pointers(worker, obj, *args, **kwargs): @@ -62,15 +61,13 @@ def helper_make_action_pointers(worker, obj, *args, **kwargs): def test_action_sanity(path_op: Tuple[str, str]): path, op = path_op - remote_self = Action.make_result_id(None) - result_id = Action.make_result_id(None) + remote_self = LineageID() new_action = Action( path=path, op=op, remote_self=remote_self, args=[], kwargs={}, - result_id=result_id, ) assert new_action is not None assert new_action.full_path == f"{path}.{op}" @@ -99,22 +96,22 @@ def test_actionobject_from_obj_sanity(orig_obj: Any): assert obj.syft_history_hash is not None # with id - obj_id = Action.make_id(None) + obj_id = UID() obj = ActionObject.from_obj(orig_obj, id=obj_id) assert obj.id == obj_id assert obj.syft_history_hash == hash(obj_id) # with id and lineage id - obj_id = Action.make_id(None) - lin_obj_id = Action.make_result_id(obj_id) + obj_id = UID() + lin_obj_id = LineageID(obj_id) obj = ActionObject.from_obj(orig_obj, id=obj_id, syft_lineage_id=lin_obj_id) assert obj.id == obj_id assert obj.syft_history_hash == lin_obj_id.syft_history_hash def test_actionobject_from_obj_fail_id_mismatch(): - obj_id = Action.make_id(None) - lineage_id = Action.make_result_id(None) + obj_id = UID() + lineage_id = LineageID() with pytest.raises(ValueError): ActionObject.from_obj("abc", id=obj_id, syft_lineage_id=lineage_id) @@ -131,14 +128,14 @@ def test_actionobject_make_empty_sanity(dtype: Type): assert obj.syft_history_hash is not None # with id - obj_id = Action.make_id(None) + obj_id = UID() obj = ActionObject.empty(syft_internal_type=syft_type, id=obj_id) assert obj.id == obj_id assert obj.syft_history_hash == hash(obj_id) # with id and lineage id - obj_id = Action.make_id(None) - lin_obj_id = Action.make_result_id(obj_id) + obj_id = UID() + lin_obj_id = LineageID(obj_id) obj = ActionObject.empty( syft_internal_type=syft_type, id=obj_id, syft_lineage_id=lin_obj_id ) @@ -163,12 +160,12 @@ def test_actionobject_make_empty_sanity(dtype: Type): def test_actionobject_hooks_init(orig_obj: Any): obj = ActionObject.from_obj(orig_obj) - assert HOOK_ALWAYS in obj._syft_pre_hooks__ - assert HOOK_ALWAYS in obj._syft_post_hooks__ + assert HOOK_ALWAYS in obj.syft_pre_hooks__ + assert HOOK_ALWAYS in obj.syft_post_hooks__ - assert make_action_side_effect in obj._syft_pre_hooks__[HOOK_ALWAYS] - assert send_action_side_effect in obj._syft_pre_hooks__[HOOK_ON_POINTERS] - assert propagate_node_uid in obj._syft_post_hooks__[HOOK_ALWAYS] + assert make_action_side_effect in obj.syft_pre_hooks__[HOOK_ALWAYS] + assert send_action_side_effect in obj.syft_pre_hooks__[HOOK_ON_POINTERS] + assert propagate_node_uid in obj.syft_post_hooks__[HOOK_ALWAYS] @pytest.mark.parametrize( @@ -302,7 +299,7 @@ def test_actionobject_hooks_propagate_node_uid_ok(): orig_obj = "abc" op = "capitalize" - obj_id = Action.make_id(None) + obj_id = UID() obj = ActionObject.from_obj(orig_obj) obj.syft_point_to(obj_id) @@ -315,7 +312,7 @@ def test_actionobject_hooks_propagate_node_uid_ok(): def test_actionobject_syft_point_to(): orig_obj = "abc" - obj_id = Action.make_id(None) + obj_id = UID() obj = ActionObject.from_obj(orig_obj) obj.syft_point_to(obj_id) @@ -587,7 +584,7 @@ def test_actionobject_syft_execute_hooks(worker, testcase): ) assert context.result_id is not None - context.obj.syft_node_uid = Action.make_id(None) + context.obj.syft_node_uid = UID() result = obj_pointer._syft_run_post_hooks__(context, name=op, result=obj_pointer) assert result.syft_node_uid == context.obj.syft_node_uid diff --git a/packages/syft/tests/syft/service/dataset/dataset_service_test.py b/packages/syft/tests/syft/service/dataset/dataset_service_test.py index 8b4f1380961..a60bc653c13 100644 --- a/packages/syft/tests/syft/service/dataset/dataset_service_test.py +++ b/packages/syft/tests/syft/service/dataset/dataset_service_test.py @@ -16,6 +16,7 @@ from syft.service.dataset.dataset import CreateDataset as Dataset from syft.service.dataset.dataset import _ASSET_WITH_NONE_MOCK_ERROR_MESSAGE from syft.service.response import SyftError +from syft.service.response import SyftException from syft.service.response import SyftSuccess from syft.types.twin_object import TwinMode @@ -62,8 +63,9 @@ def make_asset_with_empty_mock() -> dict[str, Any]: def test_asset_without_mock_mock_is_real_must_be_false( asset_without_mock: dict[str, Any], ): - with pytest.raises(ValidationError): - Asset(**asset_without_mock, mock_is_real=True) + asset = Asset(**asset_without_mock, mock_is_real=True) + asset.mock_is_real = True + assert not asset.mock_is_real def test_mock_always_not_real_after_calling_no_mock( @@ -85,8 +87,8 @@ def test_mock_always_not_real_after_set_mock_to_empty( asset.no_mock() assert not asset.mock_is_real - with pytest.raises(ValidationError): - asset.mock_is_real = True + asset.mock_is_real = True + assert not asset.mock_is_real asset.mock = mock() asset.mock_is_real = True @@ -102,8 +104,8 @@ def test_mock_always_not_real_after_set_to_empty( asset.mock = ActionObject.empty() assert not asset.mock_is_real - with pytest.raises(ValidationError): - asset.mock_is_real = True + asset.mock_is_real = True + assert not asset.mock_is_real asset.mock = mock() asset.mock_is_real = True @@ -123,7 +125,7 @@ def test_cannot_set_empty_mock_with_true_mock_is_real( asset = Asset(**asset_with_mock, mock_is_real=True) assert asset.mock_is_real - with pytest.raises(ValidationError): + with pytest.raises(SyftException): asset.set_mock(empty_mock, mock_is_real=True) assert asset.mock is asset_with_mock["mock"] diff --git a/packages/syft/tests/syft/service/sync/sync_flow_test.py b/packages/syft/tests/syft/service/sync/sync_flow_test.py new file mode 100644 index 00000000000..e08e43383bb --- /dev/null +++ b/packages/syft/tests/syft/service/sync/sync_flow_test.py @@ -0,0 +1,200 @@ +# stdlib +import sys +from textwrap import dedent + +# third party +import numpy as np +import pytest + +# syft absolute +import syft as sy +from syft.abstract_node import NodeSideType +from syft.client.syncing import compare_states +from syft.client.syncing import resolve +from syft.service.action.action_object import ActionObject + + +@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") +@pytest.mark.flaky(reruns=5, reruns_delay=1) +def test_sync_flow(): + # somehow skipif does not work + if sys.platform == "win32": + return + low_worker = sy.Worker( + name="low-test", + local_db=True, + n_consumers=1, + create_producer=True, + node_side_type=NodeSideType.LOW_SIDE, + ) + high_worker = sy.Worker( + name="high-test", + local_db=True, + n_consumers=1, + create_producer=True, + node_side_type=NodeSideType.HIGH_SIDE, + ) + + low_client = low_worker.root_client + high_client = high_worker.root_client + + low_client.register( + email="newuser@openmined.org", + name="John Doe", + password="pw", + password_verify="pw", + ) + client_low_ds = low_worker.guest_client + + mock_high = np.array([10, 11, 12, 13, 14]) + private_high = np.array([15, 16, 17, 18, 19]) + + dataset_high = sy.Dataset( + name="my-dataset", + description="abc", + asset_list=[ + sy.Asset( + name="numpy-data", + mock=mock_high, + data=private_high, + shape=private_high.shape, + mock_is_real=True, + ) + ], + ) + + high_client.upload_dataset(dataset_high) + mock_low = np.array([0, 1, 2, 3, 4]) # do_high.mock + + dataset_low = sy.Dataset( + id=dataset_high.id, + name="my-dataset", + description="abc", + asset_list=[ + sy.Asset( + name="numpy-data", + mock=mock_low, + data=ActionObject.empty(data_node_id=high_client.id), + shape=mock_low.shape, + mock_is_real=True, + ) + ], + ) + + res = low_client.upload_dataset(dataset_low) + + data_low = client_low_ds.datasets[0].assets[0] + + @sy.syft_function_single_use(data=data_low) + def compute_mean(data) -> float: + return data.mean() + + compute_mean.code = dedent(compute_mean.code) + + res = client_low_ds.code.request_code_execution(compute_mean) + print(res) + print("LOW CODE:", low_client.code.get_all()) + + low_state = low_client.sync.get_state() + high_state = high_client.sync.get_state() + + print(low_state.objects, high_state.objects) + + diff_state = compare_states(low_state, high_state) + low_items_to_sync, high_items_to_sync = resolve( + diff_state, decision="low", share_private_objects=True + ) + + print(low_items_to_sync, high_items_to_sync) + + low_client.apply_state(low_items_to_sync) + + high_client.apply_state(high_items_to_sync) + + low_state = low_client.sync.get_state() + high_state = high_client.sync.get_state() + + diff_state = compare_states(low_state, high_state) + + high_client._fetch_api(high_client.credentials) + + data_high = high_client.datasets[0].assets[0] + + print(high_client.code.get_all()) + job_high = high_client.code.compute_mean(data=data_high, blocking=False) + print("Waiting for job...") + job_high.wait() + job_high.result.get() + + # syft absolute + from syft.service.request.request import Request + + request: Request = high_client.requests[0] + job_info = job_high.info(public_metadata=True, result=True) + + print(request.syft_client_verify_key, request.syft_node_location) + print(request.code.syft_client_verify_key, request.code.syft_node_location) + request.accept_by_depositing_result(job_info) + + request = high_client.requests[0] + code = request.code + job_high._get_log_objs() + + action_store_high = high_worker.get_service("actionservice").store + blob_store_high = high_worker.get_service("blobstorageservice").stash.partition + assert ( + f"{client_low_ds.verify_key}_READ" + in action_store_high.permissions[job_high.result.id.id] + ) + assert ( + f"{client_low_ds.verify_key}_READ" + in blob_store_high.permissions[job_high.result.syft_blob_storage_entry_id] + ) + + low_state = low_client.sync.get_state() + high_state = high_client.sync.get_state() + + diff_state_2 = compare_states(low_state, high_state) + + low_items_to_sync, high_items_to_sync = resolve( + diff_state_2, decision="high", share_private_objects=True + ) + for diff in diff_state_2.diffs: + print(diff.status, diff.object_type) + low_client.apply_state(low_items_to_sync) + + action_store_low = low_worker.get_service("actionservice").store + blob_store_low = low_worker.get_service("blobstorageservice").stash.partition + assert ( + f"{client_low_ds.verify_key}_READ" + in action_store_low.permissions[job_high.result.id.id] + ) + assert ( + f"{client_low_ds.verify_key}_READ" + in blob_store_low.permissions[job_high.result.syft_blob_storage_entry_id] + ) + + low_state = low_client.sync.get_state() + high_state = high_client.sync.get_state() + res_low = client_low_ds.code.compute_mean(data=data_low) + print("Res Low", res_low) + + assert res_low.get() == private_high.mean() + + assert ( + res_low.id == job_high.result.id.id == code.output_history[-1].outputs[0].id.id + ) + assert ( + job_high.result.syft_blob_storage_entry_id == res_low.syft_blob_storage_entry_id + ) + + job_low = client_low_ds.code.compute_mean(data=data_low, blocking=False) + + assert job_low.id == job_high.id + assert job_low.result.id == job_high.result.id + assert ( + job_low.result.syft_blob_storage_entry_id + == job_high.result.syft_blob_storage_entry_id + ) + low_worker.close() + high_worker.close() diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index d0a6def902d..d359eb2848f 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -158,8 +158,12 @@ def mock_stash_get_all(root_verify_key) -> Ok: assert response.is_ok() is True assert len(response.ok()) == len(mock_stash_get_all_output) - assert updated_settings == new_settings # the first settings is updated - assert not_updated_settings == settings # the second settings is not updated + assert ( + updated_settings.model_dump() == new_settings.model_dump() + ) # the first settings is updated + assert ( + not_updated_settings.model_dump() == settings.model_dump() + ) # the second settings is not updated def test_settingsservice_update_stash_get_all_fail( diff --git a/packages/syft/tests/syft/stores/dict_document_store_test.py b/packages/syft/tests/syft/stores/dict_document_store_test.py index d9141c8ad6c..e1280ddfdf9 100644 --- a/packages/syft/tests/syft/stores/dict_document_store_test.py +++ b/packages/syft/tests/syft/stores/dict_document_store_test.py @@ -4,6 +4,7 @@ # syft absolute from syft.store.dict_document_store import DictStorePartition from syft.store.document_store import QueryKeys +from syft.types.uid import UID # relative from .store_mocks_test import MockObjectType @@ -25,7 +26,7 @@ def test_dict_store_partition_set( res = dict_store_partition.init_store() assert res.is_ok() - obj = MockSyftObject(data=1) + obj = MockSyftObject(id=UID(), data=1) res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) assert res.is_ok() diff --git a/packages/syft/tests/syft/stores/sqlite_document_store_test.py b/packages/syft/tests/syft/stores/sqlite_document_store_test.py index 5d738eddb62..11f6dd38b60 100644 --- a/packages/syft/tests/syft/stores/sqlite_document_store_test.py +++ b/packages/syft/tests/syft/stores/sqlite_document_store_test.py @@ -246,9 +246,8 @@ def _kv_cbk(tid: int) -> None: root_verify_key, sqlite_workspace ) for idx in range(repeats): - obj = MockObjectType(data=idx) - for _ in range(10): + obj = MockObjectType(data=idx) res = sqlite_store_partition.set( root_verify_key, obj, ignore_duplicates=False ) diff --git a/packages/syft/tests/syft/transforms/transform_methods_test.py b/packages/syft/tests/syft/transforms/transform_methods_test.py index 4010e454ce3..40669b0db5d 100644 --- a/packages/syft/tests/syft/transforms/transform_methods_test.py +++ b/packages/syft/tests/syft/transforms/transform_methods_test.py @@ -5,8 +5,8 @@ from typing import Optional # third party -from pydantic import EmailError from pydantic import EmailStr +from pydantic_core import PydanticCustomError import pytest # syft absolute @@ -421,7 +421,7 @@ def __iter__(self): ) result = validate_email(transform_context) assert isinstance(result, TransformContext) - assert isinstance(result.output["email"], EmailStr) + assert EmailStr._validate(result.output["email"]) assert result.output["email"] == mock_obj.email mock_obj = MockObject(email=faker.name()) @@ -429,5 +429,5 @@ def __iter__(self): obj=mock_obj, context=node_context ) - with pytest.raises(EmailError): + with pytest.raises(PydanticCustomError): validate_email(transform_context) diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index b4d05f66498..5720244cfdc 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -50,7 +50,7 @@ def test_user_code(worker) -> None: root_domain_client = worker.root_client message = root_domain_client.notifications[-1] request = message.link - user_code = request.changes[0].link + user_code = request.changes[0].code result = user_code.unsafe_function() request.accept_by_depositing_result(result) @@ -124,7 +124,9 @@ def func(asset): c for c in request.changes if (isinstance(c, UserCodeStatusChange)) ) - assert status_change.linked_obj.resolve.assets[0] == asset_input + assert status_change.code.assets[0].model_dump( + mode="json" + ) == asset_input.model_dump(mode="json") @sy.syft_function() diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 94e3d7a5deb..b372fa5d690 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -172,7 +172,7 @@ def mock_get_by_uid(credentials: SyftVerifyKey, uid: UID) -> Ok: monkeypatch.setattr(user_service.stash, "get_by_uid", mock_get_by_uid) response = user_service.view(authed_context, uid_to_view) assert isinstance(response, UserView) - assert response == expected_output + assert response.model_dump() == expected_output.model_dump() def test_userservice_get_all_success( @@ -192,7 +192,10 @@ def mock_get_all(credentials: SyftVerifyKey) -> Ok: response = user_service.get_all(authed_context) assert isinstance(response, List) assert len(response) == len(expected_output) - assert response == expected_output + assert all( + r.model_dump() == expected.model_dump() + for r, expected in zip(response, expected_output) + ) def test_userservice_get_all_error( @@ -230,17 +233,27 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Union[Ok, Err]: # Search via id response = user_service.search(authed_context, id=guest_user.id) assert isinstance(response, List) - assert response == expected_output + assert all( + r.model_dump() == expected.model_dump() + for r, expected in zip(response, expected_output) + ) + # assert response.model_dump() == expected_output.model_dump() # Search via email response = user_service.search(authed_context, email=guest_user.email) assert isinstance(response, List) - assert response == expected_output + assert all( + r.model_dump() == expected.model_dump() + for r, expected in zip(response, expected_output) + ) # Search via name response = user_service.search(authed_context, name=guest_user.name) assert isinstance(response, List) - assert response == expected_output + assert all( + r.model_dump() == expected.model_dump() + for r, expected in zip(response, expected_output) + ) # Search via verify_key response = user_service.search( @@ -248,14 +261,20 @@ def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> Union[Ok, Err]: verify_key=guest_user.verify_key, ) assert isinstance(response, List) - assert response == expected_output + assert all( + r.model_dump() == expected.model_dump() + for r, expected in zip(response, expected_output) + ) # Search via multiple kwargs response = user_service.search( authed_context, name=guest_user.name, email=guest_user.email ) assert isinstance(response, List) - assert response == expected_output + assert all( + r.model_dump() == expected.model_dump() + for r, expected in zip(response, expected_output) + ) def test_userservice_search_with_invalid_kwargs( diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 03772951d71..268e03c10c5 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -210,8 +210,8 @@ def post_add(context: Any, name: str, new_result: Any) -> Any: # change return type to sum return Ok(sum(new_result)) - action_object._syft_pre_hooks__["__add__"] = [pre_add] - action_object._syft_post_hooks__["__add__"] = [post_add] + action_object.syft_pre_hooks__["__add__"] = [pre_add] + action_object.syft_post_hooks__["__add__"] = [post_add] result = action_object + action_object x = result.syft_action_data @@ -219,8 +219,8 @@ def post_add(context: Any, name: str, new_result: Any) -> Any: assert y == 18 assert x == y - action_object._syft_pre_hooks__["__add__"] = [] - action_object._syft_post_hooks__["__add__"] = [] + action_object.syft_pre_hooks__["__add__"] = [] + action_object.syft_post_hooks__["__add__"] = [] def test_worker_serde() -> None: diff --git a/scripts/get_k8s_secret_ci.sh b/scripts/get_k8s_secret_ci.sh index 11d5bb9d767..965e8ff0896 100644 --- a/scripts/get_k8s_secret_ci.sh +++ b/scripts/get_k8s_secret_ci.sh @@ -1,6 +1,6 @@ #!/bin/bash -export SYFT_LOGIN_testgateway1_PASSWORD=$(kubectl --context=k3d-testgateway1 get secret syft-default-secret -n testgateway1 \ +export SYFT_LOGIN_testgateway1_PASSWORD=$(kubectl --context=k3d-testgateway1 get secret backend-secret -n syft \ -o jsonpath='{.data.defaultRootPassword}' | base64 --decode) -export SYFT_LOGIN_testdomain1_PASSWORD=$(kubectl get --context=k3d-testdomain1 secret syft-default-secret -n testdomain1 \ - -o jsonpath='{.data.defaultRootPassword}' | base64 --decode) \ No newline at end of file +export SYFT_LOGIN_testdomain1_PASSWORD=$(kubectl get --context=k3d-testdomain1 secret backend-secret -n syft \ + -o jsonpath='{.data.defaultRootPassword}' | base64 --decode) diff --git a/tests/integration/container_workload/blob_storage_test.py b/tests/integration/container_workload/blob_storage_test.py new file mode 100644 index 00000000000..869ae06f1ba --- /dev/null +++ b/tests/integration/container_workload/blob_storage_test.py @@ -0,0 +1,26 @@ +# stdlib +import os + +# third party +import pytest + +# syft absolute +import syft as sy + + +@pytest.mark.container_workload +def test_mount_azure_blob_storage(domain_1_port): + domain_client = sy.login( + email="info@openmined.org", password="changethis", port=domain_1_port + ) + domain_client.api.services.blob_storage.mount_azure( + account_name="citestingstorageaccount", + container_name="citestingcontainer", + account_key=os.environ["AZURE_BLOB_STORAGE_KEY"], + bucket_name="helmazurebucket", + ) + blob_files = domain_client.api.services.blob_storage.get_files_from_bucket( + bucket_name="helmazurebucket" + ) + document = [f for f in blob_files if "testfile.txt" in f.file_name][0] + assert document.read() == b"abc\n" diff --git a/tox.ini b/tox.ini index 0a368808cac..57a20cf08ba 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ envlist = syft.build.helm syft.package.helm syft.test.helm + syft.test.helm.upgrade syft.protocol.check syftcli.test.unit syftcli.publish @@ -232,7 +233,7 @@ commands = bash -c 'HAGRID_ART=$HAGRID_ART hagrid launch test_domain_1 domain to docker:9081 $HAGRID_FLAGS --enable-signup --no-health-checks --verbose --no-warnings' - bash -c '(docker logs test_domain_1-frontend-1 -f &) | grep -q -E "Network:\s+https?://[a-zA-Z0-9.-]+:[0-9]+/" || true' + bash -c '(docker logs test-domain-1-frontend-1 -f &) | grep -q -E "Network:\s+https?://[a-zA-Z0-9.-]+:[0-9]+/" || true' bash -c '(docker logs test_domain_1-backend-1 -f &) | grep -q "Application startup complete" || true' pnpm install @@ -256,15 +257,16 @@ allowlist_externals = sleep bash chcp -passenv=HOME, USER +passenv=HOME, USER, AZURE_BLOB_STORAGE_KEY setenv = - HAGRID_FLAGS = {env:HAGRID_FLAGS:--tag=local --release=development --test} + HAGRID_FLAGS = {env:HAGRID_FLAGS:--tag=local --release=development --dev} EMULATION = {env:EMULATION:false} HAGRID_ART = false PYTHONIOENCODING = utf-8 PYTEST_MODULES = {env:PYTEST_MODULES:frontend container_workload network e2e security redis} commands = bash -c "whoami; id;" + bash -c "env" bash -c "echo Running with HAGRID_FLAGS=$HAGRID_FLAGS EMULATION=$EMULATION PYTEST_MODULES=$PYTEST_MODULES; date" @@ -322,21 +324,11 @@ commands = pytest tests/integration -m frontend -p no:randomly --co; \ pytest tests/integration -m frontend -vvvv -p no:randomly -p no:benchmark -o log_cli=True --capture=no; \ return=$?; \ - docker stop test_domain_1-frontend-1 || true; \ + docker stop test-domain-1-frontend-1 || true; \ echo "Finished frontend"; date; \ exit $return; \ fi' - ; container workload - bash -c 'if [[ "$PYTEST_MODULES" == *"container_workload"* ]]; then \ - echo "Starting Container Workload test"; date; \ - pytest tests/integration -m container_workload -p no:randomly --co; \ - pytest tests/integration -m container_workload -vvvv -p no:randomly -p no:benchmark -o log_cli=True --capture=no; \ - return=$?; \ - echo "Finished container workload"; date; \ - exit $return; \ - fi' - ; network bash -c 'if [[ "$PYTEST_MODULES" == *"network"* ]]; then \ echo "Starting network"; date; \ @@ -347,6 +339,16 @@ commands = exit $return; \ fi' + ; container workload + bash -c 'if [[ "$PYTEST_MODULES" == *"container_workload"* ]]; then \ + echo "Starting Container Workload test"; date; \ + pytest tests/integration -m container_workload -p no:randomly --co; \ + pytest tests/integration -m container_workload -vvvv -p no:randomly -p no:benchmark -o log_cli=True --capture=no; \ + return=$?; \ + echo "Finished container workload"; date; \ + exit $return; \ + fi' + ; shutdown bash -c "echo Killing Nodes; date" bash -c 'HAGRID_ART=false hagrid land all --force' @@ -451,8 +453,8 @@ allowlist_externals = passenv=HOME, USER setenv = LOCAL_ENCLAVE_PORT=8010 - ENABLE_OBLV=true - DOMAIN_CONNECTION_PORT=8010 + OBLV_ENABLED=true + OBLV_LOCALHOST_PORT=8010 ENABLE_SIGNUP=True commands = pip install oblv-ctl==0.3.1 @@ -689,49 +691,25 @@ commands = bash -c "k3d cluster delete testgateway1 || true" bash -c "k3d cluster delete testdomain1 || true" - # Deleting registery & volumes + # Deleting registry & volumes bash -c "k3d registry delete k3d-registry.localhost || true" bash -c "docker volume rm k3d-testgateway1-images --force || true" bash -c "docker volume rm k3d-testdomain1-images --force || true" - - # Creating registry - bash -c 'k3d registry create registry.localhost --port 5800 -v `pwd`/k3d-registry:/var/lib/registry || true' + # Create registry + tox -e dev.k8s.registry # Creating testgateway1 cluster on port 9081 - bash -c 'NODE_NAME=testgateway1 NODE_PORT=9081 && \ - k3d cluster create $NODE_NAME -p "$NODE_PORT:80@loadbalancer" --registry-use k3d-registry.localhost || true \ - k3d cluster start $NODE_NAME' - - bash -c 'NODE_NAME=testgateway1 NODE_PORT=9081 && \ - cd packages/grid && \ - (r=5;while ! \ - devspace --no-warn --kube-context "k3d-$NODE_NAME" --namespace $NODE_NAME \ - -p gateway \ - --var NODE_NAME=$NODE_NAME \ - --var TEST_MODE=1 \ - --var CONTAINER_REGISTRY=k3d-registry.localhost:5800 \ - --var NODE_TYPE=gateway \ - deploy -b; \ - do ((--r))||exit;echo "retrying" && sleep 20;done)' + bash -c '\ + export CLUSTER_NAME=testgateway1 CLUSTER_HTTP_PORT=9081 DEVSPACE_PROFILE=gateway && \ + tox -e dev.k8s.start && \ + tox -e dev.k8s.deploy' # Creating testdomain1 cluster on port 9082 - bash -c 'NODE_NAME=testdomain1 NODE_PORT=9082 && \ - k3d cluster create $NODE_NAME -p "$NODE_PORT:80@loadbalancer" --registry-use k3d-registry.localhost || true \ - k3d cluster start $NODE_NAME' - - # Patches CoreDNS - tox -e dev.k8s.patch.coredns - - bash -c 'NODE_NAME=testdomain1 NODE_PORT=9082 && \ - cd packages/grid && \ - (r=5;while ! \ - devspace --no-warn --kube-context "k3d-$NODE_NAME" --namespace $NODE_NAME \ - --var NODE_NAME=$NODE_NAME \ - --var TEST_MODE=1 \ - --var CONTAINER_REGISTRY=k3d-registry.localhost:5800 \ - deploy -b; \ - do ((--r))||exit;echo "retrying" && sleep 20;done)' + bash -c '\ + export CLUSTER_NAME=testdomain1 CLUSTER_HTTP_PORT=9082 && \ + tox -e dev.k8s.start && \ + tox -e dev.k8s.deploy' # free up build cache after build of images bash -c 'if [[ "$GITHUB_CI" != "false" ]]; then \ @@ -742,25 +720,24 @@ commands = sleep 30 # wait for front end - bash packages/grid/scripts/wait_for.sh service frontend --context k3d-testdomain1 --namespace testdomain1 - bash -c '(kubectl logs service/frontend --context k3d-testdomain1 --namespace testdomain1 -f &) | grep -q -E "Network:\s+https?://[a-zA-Z0-9.-]+:[0-9]+/" || true' + bash packages/grid/scripts/wait_for.sh service frontend --context k3d-testdomain1 --namespace syft + bash -c '(kubectl logs service/frontend --context k3d-testdomain1 --namespace syft -f &) | grep -q -E "Network:\s+https?://[a-zA-Z0-9.-]+:[0-9]+/" || true' # wait for test gateway 1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-testgateway1 --namespace testgateway1 - bash packages/grid/scripts/wait_for.sh service backend --context k3d-testgateway1 --namespace testgateway1 - bash packages/grid/scripts/wait_for.sh service proxy --context k3d-testgateway1 --namespace testgateway1 + bash packages/grid/scripts/wait_for.sh service mongo --context k3d-testgateway1 --namespace syft + bash packages/grid/scripts/wait_for.sh service backend --context k3d-testgateway1 --namespace syft + bash packages/grid/scripts/wait_for.sh service proxy --context k3d-testgateway1 --namespace syft # wait for test domain 1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-testdomain1 --namespace testdomain1 - bash packages/grid/scripts/wait_for.sh service backend --context k3d-testdomain1 --namespace testdomain1 - bash packages/grid/scripts/wait_for.sh service proxy --context k3d-testdomain1 --namespace testdomain1 - bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-testdomain1 --namespace testdomain1 + bash packages/grid/scripts/wait_for.sh service mongo --context k3d-testdomain1 --namespace syft + bash packages/grid/scripts/wait_for.sh service backend --context k3d-testdomain1 --namespace syft + bash packages/grid/scripts/wait_for.sh service proxy --context k3d-testdomain1 --namespace syft + bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-testdomain1 --namespace syft # Checking logs generated & startup of test-domain 1 - bash -c '(kubectl logs service/backend --context k3d-testdomain1 --namespace testdomain1 -f &) | grep -q "Application startup complete" || true' + bash -c '(kubectl logs service/backend --context k3d-testdomain1 --namespace syft -f &) | grep -q "Application startup complete" || true' # Checking logs generated & startup of testgateway1 - bash -c '(kubectl logs service/backend --context k3d-testgateway1 --namespace testgateway1 -f &) | grep -q "Application startup complete" || true' - + bash -c '(kubectl logs service/backend --context k3d-testgateway1 --namespace syft -f &) | grep -q "Application startup complete" || true' # frontend bash -c 'if [[ "$PYTEST_MODULES" == *"frontend"* ]]; then \ @@ -772,18 +749,18 @@ commands = exit $return; \ fi' - #Integration + Gateway Connection Tests + # Integration + Gateway Connection Tests # Gateway tests are not run in kuberetes, as currently,it does not have a way to configure # high/low side warning flag. - bash -c " source ./scripts/get_k8s_secret_ci.sh; \ + bash -c "source ./scripts/get_k8s_secret_ci.sh; \ pytest tests/integration/network -k 'not test_domain_gateway_user_code' -p no:randomly -vvvv" # Shutting down the gateway cluster to free up space, as the # below code does not require gateway cluster - bash -c "k3d cluster delete testgateway1 || true" + bash -c "CLUSTER_NAME=testgateway1 tox -e dev.k8s.destroy || true" bash -c "docker volume rm k3d-testgateway1-images --force || true" - ; ; container workload + ; container workload ; bash -c 'if [[ "$PYTEST_MODULES" == *"container_workload"* ]]; then \ ; echo "Starting Container Workload test"; date; \ ; pytest tests/integration -m container_workload -p no:randomly --co; \ @@ -793,27 +770,16 @@ commands = ; exit $return; \ ; fi' - # Since we randomize the password, we retrieve them and store as environment variables - # which would then be used by the notebook - - - # ignore 06 because of opendp on arm64 - # Run 0.8 notebooks - - bash -c " source ./scripts/get_k8s_secret_ci.sh; \ - pytest --nbmake notebooks/api/0.8 -p no:randomly -k 'not 10-container-images.ipynb' -vvvv --nbmake-timeout=1000" - + bash -c "source ./scripts/get_k8s_secret_ci.sh; \ + pytest --nbmake notebooks/api/0.8 -p no:randomly -k 'not 10-container-images.ipynb' -vvvv --nbmake-timeout=1000" # deleting clusters created - bash -c "k3d cluster delete testgateway1 || true" - bash -c "k3d cluster delete testdomain1 || true" + bash -c "CLUSTER_NAME=testdomain1 tox -e dev.k8s.destroy || true" bash -c "k3d registry delete k3d-registry.localhost || true" bash -c "docker rm $(docker ps -aq) --force || true" - bash -c "docker volume rm k3d-testgateway1-images --force || true" bash -c "docker volume rm k3d-testdomain1-images --force || true" - [testenv:syft.build.helm] description = Build Helm Chart for Kubernetes changedir = {toxinidir} @@ -829,6 +795,15 @@ commands = helm lint syft' +[testenv:syft.lint.helm] +description = Lint helm chart +changedir = {toxinidir}/packages/grid/helm +passenv=HOME, USER +allowlist_externals = + bash +commands = + bash -c 'kube-linter lint ./syft --config ./kubelinter-config.yaml' + [testenv:syft.package.helm] description = Package Helm Chart for Kubernetes deps = @@ -893,11 +868,11 @@ commands = # else install the helm charts from the openmined gh-pages branch bash -c 'if [[ $SYFT_VERSION == "local" ]]; then \ echo "Installing local helm charts"; \ - bash -c "cd packages/grid/helm && helm install --kube-context k3d-syft --namespace syft syft ./syft --set configuration.devmode=true"; \ + bash -c "cd packages/grid/helm && helm install --kube-context k3d-syft --namespace syft syft ./syft --set global.useDefaultSecrets=true"; \ else \ echo "Installing helm charts from repo for syft version: ${SYFT_VERSION}"; \ bash -c "helm repo add openmined https://openmined.github.io/PySyft/helm && helm repo update openmined"; \ - bash -c "helm install --kube-context k3d-syft --namespace syft syft openmined/syft --version=${SYFT_VERSION} --set configuration.devmode=true"; \ + bash -c "helm install --kube-context k3d-syft --namespace syft syft openmined/syft --version=${SYFT_VERSION} --set global.useDefaultSecrets=true"; \ fi' ; wait for everything else to be loaded @@ -916,6 +891,17 @@ commands = bash -c "k3d cluster delete syft || true" bash -c "docker volume rm k3d-syft-images --force || true" +[testenv:syft.test.helm.upgrade] +description = Test helm upgrade +changedir = {toxinidir}/packages/grid/ +passenv=HOME,USER,KUBE_CONTEXT +setenv = + UPGRADE_TYPE = {env:UPGRADE_TYPE:ProdToBeta} +allowlist_externals = + bash +commands = + bash ./scripts/helm_upgrade.sh {env:UPGRADE_TYPE} + [testenv:syftcli.test.unit] description = Syft CLI Unit Tests deps = @@ -933,6 +919,9 @@ allowlist_externals = bash sudo commands = + ; check k3d version + bash -c 'k3d --version' + ; create registry bash -c 'k3d registry create registry.localhost --port 5800 -v $HOME/.k3d-registry:/var/lib/registry || true' @@ -940,7 +929,7 @@ commands = bash -c 'if ! grep -q k3d-registry.localhost /etc/hosts; then sudo {envpython} scripts/patch_hosts.py --add-k3d-registry --fix-docker-hosts; fi' ; Fail this command if registry is not working - bash -c 'URL=http://k3d-registry.localhost:5800/v2/_catalog; curl -X GET $URL' + bash -c 'curl --retry 5 --retry-all-errors http://k3d-registry.localhost:5800/v2/_catalog' [testenv:dev.k8s.patch.coredns] description = Patch CoreDNS to resolve k3d-registry.localhost @@ -958,7 +947,10 @@ commands = [testenv:dev.k8s.start] description = Start local Kubernetes registry & cluster with k3d changedir = {toxinidir} -passenv=* +passenv = * +setenv = + CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} + CLUSTER_HTTP_PORT = {env:CLUSTER_HTTP_PORT:8080} allowlist_externals = bash sleep @@ -968,8 +960,8 @@ commands = tox -e dev.k8s.registry ; for NodePort to work add the following --> -p "NodePort:NodePort@loadbalancer" - bash -c 'k3d cluster create syft-dev -p "8080:80@loadbalancer" --registry-use k3d-registry.localhost:5800; \ - kubectl create namespace syft || true' + bash -c 'k3d cluster create ${CLUSTER_NAME} -p "${CLUSTER_HTTP_PORT}:80@loadbalancer" --registry-use k3d-registry.localhost:5800 && \ + kubectl --context k3d-${CLUSTER_NAME} create namespace syft || true' ; patch coredns tox -e dev.k8s.patch.coredns @@ -980,34 +972,39 @@ commands = [testenv:dev.k8s.deploy] description = Deploy Syft to a local Kubernetes cluster with Devspace changedir = {toxinidir}/packages/grid -passenv=HOME, USER +passenv = HOME, USER, DEVSPACE_PROFILE +setenv= + CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} allowlist_externals = - tox bash commands = ; deploy syft helm charts - bash -c 'devspace deploy -b --kube-context k3d-syft-dev --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800' + bash -c '\ + if [[ -n "${DEVSPACE_PROFILE}" ]]; then export DEVSPACE_PROFILE="-p ${DEVSPACE_PROFILE}"; fi && \ + devspace deploy -b --kube-context k3d-${CLUSTER_NAME} --no-warn ${DEVSPACE_PROFILE} --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800' [testenv:dev.k8s.hotreload] description = Start development with hot-reload in Kubernetes changedir = {toxinidir}/packages/grid -passenv=HOME, USER +passenv = HOME, USER, DEVSPACE_PROFILE +setenv= + CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} allowlist_externals = bash - tox commands = ; deploy syft helm charts with hot-reload - bash -c 'devspace dev --kube-context k3d-syft-dev --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800' + bash -c '\ + if [[ -n "${DEVSPACE_PROFILE}" ]]; then export DEVSPACE_PROFILE="-p ${DEVSPACE_PROFILE}"; fi && \ + devspace dev --kube-context k3d-${CLUSTER_NAME} --no-warn ${DEVSPACE_PROFILE} --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800' [testenv:dev.k8s.info] description = Gather info about the localKubernetes cluster -passenv=HOME, USER +passenv = HOME, USER ignore_errors = True allowlist_externals = k3d kubectl commands = - kubectl config view k3d cluster list kubectl cluster-info kubectl config current-context @@ -1017,22 +1014,21 @@ commands = description = Cleanup Syft deployment and associated resources, but keep the cluster running changedir = {toxinidir}/packages/grid passenv=HOME, USER +setenv= + CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} allowlist_externals = bash commands = - bash -c 'devspace purge --force-purge --kube-context k3d-syft-dev --namespace syft; sleep 3' - bash -c 'devspace cleanup images --kube-context k3d-syft-dev --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800 || true' - bash -c 'kubectl config use-context k3d-syft-dev' - bash -c 'kubectl delete all --all --namespace syft || true' - bash -c 'kubectl delete pvc --all --namespace syft || true' - bash -c 'kubectl delete secret --all --namespace syft || true' - bash -c 'kubectl delete configmap --all --namespace syft || true' - bash -c 'kubectl delete serviceaccount --all --namespace syft || true' + bash -c 'devspace purge --force-purge --kube-context k3d-${CLUSTER_NAME} --no-warn --namespace syft; sleep 3' + bash -c 'devspace cleanup images --kube-context k3d-${CLUSTER_NAME} --no-warn --namespace syft --var CONTAINER_REGISTRY=k3d-registry.localhost:5800 || true' + bash -c 'kubectl --context k3d-${CLUSTER_NAME} delete namespace syft --now=true || true' [testenv:dev.k8s.destroy] description = Destroy local Kubernetes cluster changedir = {toxinidir}/packages/grid -passenv=HOME, USER +passenv = HOME, USER +setenv= + CLUSTER_NAME = {env:CLUSTER_NAME:syft-dev} allowlist_externals = tox bash @@ -1043,18 +1039,16 @@ commands = ; destroy cluster bash -c '\ rm -rf .devspace; echo ""; \ - k3d cluster delete syft-dev; echo ""; \ - kubectl config view' + k3d cluster delete ${CLUSTER_NAME}' [testenv:dev.k8s.destroyall] description = Destroy both local Kubernetes cluster and registry changedir = {toxinidir} -passenv=HOME, USER +passenv = HOME, USER, CLUSTER_NAME ignore_errors=True allowlist_externals = bash tox - rm commands = ; destroy cluster tox -e dev.k8s.destroy @@ -1138,4 +1132,4 @@ commands = fi" - pytest notebooks/api/0.8 --nbmake -p no:randomly -vvvv --nbmake-timeout=1000 -k '{env:EXCLUDE_NOTEBOOKS:}' \ No newline at end of file + pytest notebooks/api/0.8 --nbmake -p no:randomly -vvvv --nbmake-timeout=1000 -k '{env:EXCLUDE_NOTEBOOKS:}'