diff --git a/.github/workflows/build-schema.yml b/.github/workflows/build-schema.yml index 67814d94e..4a1f81c1a 100644 --- a/.github/workflows/build-schema.yml +++ b/.github/workflows/build-schema.yml @@ -7,6 +7,7 @@ on: pull_request: branches: - main + - stable** workflow_dispatch: inputs: ods_branch: @@ -21,7 +22,7 @@ on: jobs: ods_tools: - if: inputs.ods_branch != '' + if: ${{ github.event_name != 'workflow_dispatch' }} uses: OasisLMF/ODS_Tools/.github/workflows/build.yml@main secrets: inherit with: @@ -31,7 +32,9 @@ jobs: if: ${{ ! failure() || ! cancelled() }} needs: ods_tools env: - SCHEMA: 'reports/openapi-schema.json' + SCHEMA_ALL: 'reports/openapi-schema.json' + SCHEMA_V1: 'reports/v1-openapi-schema.json' + SCHEMA_V2: 'reports/v2-openapi-schema.json' runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 @@ -59,15 +62,45 @@ jobs: - name: Generate OpenAPI run: | - test -d $(dirname ${{ env.SCHEMA }}) || mkdir -p $(dirname ${{ env.SCHEMA }}) + test -d $(dirname ${{ env.SCHEMA_ALL }}) || mkdir -p $(dirname ${{ env.SCHEMA_ALL }}) python ./manage.py migrate - python ./manage.py generate_swagger ${{ env.SCHEMA }} + python ./manage.py generate_swagger ${{ env.SCHEMA_ALL }} + + - name: Generate OpenAPI (only v1) + run: | + test -d $(dirname ${{ env.SCHEMA_V1 }}) || mkdir -p $(dirname ${{ env.SCHEMA_V1 }}) + python ./manage.py migrate + python ./manage.py generate_swagger ${{ env.SCHEMA_V1 }} + env: + OASIS_GEN_SWAGGER_V1: 1 + + - name: Generate OpenAPI (only v2) + run: | + test -d $(dirname ${{ env.SCHEMA_V2 }}) || mkdir -p $(dirname ${{ env.SCHEMA_V2 }}) + python ./manage.py migrate + python ./manage.py generate_swagger ${{ env.SCHEMA_V2 }} + env: + OASIS_GEN_SWAGGER_V2: 1 - name: Store OpenAPI schema uses: actions/upload-artifact@v3 with: name: openapi-schema - path: ${{ env.SCHEMA }} + path: ${{ env.SCHEMA_ALL }} + retention-days: 3 + + - name: Store OpenAPI schema (only v1) + uses: actions/upload-artifact@v3 + with: + name: v1-openapi-schema + path: ${{ env.SCHEMA_V1 }} + retention-days: 3 + + - name: Store OpenAPI schema (only v2) + uses: actions/upload-artifact@v3 + with: + name: v2-openapi-schema + path: ${{ env.SCHEMA_V2 }} retention-days: 3 - name: Test Schema diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 1464bdba9..2f39f0125 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -30,7 +30,7 @@ jobs: - name: Run (partial) flake8 if: ${{ ! cancelled() }} - run: flake8 --select F401,F522,F524,F541 --show-source src/ + run: flake8 --select F401,F522,F524,F541 --show-source src/ --config tox.ini - name: check PEP8 compliance if: ${{ ! cancelled() }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c893e76ae..b395d10b1 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -96,7 +96,7 @@ jobs: env: pre_release: ${{ inputs.pre_release == '' && 'false' || inputs.pre_release }} push_latest: ${{ inputs.push_latest == '' && 'false' || inputs.push_latest }} - latest_tag: '2-latest' + latest_tag: 'latest' release_tag: ${{ inputs.release_tag }} prev_release_tag: ${{ inputs.prev_release_tag }} @@ -367,12 +367,31 @@ jobs: name: openapi-schema path: ${{ github.workspace }}/ + - name: Download API schema (v1) + uses: actions/download-artifact@v3 + with: + name: v1-openapi-schema + path: ${{ github.workspace }}/ + + - name: Download API schema (v2) + uses: actions/download-artifact@v3 + with: + name: v2-openapi-schema + path: ${{ github.workspace }}/ + - name: Name API schema id: api_schema run: | schema_filename="openapi-schema-${{ env.release_tag }}.json" + v1_schema_filename="v1-openapi-schema-${{ env.release_tag }}.json" + v2_schema_filename="v2-openapi-schema-${{ env.release_tag }}.json" + mv openapi-schema.json $schema_filename - echo "filename=$schema_filename" >> $GITHUB_OUTPUT + mv v1-openapi-schema.json $v1_schema_filename + mv v2-openapi-schema.json $v2_schema_filename + echo "filename_all=$schema_filename" >> $GITHUB_OUTPUT + echo "filename_v1=$v1_schema_filename" >> $GITHUB_OUTPUT + echo "filename_v2=$v2_schema_filename" >> $GITHUB_OUTPUT # --- Create Changelog --- # - name: Tag Release @@ -464,17 +483,38 @@ jobs: prerelease: ${{ env.pre_release }} # --- Attach build assest --- # - - name: Upload Schema + - name: Upload Schema (base) id: upload-source-release-asset uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{ secrets.BUILD_GIT_TOKEN }} with: upload_url: ${{ steps.create_release.outputs.upload_url }} - asset_path: ${{ github.workspace }}/${{ steps.api_schema.outputs.filename }} - asset_name: ${{ steps.api_schema.outputs.filename }} + asset_path: ${{ github.workspace }}/${{ steps.api_schema.outputs.filename_all }} + asset_name: ${{ steps.api_schema.outputs.filename_all }} + asset_content_type: application/json + + - name: Upload Schema (v1) + id: upload-source-release-asset-v1 + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.BUILD_GIT_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ${{ github.workspace }}/${{ steps.api_schema.outputs.filename_v1 }} + asset_name: ${{ steps.api_schema.outputs.filename_v1 }} asset_content_type: application/json + - name: Upload Schema (v2) + id: upload-source-release-asset-v2 + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.BUILD_GIT_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ${{ github.workspace }}/${{ steps.api_schema.outputs.filename_v2 }} + asset_name: ${{ steps.api_schema.outputs.filename_v2 }} + asset_content_type: application/json # --- Slack notify --- # - name: slack message vars id: slack_vars diff --git a/.github/workflows/scan-external.yml b/.github/workflows/scan-external.yml index eec7851b7..850b18f37 100644 --- a/.github/workflows/scan-external.yml +++ b/.github/workflows/scan-external.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - main + - stable** workflow_dispatch: inputs: diff --git a/.github/workflows/scan.yml b/.github/workflows/scan.yml index 1cf7c5bdc..16e0f4e38 100644 --- a/.github/workflows/scan.yml +++ b/.github/workflows/scan.yml @@ -7,6 +7,7 @@ on: pull_request: branches: - main + - stable** schedule: - cron: '0 */6 * * *' # Run scan every 6 hours diff --git a/.github/workflows/test-images.yml b/.github/workflows/test-images.yml index 5066b6f74..3ae27265e 100644 --- a/.github/workflows/test-images.yml +++ b/.github/workflows/test-images.yml @@ -7,6 +7,7 @@ on: pull_request: branches: - main + - stable** workflow_dispatch: inputs: last_release: @@ -52,7 +53,13 @@ jobs: outputs: pytest_opts: ${{ steps.pytest.outputs.opts }} piwind_branch: ${{ steps.piwind.outputs.branch }} + release_tag: ${{ steps.released_images.outputs.prev_release_tag }} + release_stable_15: ${{ steps.released_images.outputs.stable_15 }} + release_stable_23: ${{ steps.released_images.outputs.stable_23 }} + release_stable_27: ${{ steps.released_images.outputs.stable_27 }} + release_stable_28: ${{ steps.released_images.outputs.stable_28 }} + build_server_img: ${{ steps.built_images.outputs.server_img }} build_server_tag: ${{ steps.built_images.outputs.server_tag }} build_worker_img: ${{ steps.built_images.outputs.worker_img }} @@ -75,14 +82,26 @@ jobs: echo "prev_release_tag=$tag" >> $GITHUB_OUTPUT # Find tags release accross all branches, limited to matching semver elif [[ -z "${{ inputs.last_release }}" ]]; then - tag=$( ./scripts/find_latest.sh -j "${{ env.semver_major }}" ) - #tag=$( ./scripts/find_latest.sh -j "${{ env.semver_major }}" -i "${{ env.semver_minor }}" ) + tag=$( ./scripts/find_latest.sh -j "${{ env.semver_major }}" -i "${{ env.semver_minor }}" ) echo "prev_release_tag=$tag" >> $GITHUB_OUTPUT # Don't search, use the given input else echo "prev_release_tag=${{ inputs.last_release }}" >> $GITHUB_OUTPUT fi + # Find latest LTS from each stable branch + stable_1_15=$( ./scripts/find_latest.sh -j 1 -i 15 ) + echo "stable_15=$stable_1_15" >> $GITHUB_OUTPUT + + stable_1_23=$( ./scripts/find_latest.sh -j 1 -i 23 ) + echo "stable_23=$stable_1_23" >> $GITHUB_OUTPUT + + stable_1_27=$( ./scripts/find_latest.sh -j 1 -i 27 ) + echo "stable_27=$stable_1_27" >> $GITHUB_OUTPUT + + stable_1_28=$( ./scripts/find_latest.sh -j 1 -i 28 ) + echo "stable_28=$stable_1_28" >> $GITHUB_OUTPUT + - name: Select PiWind branch id: piwind run: | @@ -93,12 +112,12 @@ jobs: BRANCH=${{ github.ref_name }} else BRANCH=${{ inputs.piwind_branch }} - fi + fi #override 'main-platform1' -> 'main' if [[ "$BRANCH" = 'main-platform1' ]]; then BRANCH=main - fi + fi echo "branch=$BRANCH" >> $GITHUB_OUTPUT @@ -130,8 +149,8 @@ jobs: echo "deb_worker_img=$deb_worker_img" >> $GITHUB_OUTPUT echo "deb_worker_tag=$deb_worker_tag" >> $GITHUB_OUTPUT - worker_all_checks: - name: PiWind all checks + all_checks_v1: + name: V1 all checks secrets: inherit needs: [setup] uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main @@ -141,10 +160,59 @@ jobs: server_tag: ${{ needs.setup.outputs.build_server_tag }} worker_image: ${{ needs.setup.outputs.build_worker_img }} worker_tag: ${{ needs.setup.outputs.build_worker_tag }} + worker_api_ver: 'v1' debug_mode: 1 - pytest_opts: "--docker-compose=./docker/plat2.docker-compose.yml " + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml " storage_suffix: '-all-checks' + all_checks_v2: + name: V2 all checks + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: ${{ needs.setup.outputs.piwind_branch }} + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: ${{ needs.setup.outputs.build_worker_img }} + worker_tag: ${{ needs.setup.outputs.build_worker_tag }} + worker_api_ver: 'v2' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml " + storage_suffix: '-all-checks' + + storage_s3_v1: + name: V1 Storage Compatibility (S3) + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: ${{ needs.setup.outputs.piwind_branch }} + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: ${{ needs.setup.outputs.build_worker_img }} + worker_tag: ${{ needs.setup.outputs.build_worker_tag }} + worker_api_ver: 'v1' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.s3.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + storage_suffix: '_s3' + + storage_s3_v2: + name: V2 Storage Compatibility (S3) + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: ${{ needs.setup.outputs.piwind_branch }} + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: ${{ needs.setup.outputs.build_worker_img }} + worker_tag: ${{ needs.setup.outputs.build_worker_tag }} + worker_api_ver: 'v2' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.s3.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + storage_suffix: '_s3' + worker_debian: name: Worker Debian secrets: inherit @@ -156,51 +224,72 @@ jobs: server_tag: ${{ needs.setup.outputs.build_server_tag }} worker_image: ${{ needs.setup.outputs.build_deb_worker_img }} worker_tag: ${{ needs.setup.outputs.build_deb_worker_tag }} + worker_api_ver: 'v2' debug_mode: 1 - pytest_opts: "--docker-compose=./docker/plat2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" storage_suffix: '-worker-debian' -# server_compatibility: -# name: Server Compatibility (${{ needs.setup.outputs.release_tag }}) -# secrets: inherit -# needs: [setup] -# uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main -# with: -# piwind_branch: ${{ needs.setup.outputs.piwind_branch }} -# server_image: ${{ needs.setup.outputs.build_server_img }} -# server_tag: ${{ needs.setup.outputs.build_server_tag }} -# worker_image: 'coreoasis/model_worker' -# worker_tag: ${{ needs.setup.outputs.release_tag }} -# debug_mode: 1 -# pytest_opts: "--docker-compose=./docker/plat2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" -# storage_suffix: '-server-compatibility' -# -# worker_compatibility: -# name: Worker Compatibility (${{ needs.setup.outputs.release_tag }}) -# secrets: inherit -# needs: [setup] -# uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main -# with: -# piwind_branch: ${{ needs.setup.outputs.piwind_branch }} -# server_image: 'coreoasis/api_server' -# server_tag: ${{ needs.setup.outputs.release_tag }} -# worker_image: ${{ needs.setup.outputs.build_worker_img }} -# worker_tag: ${{ needs.setup.outputs.build_worker_tag }} -# debug_mode: 1 -# pytest_opts: "--docker-compose=./docker/plat2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" -# storage_suffix: '-worker-compatibility' - -# storage_s3: -# name: Storage Compatibility (S3) -# secrets: inherit -# needs: [setup] -# uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main -# with: -# piwind_branch: ${{ needs.setup.outputs.piwind_branch }} -# server_image: ${{ needs.setup.outputs.build_server_img }} -# server_tag: ${{ needs.setup.outputs.build_server_tag }} -# worker_image: ${{ needs.setup.outputs.build_worker_img }} -# worker_tag: ${{ needs.setup.outputs.build_worker_tag }} -# debug_mode: 0 -# pytest_opts: "--docker-compose=./docker/plat2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" -# storage_suffix: '-s3' + stable_compatibility_15: + name: Test stable worker (${{ needs.setup.outputs.release_stable_15 }}) + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: 'stable/1.15.x' + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: 'coreoasis/model_worker' + worker_tag: ${{ needs.setup.outputs.release_stable_15 }} + worker_api_ver: 'v1' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + storage_suffix: "_worker-${{ needs.setup.outputs.release_stable_15 }}" + + stable_compatibility_23: + name: Test stable worker (${{ needs.setup.outputs.release_stable_23 }}) + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: 'stable/1.23.x' + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: 'coreoasis/model_worker' + worker_tag: ${{ needs.setup.outputs.release_stable_23 }} + worker_api_ver: 'v1' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + storage_suffix: "_worker-${{ needs.setup.outputs.release_stable_23 }}" + + stable_compatibility_27: + name: Test stable worker (${{ needs.setup.outputs.release_stable_27 }}) + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: 'stable/1.27.x' + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: 'coreoasis/model_worker' + worker_tag: ${{ needs.setup.outputs.release_stable_27 }} + worker_api_ver: 'v1' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + storage_suffix: "_worker-${{ needs.setup.outputs.release_stable_27 }}" + + stable_compatibility_28: + name: Test stable worker (${{ needs.setup.outputs.release_stable_28 }}) + secrets: inherit + needs: [setup] + uses: OasisLMF/OasisPiWind/.github/workflows/integration.yml@main + with: + piwind_branch: 'stable/1.28.x' + server_image: ${{ needs.setup.outputs.build_server_img }} + server_tag: ${{ needs.setup.outputs.build_server_tag }} + worker_image: 'coreoasis/model_worker' + worker_tag: ${{ needs.setup.outputs.release_stable_28 }} + worker_api_ver: 'v1' + debug_mode: 1 + pytest_opts: "--docker-compose=./docker/plat2-v2.docker-compose.yml ${{ needs.setup.outputs.pytest_opts }}" + storage_suffix: "_worker-${{ needs.setup.outputs.release_stable_28 }}" + diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python_api.yml similarity index 97% rename from .github/workflows/test-python.yml rename to .github/workflows/test-python_api.yml index 49b4e5da1..08358206b 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python_api.yml @@ -1,4 +1,4 @@ -name: Platform Python Tests +name: Python Tests - Platform API on: push: @@ -7,6 +7,7 @@ on: pull_request: branches: - main + - stable** workflow_dispatch: inputs: ods_branch: diff --git a/.github/workflows/test-python_worker-controller.yml b/.github/workflows/test-python_worker-controller.yml new file mode 100644 index 000000000..caf009f91 --- /dev/null +++ b/.github/workflows/test-python_worker-controller.yml @@ -0,0 +1,52 @@ +name: Python Tests - Worker Controller + +on: + push: + branches: + - main + pull_request: + branches: + - main + - stable** + workflow_dispatch: + workflow_call: + +jobs: + unittest: + env: + JUNIT_REPORT: pytest_worker-controller_report.xml + PLAT_BRANCH: ${{ github.ref }} + runs-on: ubuntu-22.04 + + steps: + - name: Branch selection (remote trigger) + if: inputs.platform_branch != '' + run: echo "PLAT_BRANCH=${{ inputs.platform_branch }}" >> $GITHUB_ENV + + - name: Checkout + uses: actions/checkout@v3 + with: + repository: OasisLMF/OasisPlatform + ref: ${{ env.PLAT_BRANCH }} + + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - run: | + pip install pytest + pip install -r kubernetes/worker-controller/requirements.txt + + - name: Run Pytest + run: | + cd kubernetes/worker-controller/src + python -m pytest -v + +# - name: Generate Report +# uses: dorny/test-reporter@v1 +# if: success() || failure() # run this step even if previous step failed +# with: +# name: Pytest Results # Name of the check run which will be created +# path: ${{ env.JUNIT_REPORT }} # Path to test results +# reporter: java-junit # Format of test results +# fail-on-error: 'false' diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 49209d14a..741586757 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,38 @@ OasisPlatform Changelog ======================= +`2.3.0`_ + --------- +* [#898](https://github.com/OasisLMF/OasisPlatform/pull/898) - Fix ods-tools changelog call +* [#869](https://github.com/OasisLMF/OasisPlatform/pull/899) - Worker Controller crashing under heavy load +* [#897](https://github.com/OasisLMF/OasisPlatform/pull/906) - Collected WebSocket bug fixes +* [#912](https://github.com/OasisLMF/OasisPlatform/pull/912) - Fix syntax in flower chart template +* [#913](https://github.com/OasisLMF/OasisPlatform/pull/914) - Add ENV var to disable http in websocket pod +* [#918](https://github.com/OasisLMF/OasisPlatform/pull/918) - Fix worker_count_max assigment +* [#920](https://github.com/OasisLMF/OasisPlatform/pull/921) - ODS Tools link in release notes points to OasisLMF repo +* [#929](https://github.com/OasisLMF/OasisPlatform/pull/930) - Platform 2 - Keycloak DB reset on restart or redeployment. +* [#893](https://github.com/OasisLMF/OasisPlatform/pull/931) - Support Platform 1 workers on the v2 server +* [#942](https://github.com/OasisLMF/OasisPlatform/pull/942) - Updated Package Requirements: oasislmf==1.28.5 ods-tools==3.1.4 +* [#928, #681](https://github.com/OasisLMF/OasisPlatform/pull/944) - Added chunking options to analysis level +* [#905, #786](https://github.com/OasisLMF/OasisPlatform/pull/945) - Fixed generate and run endpoint +* [#818](https://github.com/OasisLMF/OasisPlatform/pull/818) - Update/remote trig python tests +* [#910](https://github.com/OasisLMF/OasisPlatform/pull/947) - Add post analysis hook to platform 2 workflow +* [#890](https://github.com/OasisLMF/OasisPlatform/pull/948) - Fetch a model's versions when auto-registration is disabled +* [#903](https://github.com/OasisLMF/OasisPlatform/pull/950) - File linking OED from sub-directories fails to link inside workers +* [#953](https://github.com/OasisLMF/OasisPlatform/pull/954) - Platform 2.1.3 - No free channel ids error +* [#955](https://github.com/OasisLMF/OasisPlatform/pull/955) - Revert "Always post model version info on worker startup (platform 2)… +* [#951](https://github.com/OasisLMF/OasisPlatform/pull/956) - Allow 'single instance' execution from v2 api +* [#952](https://github.com/OasisLMF/OasisPlatform/pull/957) - Cleaner split between v1 and v2 OpenAPI schemas +* [#960](https://github.com/OasisLMF/OasisPlatform/pull/960) - Update external images & python packages (2.3.0 release) +* [#961](https://github.com/OasisLMF/OasisPlatform/pull/963) - Remove the python3-pip from production server images +* [#966](https://github.com/OasisLMF/OasisPlatform/pull/966) - Fix broken swagger calls when SUB_PATH_URL=True +* [#968](https://github.com/OasisLMF/OasisPlatform/pull/968) - Fix model registration script for v1 workers +* [#857](https://github.com/OasisLMF/OasisPlatform/pull/857) - Release 2.2.1 (staging) +* [#872](https://github.com/OasisLMF/OasisPlatform/pull/882) - Investigate flower error in monitoring chart +* [#871](https://github.com/OasisLMF/OasisPlatform/pull/883) - Handle exceptions from OedExposure on file Upload +* [#702](https://github.com/OasisLMF/OasisPlatform/pull/886) - Fix worker controller stablility +.. _`2.3.0`: https://github.com/OasisLMF/OasisPlatform/compare/2.2.1...2.3.0 + `2.2.1`_ --------- * [#849](https://github.com/OasisLMF/OasisPlatform/pull/849) - Feautre/1323 reorganize branches plat2 diff --git a/Dockerfile.api_server b/Dockerfile.api_server index 2914fd809..7417b8a79 100755 --- a/Dockerfile.api_server +++ b/Dockerfile.api_server @@ -20,7 +20,7 @@ USER server FROM ubuntu:22.04 RUN apt-get update \ && apt-get upgrade -y \ - && apt-get install -y --no-install-recommends sudo python3 python3-pip curl libmariadbclient-dev-compat \ + && apt-get install -y --no-install-recommends sudo python3 python3-pkg-resources curl libmariadbclient-dev-compat \ && rm -rf /var/lib/apt/lists/ RUN adduser --home /home/server --shell /bin/bash --disabled-password --gecos "" server diff --git a/Dockerfile.model_worker b/Dockerfile.model_worker index 339d98018..8b2b96ac8 100755 --- a/Dockerfile.model_worker +++ b/Dockerfile.model_worker @@ -41,15 +41,12 @@ WORKDIR /home/worker COPY ./requirements-worker.txt ./requirements.txt COPY --chown=worker:worker ./conf.ini ./ COPY ./src/startup_worker.sh ./startup.sh -COPY ./src/startup_tester.sh ./runtest -COPY ./src/startup_tester_S3.sh ./runtest_S3 COPY ./src/__init__.py ./src/ COPY ./src/common ./src/common/ COPY ./src/conf ./src/conf/ COPY ./src/model_execution_worker/ ./src/model_execution_worker/ COPY ./src/utils/ ./src/utils/ COPY ./src/utils/worker_bashrc /home/worker/.bashrc -COPY ./tests/integration /home/worker/tests/integration # Add required directories RUN mkdir -p /var/oasis && \ diff --git a/Dockerfile.model_worker_debian b/Dockerfile.model_worker_debian index 57e5b9933..d48a89d7d 100644 --- a/Dockerfile.model_worker_debian +++ b/Dockerfile.model_worker_debian @@ -45,8 +45,6 @@ RUN if [ ! -z "$ods_tools_branch" ] ; then \ # Copy startup script + server config COPY ./src/startup_worker.sh ./startup.sh -COPY ./src/startup_tester.sh ./runtest -COPY ./src/startup_tester_S3.sh ./runtest_S3 COPY ./conf.ini ./ COPY ./src/__init__.py ./src/ COPY ./src/common ./src/common/ @@ -54,6 +52,5 @@ COPY ./src/conf ./src/conf/ COPY ./src/model_execution_worker/ ./src/model_execution_worker/ COPY ./src/utils/ ./src/utils/ COPY ./src/utils/worker_bashrc /home/worker/.bashrc -COPY ./tests/integration /home/worker/tests/integration ENTRYPOINT ./startup.sh diff --git a/VERSION b/VERSION index c043eea77..276cbf9e2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.2.1 +2.3.0 diff --git a/conf.ini b/conf.ini index 88a1fa103..e4667ca92 100644 --- a/conf.ini +++ b/conf.ini @@ -19,6 +19,8 @@ DO_GZIP_RESPONSE = True SECRET_KEY=OmuudYrSFVxcIVIWf6YlYdkP6NXApP TOKEN_SIGINING_KEY=JsVzvtWw2EwksaYCZsMmd2zmm TOKEN_REFRESH_ROTATE = True +DISABLE_V2_API = False + #PORTFOLIO_PARQUET_STORAGE = True #TOKEN_REFRESH_LIFETIME = minutes=0, hours=0, days=0, weeks=0 diff --git a/docker-compose.yml b/docker-compose.yml index c545d3e88..11b868472 100755 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,10 +23,24 @@ x-shared-env: &shared-env OASIS_CELERY_DB_PORT: 5432 OASIS_INPUT_GENERATION_CONTROLLER_QUEUE: task-controller OASIS_LOSSES_GENERATION_CONTROLLER_QUEUE: task-controller + +x-oasis-env-v1: &oasis-env-v1 + OASIS_DEBUG: ${DEBUG:-0} + OASIS_RABBIT_HOST: broker + OASIS_RABBIT_PORT: 5672 + OASIS_RABBIT_USER: rabbit + OASIS_RABBIT_PASS: rabbit + OASIS_CELERY_DB_ENGINE: db+postgresql+psycopg2 + OASIS_CELERY_DB_HOST: celery-db + OASIS_CELERY_DB_PASS: password + OASIS_CELERY_DB_USER: celery + OASIS_CELERY_DB_NAME: celery + OASIS_CELERY_DB_PORT: 5432 + x-volumes: &shared-volumes - filestore-OasisData:/shared-fs:rw services: - server_http: + server: restart: always image: coreoasis/api_server:dev command: ["./wsgi/run-wsgi.sh"] @@ -47,7 +61,7 @@ services: OASIS_ADMIN_PASS: password volumes: - filestore-OasisData:/shared-fs:rw - - ./src/server/oasisapi/analyses:/var/www/oasis/src/server/oasisapi/analyses + - ./src/server/oasisapi:/var/www/oasis/src/server/oasisapi server_websocket: restart: always image: coreoasis/api_server:dev @@ -62,11 +76,11 @@ services: <<: *shared-env volumes: - filestore-OasisData:/shared-fs:rw - - ./src/server/oasisapi/analyses:/var/www/oasis/src/server/oasisapi/analyses - worker-monitor: + - ./src/server/oasisapi:/var/www/oasis/src/server/oasisapi + v1-worker-monitor: restart: always image: coreoasis/api_server:dev - command: [celery, -A, src.server.oasisapi.celery_app, worker, --loglevel=INFO] + command: [celery, -A, 'src.server.oasisapi.celery_app_v1', worker, --loglevel=INFO,] links: - server-db - celery-db @@ -75,11 +89,24 @@ services: <<: *shared-env volumes: - filestore-OasisData:/shared-fs:rw - - ./src/server/oasisapi/analyses:/var/www/oasis/src/server/oasisapi/analyses - task-controller: + - ./src/server/oasisapi:/var/www/oasis/src/server/oasisapi + v2-worker-monitor: + restart: always + image: coreoasis/api_server:dev + command: [celery, -A, 'src.server.oasisapi.celery_app_v2', worker, --loglevel=INFO, -Q, celery-v2] + links: + - server-db + - celery-db + - broker + environment: + <<: *shared-env + volumes: + - filestore-OasisData:/shared-fs:rw + - ./src/server/oasisapi:/var/www/oasis/src/server/oasisapi + v2-task-controller: restart: always image: coreoasis/api_server:dev - command: [celery, -A, src.server.oasisapi.celery_app, worker, --loglevel=INFO, -Q, task-controller] + command: [celery, -A, 'src.server.oasisapi.celery_app_v2', worker, --loglevel=INFO, -Q, task-controller] links: - server-db - celery-db @@ -89,10 +116,10 @@ services: volumes: - filestore-OasisData:/shared-fs:rw - ./src/server/oasisapi/analyses:/var/www/oasis/src/server/oasisapi/analyses - celery-beat: + celery-beat_v2: restart: always image: coreoasis/api_server:dev - command: [celery, -A, src.server.oasisapi.celery_app, beat, --loglevel=INFO] + command: [celery, -A, src.server.oasisapi.celery_app_v2, beat, --loglevel=INFO] links: - server-db - celery-db @@ -100,7 +127,41 @@ services: environment: <<: *shared-env volumes: *shared-volumes - worker: + stable-worker: + restart: always + image: coreoasis/model_worker:1.28.4 + links: + - celery-db + - broker:mybroker + environment: + <<: *oasis-env-v1 + OASIS_MODEL_DATA_DIRECTORY: /home/worker/model + OASIS_MODEL_SUPPLIER_ID: OasisLMF + OASIS_MODEL_ID: PiWind + OASIS_MODEL_VERSION_ID: '1.28.4' + volumes: + - ${OASIS_MODEL_DATA_DIR:-./data/static}:/home/worker/model:rw + - filestore-OasisData:/shared-fs:rw + v1-worker: + restart: always + image: coreoasis/model_worker:dev + build: + context: . + dockerfile: Dockerfile.model_worker + links: + - celery-db + - broker:mybroker + environment: + <<: *shared-env + OASIS_MODEL_SUPPLIER_ID: OasisLMF + OASIS_MODEL_ID: PiWind + OASIS_MODEL_VERSION_ID: 'v1' + OASIS_RUN_MODE: v1 + volumes: + - ${OASIS_MODEL_DATA_DIR:-./data/static}:/home/worker/model:rw + - ./src/model_execution_worker:/home/worker/src/model_execution_worker + - filestore-OasisData:/shared-fs:rw + v2-worker: restart: always image: coreoasis/model_worker:dev build: @@ -113,8 +174,8 @@ services: <<: *shared-env OASIS_MODEL_SUPPLIER_ID: OasisLMF OASIS_MODEL_ID: PiWind - OASIS_MODEL_VERSION_ID: 1 - OASIS_MODEL_NUM_ANALYSIS_CHUNKS: 8 + OASIS_MODEL_VERSION_ID: 'v2' + OASIS_RUN_MODE: v2 volumes: - ${OASIS_MODEL_DATA_DIR:-./data/static}:/home/worker/model:rw - ./src/model_execution_worker:/home/worker/src/model_execution_worker @@ -169,3 +230,14 @@ services: image: redis:5.0.7 ports: - 6379:6379 + user-interface: + restart: always + image: coreoasis/oasisui_app:${VERS_UI:-latest} + container_name: oasisui_app + environment: + - API_IP=server:8000/api/ + - API_VERSION=v2 + - API_SHARE_FILEPATH=./downloads + - OASIS_ENVIRONMENT=oasis_localhost + ports: + - 8080:3838 diff --git a/jenkins/oasis_platform.groovy b/jenkins/oasis_platform.groovy deleted file mode 100644 index 7c15d7c75..000000000 --- a/jenkins/oasis_platform.groovy +++ /dev/null @@ -1,660 +0,0 @@ -//JOB TEMPLATE -def createStage(stage_name, stage_params, propagate_flag) { - return { - stage("Test: ${stage_name}") { - build job: "${stage_name}", parameters: stage_params, propagate: propagate_flag - } - } -} - -// LIST of default models sub-jobs to trigger as part of regression testing -def model_regression_list = """ -oasis_PiWind/develop -GemFoundation_GMO/master -corelogic_quake/develop -""" - -node { - hasFailed = false - sh 'sudo /var/lib/jenkins/jenkins-chown' - deleteDir() // wipe out the workspace - - - properties([ - parameters([ - [$class: 'StringParameterDefinition', description: "Oasis Build scripts branch", name: 'BUILD_BRANCH', defaultValue: 'master'], - [$class: 'StringParameterDefinition', description: "OasisPlatform branch", name: 'PLATFORM_BRANCH', defaultValue: BRANCH_NAME], - [$class: 'StringParameterDefinition', description: "Install OasisLMF from branch", name: 'MDK_BRANCH', defaultValue: 'develop'], - [$class: 'StringParameterDefinition', description: "Test API/Worker using PiWind branch", name: 'PIWIND_BRANCH', defaultValue: 'develop'], - [$class: 'StringParameterDefinition', description: "Release tag to publish", name: 'RELEASE_TAG', defaultValue: BRANCH_NAME.split('/').last() + "-${BUILD_NUMBER}"], - [$class: 'StringParameterDefinition', description: "Last release, for changelog", name: 'PREV_RELEASE_TAG', defaultValue: ""], - [$class: 'StringParameterDefinition', description: "OasisLMF release notes ref", name: 'OASISLMF_TAG', defaultValue: ""], - [$class: 'StringParameterDefinition', description: "OasisLMF prev release notes ref", name: 'OASISLMF_PREV_TAG', defaultValue: ""], - [$class: 'StringParameterDefinition', description: "Ktools release notes ref", name: 'KTOOLS_TAG', defaultValue: ""], - [$class: 'StringParameterDefinition', description: "Ktools prev release notes ref", name: 'KTOOLS_PREV_TAG', defaultValue: ""], - [$class: 'StringParameterDefinition', description: "CVE Rating that fails a build", name: 'SCAN_IMAGE_VULNERABILITIES', defaultValue: "HIGH,CRITICAL"], - [$class: 'StringParameterDefinition', description: "CVE Rating that fails a build", name: 'SCAN_REPO_VULNERABILITIES', defaultValue: "HIGH,CRITICAL"], - [$class: 'TextParameterDefinition', description: "List of models for Regression tests", name: 'MODEL_REGRESSION', defaultValue: model_regression_list], - [$class: 'BooleanParameterDefinition', description: "Test previous API and Worker", name: 'CHECK_COMPATIBILITY', defaultValue: Boolean.valueOf(false)], - [$class: 'BooleanParameterDefinition', description: "Test S3 storage using LocalStack", name: 'CHECK_S3', defaultValue: Boolean.valueOf(false)], - [$class: 'BooleanParameterDefinition', description: "Run API unittests", name: 'UNITTEST', defaultValue: Boolean.valueOf(true)], - [$class: 'BooleanParameterDefinition', description: "Run Regression checks", name: 'RUN_REGRESSION', defaultValue: Boolean.valueOf(false)], - [$class: 'BooleanParameterDefinition', description: "Purge docker images on completion", name: 'PURGE', defaultValue: Boolean.valueOf(true)], - [$class: 'BooleanParameterDefinition', description: "Create release if checked", name: 'PUBLISH', defaultValue: Boolean.valueOf(false)], - [$class: 'BooleanParameterDefinition', description: "Mark as pre-released software", name: 'PRE_RELEASE', defaultValue: Boolean.valueOf(true)], - [$class: 'BooleanParameterDefinition', description: "Perform a gitflow merge", name: 'AUTO_MERGE', defaultValue: Boolean.valueOf(true)], - [$class: 'BooleanParameterDefinition', description: "Send build status to slack", name: 'SLACK_MESSAGE', defaultValue: Boolean.valueOf(true)] - ]) - ]) - - // Build vars - String build_repo = 'git@github.com:OasisLMF/build.git' - String build_branch = params.BUILD_BRANCH - String build_workspace = 'oasis_build' - - // docker vars (main) - String docker_api = "Dockerfile.api_server" - String docker_worker = "Dockerfile.model_worker" - String docker_worker_debian = "Dockerfile.model_worker_debian" - String docker_piwind = "docker/Dockerfile.piwind_worker" - String docker_controller = 'Dockerfile' - - String image_api = "coreoasis/api_server" - String image_worker = "coreoasis/model_worker" - String image_piwind = "coreoasis/piwind_worker" - String image_controller = 'coreoasis/worker_controller' - - // docker vars (slim) - //String docker_api_slim = "docker/Dockerfile.api_server_alpine" - String docker_worker_slim = "docker/Dockerfile.model_worker_slim" - - // platform vars - String oasis_branch = params.PLATFORM_BRANCH // Git repo branch to build from - String mdk_branch = params.MDK_BRANCH - String oasis_name = 'OasisPlatform' - String oasis_git_url = "git@github.com:OasisLMF/${oasis_name}.git" - String oasis_workspace = 'platform_workspace' - String utils_sh = '/buildscript/utils.sh' - String oasis_func = "oasis_server" - - // oasis base model test - String model_branch = params.PIWIND_BRANCH - String model_name = 'OasisPiWind' - String model_tests = 'control_set' - String model_workspace = "${model_name}_workspace" - String model_git_url = "git@github.com:OasisLMF/${model_name}.git" - String model_test_dir = "${env.WORKSPACE}/${model_workspace}/tests/" - String model_test_ini = "test-config.ini" - - String script_dir = env.WORKSPACE + "/${build_workspace}" - String git_creds = "1335b248-336a-47a9-b0f6-9f7314d6f1f4" - String PIPELINE = script_dir + "/buildscript/pipeline.sh" - - // Docker image scanning - String mnt_docker_socket = "-v /var/run/docker.sock:/var/run/docker.sock" - String mnt_output_report = "-v ${env.WORKSPACE}/${oasis_workspace}/image_reports:/tmp" - String mnt_scan_report = "-v ${env.WORKSPACE}/${oasis_workspace}/scan_reports:/tmp" - String mnt_repo = "-v ${env.WORKSPACE}/${oasis_workspace}:/mnt" - String mnt_server_deps = "-v ${env.WORKSPACE}/${oasis_workspace}/requirements-server.txt:/mnt/requirements.txt" - String mnt_worker_deps = "-v ${env.WORKSPACE}/${oasis_workspace}/requirements-worker.txt:/mnt/requirements.txt" - String mnt_controller_deps = "-v ${env.WORKSPACE}/${oasis_workspace}/kubernetes/worker-controller/requirements.txt:/mnt/requirements.txt" - - // Update MDK branch based on model branch - if (BRANCH_NAME.matches("master") || BRANCH_NAME.matches("hotfix/(.*)")){ - MDK_BRANCH='master' - MODEL_BRANCH='master' - } - - //make sure release candidate versions are tagged correctly - if (params.PUBLISH && params.PRE_RELEASE && ! params.RELEASE_TAG.matches('^(\\d+\\.)(\\d+\\.)(\\*|\\d+)rc(\\d+)$')) { - sh "echo release candidates must be tagged {version}rc{N}, example: 1.0.0rc1" - sh "exit 1" - } - - // Set Global ENV - env.PIPELINE_LOAD = script_dir + utils_sh - - env.OASIS_MODEL_DATA_DIR = "${env.WORKSPACE}/${model_workspace}" - //env.TAG_BASE = params.BASE_TAG //Build TAG for base set of images - env.TAG_RELEASE = params.RELEASE_TAG //Build TAG for TARGET image - env.TAG_RUN_PLATFORM = params.RELEASE_TAG - env.TAG_RUN_WORKER = params.RELEASE_TAG - env.COMPOSE_PROJECT_NAME = UUID.randomUUID().toString().replaceAll("-","") - - env.IMAGE_WORKER = image_piwind - // Should read these values from test/conf.ini - env.TEST_MAX_RUNTIME = '190' - env.TEST_DATA_DIR = model_test_dir - env.MDK_CONFIG = '/home/worker/model/oasislmf.json' - env.MODEL_SUPPLIER = 'OasisLMF' - env.MODEL_VARIENT = 'PiWind' - env.MODEL_ID = '1' - sh 'env' - - - // Param Publish Guards - if (params.PUBLISH && ! ( oasis_branch.matches("release/(.*)") || oasis_branch.matches("hotfix/(.*)") || oasis_branch.matches("backports/(.*)")) ){ - println("Publish Only allowed on a release/* or hotfix/* branches") - sh "exit 1" - } - - try { - parallel( - clone_oasis_build: { - stage('Clone: ' + build_workspace) { - dir(build_workspace) { - git url: build_repo, credentialsId: git_creds, branch: build_branch - } - } - }, - clone_oasis_model: { - stage('Clone: ' + model_workspace) { - dir(model_workspace) { - git url: model_git_url, credentialsId: git_creds, branch: model_branch - } - } - }, - clone_oasis_platform: { - stage('Clone: ' + oasis_func) { - sshagent (credentials: [git_creds]) { - dir(oasis_workspace) { - sh "git clone --recursive ${oasis_git_url} ." - if (oasis_branch.matches("PR-[0-9]+")){ - // Checkout PR and merge into target branch, test on the result - sh "git fetch origin pull/$CHANGE_ID/head:$BRANCH_NAME" - sh "git checkout $CHANGE_TARGET" - sh "git merge $BRANCH_NAME" - } else { - // Checkout branch - sh "git checkout ${oasis_branch}" - } - } - } - } - } - ) - - if (params.SCAN_REPO_VULNERABILITIES){ - parallel( - scan_platform_repo: { - stage('Scan: repo config') { - dir(oasis_workspace) { - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_repo} ${mnt_scan_report} aquasec/trivy fs --skip-dirs /mnt/kubernetes --exit-code 1 --severity ${params.SCAN_REPO_VULNERABILITIES} --output /tmp/cve_repo_general.txt --security-checks vuln,config,secret /mnt" - } - } - } - }, - scan_server_deps: { - stage('Scan: requirments-server.txt') { - dir(oasis_workspace) { - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_server_deps} ${mnt_scan_report} aquasec/trivy fs --exit-code 1 --severity ${params.SCAN_REPO_VULNERABILITIES} --output /tmp/cve_python_server.txt /mnt/requirements.txt" - } - } - } - }, - scan_worker_deps: { - stage('Scan: requirments-worker.txt') { - dir(oasis_workspace) { - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_worker_deps} ${mnt_scan_report} aquasec/trivy fs --exit-code 1 --severity ${params.SCAN_REPO_VULNERABILITIES} --output /tmp/cve_python_worker.txt /mnt/requirements.txt" - } - } - } - }, - scan_controller_deps: { - stage('Scan: requirments.txt (ctrl)') { - dir(oasis_workspace) { - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_controller_deps} ${mnt_scan_report} aquasec/trivy fs --exit-code 1 --severity ${params.SCAN_REPO_VULNERABILITIES} --output /tmp/cve_python_controller.txt /mnt/requirements.txt" - } - } - } - } - ) - } - - stage('Shell Env'){ - sh PIPELINE + ' print_model_vars' - if (params.CHECK_COMPATIBILITY) { - dir(oasis_workspace) { - if (params.PREV_RELEASE_TAG){ - env.LAST_RELEASE_TAG = params.PREV_RELEASE_TAG - } else { - sh "curl https://api.github.com/repos/OasisLMF/OasisPlatform/releases | jq -r '( first ) | .name' > last_release_tag" - env.LAST_RELEASE_TAG = readFile('last_release_tag').trim() - - println("LAST_RELEASE = $env.LAST_RELEASE_TAG") - } - } - } - } - if (mdk_branch && ! params.PUBLISH){ - stage('Git install MDK'){ - dir(oasis_workspace) { - // update worker and server install lists - sh "sed -i 's|^oasislmf.*| git+https://github.com/OasisLMF/OasisLMF.git@${mdk_branch}#egg=oasislmf[extra]|g' requirements-worker.txt" - sh "sed -i 's|^oasislmf.*| git+https://github.com/OasisLMF/OasisLMF.git@${mdk_branch}#egg=oasislmf[extra]|g' requirements.txt" - } - } - } - stage('Set version file'){ - dir(oasis_workspace){ - sh "echo ${env.TAG_RELEASE} - " + '$(git rev-parse --short HEAD), $(date) > VERSION' - } - } - parallel( - build_api_server: { - stage('Build: API server') { - dir(oasis_workspace) { - sh PIPELINE + " build_image ${docker_api} ${image_api} ${env.TAG_RELEASE}" - - } - } - }, - build_worker_controller: { - stage('Build: Worker controller') { - dir(oasis_workspace) { - dir('kubernetes/worker-controller'){ - sh PIPELINE + " build_image ${docker_controller} ${image_controller} ${env.TAG_RELEASE}" - } - } - } - }, - build_model_worker_ubuntu: { - stage('Build: Model worker - Ubuntu') { - dir(oasis_workspace) { - sh PIPELINE + " build_image ${docker_worker} ${image_worker} ${env.TAG_RELEASE}" - } - } - }, - build_model_worker_debian: { - stage('Build: Model worker - Debian') { - dir(oasis_workspace) { - sh PIPELINE + " build_image ${docker_worker_debian} ${image_worker} ${env.TAG_RELEASE}-debian" - } - } - } - ) - if(params.PUBLISH){ - // Build chanagelog image - stage("Create Changelog builder") { - dir(build_workspace) { - sh "docker build -f docker/Dockerfile.release-notes -t release-builder ." - } - } - } - - if (params.SCAN_IMAGE_VULNERABILITIES.replaceAll(" \\s","")){ - parallel( - scan_api_server: { - stage('Scan: API server'){ - dir(oasis_workspace) { - // Scan for Image Efficient - sh " ./imagesize.sh ${image_api}:${env.TAG_RELEASE} image_reports/size_api-server.txt" - - // Scan for CVE - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} ${mnt_output_report} aquasec/trivy image --security-checks vuln,config --exit-code 1 --severity ${params.SCAN_IMAGE_VULNERABILITIES} --output /tmp/cve_api-server.txt ${image_api}:${env.TAG_RELEASE}" - //sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} ${mnt_output_report} aquasec/trivy image --output /tmp/cve_api-server.txt ${image_api}:${env.TAG_RELEASE}" - //sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} aquasec/trivy image --exit-code 1 --severity ${params.SCAN_IMAGE_VULNERABILITIES} ${image_api}:${env.TAG_RELEASE}" - } - } - } - }, - scan_controller: { - stage('Scan: worker controller'){ - dir(oasis_workspace) { - // Scan for Image Efficient - sh " ./imagesize.sh ${image_controller}:${env.TAG_RELEASE} image_reports/size_controller.txt" - - // Scan for CVE - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} ${mnt_output_report} aquasec/trivy image --security-checks vuln,config --exit-code 1 --severity ${params.SCAN_IMAGE_VULNERABILITIES} --output /tmp/cve_controller.txt ${image_controller}:${env.TAG_RELEASE}" - //sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} ${mnt_output_report} aquasec/trivy image --output /tmp/cve_controller.txt ${image_controller}:${env.TAG_RELEASE}" - //sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} aquasec/trivy image --exit-code 1 --severity ${params.SCAN_IMAGE_VULNERABILITIES} ${image_controller}:${env.TAG_RELEASE}" - } - } - } - }, - scan_model_worker: { - stage('Scan: Model worker'){ - dir(oasis_workspace) { - // Scan for Image Efficient - sh " ./imagesize.sh ${image_worker}:${env.TAG_RELEASE} image_reports/size_model-worker.txt" - - // Scan for CVE - withCredentials([string(credentialsId: 'github-tkn-read', variable: 'gh_token')]) { - sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} ${mnt_output_report} aquasec/trivy image --security-checks vuln,config --exit-code 1 --severity ${params.SCAN_IMAGE_VULNERABILITIES} --output /tmp/cve_model-worker.txt ${image_worker}:${env.TAG_RELEASE}" - //sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} ${mnt_output_report} aquasec/trivy image --output /tmp/cve_model-worker.txt ${image_worker}:${env.TAG_RELEASE}" - //sh "docker run -e GITHUB_TOKEN=${gh_token} ${mnt_docker_socket} aquasec/trivy image --exit-code 1 --severity ${params.SCAN_IMAGE_VULNERABILITIES} ${image_worker}:${env.TAG_RELEASE}" - } - } - } - } - ) - } - if (params.UNITTEST){ - stage('Run: unittest') { - dir(oasis_workspace) { - sh " ./runtests.sh" - } - } - stage('Run: Test API schema') { - dir(oasis_workspace) { - sh " ./build-maven.sh ${env.TAG_RELEASE}" - } - } - } - - if (params.CHECK_S3 || params.CHECK_COMPATIBILITY) { - // Build PiWind worker from new worker - stage('Build: PiWind worker') { - dir(model_workspace) { - sh "docker build --build-arg worker_ver=${env.TAG_RELEASE} -f ${docker_piwind} -t ${image_piwind}:${env.TAG_RELEASE} ." - } - } - - } - - if (params.CHECK_COMPATIBILITY) { - - // START API for base model tests - stage('Run: API Server') { - dir(build_workspace) { - sh PIPELINE + " start_model" - } - } - - // RUN and test piwind - api_server_tests = model_tests.split() - for(int i=0; i < api_server_tests.size(); i++) { - stage("Run : ${api_server_tests[i]}"){ - dir(build_workspace) { - sh PIPELINE + " run_test --config /var/oasis/test/${model_test_ini} --test-case ${api_server_tests[i]}" - - // show docker logs - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs server' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker-monitor' - } - } - } - - // CHECK last release compatibility - stage("Compatibility with worker:${env.LAST_RELEASE_TAG}") { - dir(build_workspace) { - // Set tags - env.IMAGE_WORKER = image_worker - env.TAG_RUN_PLATFORM = params.RELEASE_TAG - env.TAG_RUN_WORKER = env.LAST_RELEASE_TAG - - // Setup containers - sh PIPELINE + " start_model" - - // run test - sh PIPELINE + " run_test --config /var/oasis/test/${model_test_ini} --test-case ${api_server_tests[0]}" - - // show docker logs - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs server' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker-monitor' - } - } - stage("Compatibility with server:${env.LAST_RELEASE_TAG}") { - dir(build_workspace) { - // reset db-data - sh PIPELINE + " stop_docker ${env.COMPOSE_PROJECT_NAME}" - env.OASIS_DOCKER_DB_DATA_DIR = './db-data_pre-ver' - - // Set tags - env.IMAGE_WORKER = image_worker - env.TAG_RUN_PLATFORM = env.LAST_RELEASE_TAG - env.TAG_RUN_WORKER = params.RELEASE_TAG - - // Setup containers - sh PIPELINE + " start_model" - - // run test - sh PIPELINE + " run_test --config /var/oasis/test/${model_test_ini} --test-case ${api_server_tests[0]}" - - // show docker logs - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs server' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker-monitor' - } - } - } - - if (params.CHECK_S3) { - stage("Check S3 storage"){ - dir(build_workspace) { - // Stop prev - if (params.CHECK_COMPATIBILITY) { - sh PIPELINE + " stop_docker ${env.COMPOSE_PROJECT_NAME}" - } - - // Start S3 compose files - sh PIPELINE + " start_model_s3" - - // Reset tags - env.IMAGE_WORKER = image_worker - env.TAG_RUN_PLATFORM = params.RELEASE_TAG - env.TAG_RUN_WORKER = params.RELEASE_TAG - - // run test - sh PIPELINE + " run_test_s3 --config /var/oasis/test/${model_test_ini} --test-case ${model_tests}" - - // show docker logs - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs server' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker-monitor' - } - } - } - if (params.RUN_REGRESSION) { - // RUN model regression tests - job_params = [ - [$class: 'StringParameterValue', name: 'TAG_OASIS', value: params.RELEASE_TAG] - ] - //RUN SEQUENTIAL JOBS - Fail on error - if (params.MODEL_REGRESSION){ - jobs_sequential = params.MODEL_REGRESSION.split() - for (pipeline in jobs_sequential){ - createStage(pipeline, job_params, true).call() - } - } - /* - //RUN PARALLEL JOBS - if (params.MODEL_REGRESSION){ - jobs_parallel = params.MODEL_REGRESSION.split() - parallel jobs_parallel.collectEntries { - ["${it}": createStage(it, job_params, true)] - } - } - - //RUN SEQUENTIAL JOBS - Continue on error - if (params.SEQUENTIAL_JOB_LIST_NOFAIL){ - jobs_sequential = params.SEQUENTIAL_JOB_LIST_NOFAIL.split() - for (pipeline in jobs_sequential){ - createStage(pipeline, job_params, false).call() - } - } - **/ - } - - if (params.PUBLISH){ - parallel( - publish_api_server: { - stage ('Publish: api_server') { - dir(build_workspace) { - sh PIPELINE + " push_image ${image_api} ${env.TAG_RELEASE}" - if (! params.PRE_RELEASE){ - sh PIPELINE + " push_image ${image_api} latest" - } - } - } - }, - publish_worker_controller: { - stage ('Publish: worker_controller') { - dir(build_workspace) { - sh PIPELINE + " push_image ${image_controller} ${env.TAG_RELEASE}" - if (! params.PRE_RELEASE){ - sh PIPELINE + " push_image ${image_controller} latest" - } - } - } - }, - publish_model_worker: { - stage('Publish: model_worker') { - dir(build_workspace) { - sh PIPELINE + " push_image ${image_worker} ${env.TAG_RELEASE}-debian" - sh PIPELINE + " push_image ${image_worker} ${env.TAG_RELEASE}" - if (! params.PRE_RELEASE){ - sh PIPELINE + " push_image ${image_worker} latest" - } - } - } - }, - publish_piwind_worker: { - stage('Publish: model_worker') { - dir(build_workspace) { - sh PIPELINE + " push_image ${image_piwind} ${env.TAG_RELEASE}" - } - } - } - ) - - stage("Tag release") { - sshagent (credentials: [git_creds]) { - dir(model_workspace) { - // Tag PiWind - sh PIPELINE + " git_tag ${env.TAG_RELEASE}" - } - dir(oasis_workspace) { - // Tag the OasisPlatform - sh PIPELINE + " git_tag ${env.TAG_RELEASE}" - } - } - } - - // Create release notes - stage('Create Changelog'){ - dir(oasis_workspace){ - withCredentials([string(credentialsId: 'github-api-token', variable: 'gh_token')]) { - sh "docker run -v ${env.WORKSPACE}/${oasis_workspace}:/tmp release-builder build-changelog --repo OasisPlatform --from-tag ${params.PREV_RELEASE_TAG} --to-tag ${params.RELEASE_TAG} --github-token ${gh_token} --local-repo-path ./ --output-path ./CHANGELOG.rst --apply-milestone" - sh "docker run -v ${env.WORKSPACE}/${oasis_workspace}:/tmp release-builder build-release-platform --platform-from-tag ${params.PREV_RELEASE_TAG} --platform-to-tag ${params.RELEASE_TAG} --lmf-from-tag ${params.OASISLMF_PREV_TAG} --lmf-to-tag ${params.OASISLMF_TAG} --ktools-from-tag ${params.KTOOLS_PREV_TAG} --ktools-to-tag ${params.KTOOLS_TAG} --github-token ${gh_token} --output-path ./RELEASE.md" - } - sshagent (credentials: [git_creds]) { - sh "git add ./CHANGELOG.rst" - sh "git commit -m 'Update changelog ${params.RELEASE_TAG}'" - sh "git push" - } - } - } - stage ('Create Release: GitHub') { - // Create Release - withCredentials([string(credentialsId: 'github-api-token', variable: 'gh_token')]) { - dir(oasis_workspace) { - String repo = "OasisLMF/OasisPlatform" - def release_body = readFile(file: "${env.WORKSPACE}/${oasis_workspace}/RELEASE.md") - def json_request = readJSON text: '{}' - json_request['tag_name'] = RELEASE_TAG - json_request['target_commitish'] = 'master' - json_request['name'] = RELEASE_TAG - json_request['body'] = release_body - json_request['draft'] = false - json_request['prerelease'] = params.PRE_RELEASE - writeJSON file: 'gh_request.json', json: json_request - sh 'curl -XPOST -H "Authorization:token ' + gh_token + "\" --data @gh_request.json https://api.github.com/repos/$repo/releases > gh_response.json" - - // Fetch release ID and post json schema - def response = readJSON file: "gh_response.json" - release_id = response['id'] - dir('reports') { - filename='openapi-schema.json' - sh 'curl -XPOST -H "Authorization:token ' + gh_token + '" -H "Content-Type:application/octet-stream" --data-binary @' + filename + " https://uploads.github.com/repos/$repo/releases/$release_id/assets?name=" + "openapi-schema-${RELEASE_TAG}.json" - } - } - } - } - } - } catch(hudson.AbortException | org.jenkinsci.plugins.workflow.steps.FlowInterruptedException buildException) { - hasFailed = true - error('Build Failed') - } finally { - dir(build_workspace) { - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs server-db > ./stage/log/server-db.log ' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs server > ./stage/log/server.log ' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs celery-db > ./stage/log/celery-db.log ' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs rabbit > ./stage/log/rabbit.log ' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker > ./stage/log/worker.log ' - sh 'docker-compose -f compose/oasis.platform.yml -f compose/model.worker.yml logs worker-monitor > ./stage/log/worker-monitor.log ' - sh PIPELINE + " stop_docker ${env.COMPOSE_PROJECT_NAME}" - if(params.PURGE){ - sh PIPELINE + " purge_image ${image_api} ${env.TAG_RELEASE}" - sh PIPELINE + " purge_image ${image_controller} ${env.TAG_RELEASE}" - sh PIPELINE + " purge_image ${image_worker} ${env.TAG_RELEASE}" - sh PIPELINE + " purge_image ${image_worker} ${env.TAG_RELEASE}-debian" - sh PIPELINE + " purge_image ${image_piwind} ${env.TAG_RELEASE}" - } - } - - if(params.SLACK_MESSAGE && (params.PUBLISH || hasFailed)){ - def slackColor = hasFailed ? '#FF0000' : '#27AE60' - JOB = env.JOB_NAME.replaceAll('%2F','/') - SLACK_GIT_URL = "https://github.com/OasisLMF/${oasis_name}/tree/${oasis_branch}" - SLACK_MSG = "*${JOB}* - (<${env.BUILD_URL}|${env.RELEASE_TAG}>): " + (hasFailed ? 'FAILED' : 'PASSED') - SLACK_MSG += "\nBranch: <${SLACK_GIT_URL}|${oasis_branch}>" - SLACK_MSG += "\nMode: " + (params.PUBLISH ? 'Publish' : 'Build Test') - SLACK_CHAN = (params.PUBLISH ? "#builds-release":"#builds-dev") - slackSend(channel: SLACK_CHAN, message: SLACK_MSG, color: slackColor) - } - //Store logs - dir(build_workspace) { - archiveArtifacts artifacts: "stage/log/**/*.*", excludes: '*stage/log/**/*.gitkeep' - archiveArtifacts artifacts: "stage/output/**/*.*" - } - //Store repo scan reports - if (params.SCAN_REPO_VULNERABILITIES.replaceAll(" \\s","")){ - dir(oasis_workspace){ - archiveArtifacts artifacts: 'scan_reports/**/*.*' - } - - } - //Store Docker image reports - if (params.SCAN_IMAGE_VULNERABILITIES.replaceAll(" \\s","")){ - dir(oasis_workspace){ - archiveArtifacts artifacts: 'image_reports/**/*.*' - } - - } - //Store reports - if (params.UNITTEST){ - dir(oasis_workspace){ - archiveArtifacts artifacts: 'reports/**/*.*' - } - } - // Run merge back if publish - if (params.PUBLISH && params.AUTO_MERGE && ! hasFailed){ - dir(oasis_workspace) { - sshagent (credentials: [git_creds]) { - if (! params.PRE_RELEASE) { - // Release merge back into master - sh "git stash" - sh "git checkout master && git pull" - sh "git merge ${oasis_branch} && git push" - sh "git checkout develop && git pull" - sh "git merge master && git push" - } else { - // pre_pelease merge back into develop - sh "git stash" - sh "git checkout develop && git pull" - sh "git merge ${oasis_branch} && git push" - } - } - } - } - } -} diff --git a/kubernetes/charts/oasis-models/resources/model_registration.sh b/kubernetes/charts/oasis-models/resources/model_registration.sh index f340c6ee8..c3fbbf325 100644 --- a/kubernetes/charts/oasis-models/resources/model_registration.sh +++ b/kubernetes/charts/oasis-models/resources/model_registration.sh @@ -5,16 +5,18 @@ # # Script compatible with sh and not bash +#set -x set -e set -o pipefail -BASE_URL="http://${OASIS_SERVER_HOST}:${OASIS_SERVER_PORT}" +BASE_URL="http://${OASIS_SERVER_HOST}:${OASIS_SERVER_PORT}/api" MODEL_SETTINGS_FILE="$OASIS_MODEL_DATA_DIRECTORY/model_settings.json" CHUNKING_CONFIGURATION_FILE="$OASIS_MODEL_DATA_DIRECTORY/chunking_configuration.json" SCALING_CONFIGURATION_FILE="$OASIS_MODEL_DATA_DIRECTORY/scaling_configuration.json" echo echo "=== Register model ===" +echo "Run mode version : $OASIS_RUN_MODE" echo "Supplier ID : $OASIS_MODEL_SUPPLIER_ID" echo "Model ID : $OASIS_MODEL_ID" echo "Model version ID : $OASIS_MODEL_VERSION_ID" @@ -22,7 +24,7 @@ echo "Model path : $OASIS_MODEL_DATA_DIRECTORY" echo "Model groups : $OASIS_MODEL_GROUPS" echo -if [ -z "$OASIS_MODEL_SUPPLIER_ID" ] || [ -z "$OASIS_MODEL_ID" ] || [ -z "$OASIS_MODEL_VERSION_ID" ] || [ -z "$OASIS_MODEL_DATA_DIRECTORY" ]; then +if [ -z "$OASIS_MODEL_SUPPLIER_ID" ] || [ -z "$OASIS_MODEL_ID" ] || [ -z "$OASIS_MODEL_VERSION_ID" ] || [ -z "$OASIS_MODEL_DATA_DIRECTORY" ] || [ -z "$OASIS_RUN_MODE" ]; then echo "Missing required model env var(s)" exit 1 fi @@ -69,19 +71,21 @@ curlf() { } MODEL_ID=$(curl -s -H "Authorization: Bearer ${ACCESS_TOKEN}" -X GET \ - "${BASE_URL}/v1/models/?supplier_id=${OASIS_MODEL_SUPPLIER_ID}&version_id=${OASIS_MODEL_VERSION_ID}" | \ + "${BASE_URL}/v2/models/?supplier_id=${OASIS_MODEL_SUPPLIER_ID}&version_id=${OASIS_MODEL_VERSION_ID}" | \ jq ".[] | select((.supplier_id | ascii_downcase == \"$(echo ${OASIS_MODEL_SUPPLIER_ID} | \ tr '[:upper:]' '[:lower:]')\") and (.model_id | ascii_downcase == \"$(echo ${OASIS_MODEL_ID} | \ tr '[:upper:]' '[:lower:]')\") and (.version_id == \"$(echo ${OASIS_MODEL_VERSION_ID} | tr '[:upper:]' '[:lower:]')\")) | .id") -MODEL_JSON_ID_ATTRIBUTES="\"supplier_id\": \"${OASIS_MODEL_SUPPLIER_ID}\",\"model_id\": \"${OASIS_MODEL_ID}\",\"version_id\": \"${OASIS_MODEL_VERSION_ID}\"" + +MODEL_RUN_MODE=$(echo ${OASIS_RUN_MODE} | tr '[:lower:]' '[:upper:]') # set model's execution workflow +MODEL_JSON_ID_ATTRIBUTES="\"supplier_id\": \"${OASIS_MODEL_SUPPLIER_ID}\",\"model_id\": \"${OASIS_MODEL_ID}\",\"version_id\": \"${OASIS_MODEL_VERSION_ID}\",\"run_mode\": \"${MODEL_RUN_MODE}\"" if [ -n "$MODEL_ID" ]; then echo "Model exists with id $MODEL_ID" else echo "Model not found - registers it" - MODEL_ID=$(curlf -X POST "${BASE_URL}/v1/models/" -H "Content-Type: application/json" \ + MODEL_ID=$(curlf -X POST "${BASE_URL}/v2/models/" -H "Content-Type: application/json" \ -d "{${MODEL_JSON_ID_ATTRIBUTES}"} | jq .id) echo "Created with id $MODEL_ID" @@ -93,19 +97,21 @@ if [ -n "$OASIS_MODEL_GROUPS" ]; then GROUPS_JSON="{${MODEL_JSON_ID_ATTRIBUTES}, \"groups\": [\"$(echo $OASIS_MODEL_GROUPS | sed 's/,/","/g')\"]}" fi -curlf -X PATCH "${BASE_URL}/v1/models/${MODEL_ID}/" -H "Content-Type: application/json" -d "$GROUPS_JSON" | jq . +curlf -X PATCH "${BASE_URL}/v2/models/${MODEL_ID}/" -H "Content-Type: application/json" -d "$GROUPS_JSON" | jq . echo "Uploading model settings" -curlf -X POST "${BASE_URL}/v1/models/${MODEL_ID}/settings/" -H "Content-Type: application/json" -d @${MODEL_SETTINGS_FILE} | jq . - -if [ -f "$CHUNKING_CONFIGURATION_FILE" ]; then - echo "Uploading chunking configuration" - curlf -X POST "${BASE_URL}/v1/models/${MODEL_ID}/chunking_configuration/" -H "Content-Type: application/json" -d @${CHUNKING_CONFIGURATION_FILE} | jq . -fi - -if [ -f "$SCALING_CONFIGURATION_FILE" ]; then - echo "Uploading scaling configuration" - curlf -X POST "${BASE_URL}/v1/models/${MODEL_ID}/scaling_configuration/" -H "Content-Type: application/json" -d @${SCALING_CONFIGURATION_FILE} | jq . -fi +curlf -X POST "${BASE_URL}/v2/models/${MODEL_ID}/settings/" -H "Content-Type: application/json" -d @${MODEL_SETTINGS_FILE} | jq . + +if [[ "$OASIS_RUN_MODE" == "v2" ]]; then + if [ -f "$CHUNKING_CONFIGURATION_FILE" ]; then + echo "Uploading chunking configuration" + curlf -X POST "${BASE_URL}/v2/models/${MODEL_ID}/chunking_configuration/" -H "Content-Type: application/json" -d @${CHUNKING_CONFIGURATION_FILE} | jq . + fi + + if [ -f "$SCALING_CONFIGURATION_FILE" ]; then + echo "Uploading scaling configuration" + curlf -X POST "${BASE_URL}/v2/models/${MODEL_ID}/scaling_configuration/" -H "Content-Type: application/json" -d @${SCALING_CONFIGURATION_FILE} | jq . + fi +fi echo "Finished" diff --git a/kubernetes/charts/oasis-models/resources/multiple_model_registration.sh b/kubernetes/charts/oasis-models/resources/multiple_model_registration.sh index cc24a7af9..2f34ddb02 100644 --- a/kubernetes/charts/oasis-models/resources/multiple_model_registration.sh +++ b/kubernetes/charts/oasis-models/resources/multiple_model_registration.sh @@ -17,6 +17,7 @@ while true; do if [ -n "$(sh -c "echo \$OASIS_MODEL_SUPPLIER_ID_${counter}")" ]; then + export OASIS_RUN_MODE="$(sh -c "echo \$OASIS_RUN_MODE_${counter}")" export OASIS_MODEL_SUPPLIER_ID="$(sh -c "echo \$OASIS_MODEL_SUPPLIER_ID_${counter}")" export OASIS_MODEL_ID="$(sh -c "echo \$OASIS_MODEL_ID_${counter}")" export OASIS_MODEL_VERSION_ID="$(sh -c "echo \$OASIS_MODEL_VERSION_ID_${counter}")" diff --git a/kubernetes/charts/oasis-models/templates/_helpers.tpl b/kubernetes/charts/oasis-models/templates/_helpers.tpl index f75b21589..aa30b9b8d 100644 --- a/kubernetes/charts/oasis-models/templates/_helpers.tpl +++ b/kubernetes/charts/oasis-models/templates/_helpers.tpl @@ -26,6 +26,7 @@ oasislmf/type: worker oasislmf/supplier-id: {{ .supplierId | quote }} oasislmf/model-id: {{ .modelId | quote }} oasislmf/model-version-id: {{ .modelVersionId | quote }} +oasislmf/api-version: {{ .apiVersion | quote }} {{- end }} {{/* @@ -144,6 +145,20 @@ Variables for a broker client configMapKeyRef: name: {{ .Values.databases.broker.name }} key: uri +- name: OASIS_RABBIT_HOST + valueFrom: + configMapKeyRef: + name: {{ .Values.databases.broker.name }} + key: host +- name: OASIS_RABBIT_PORT + valueFrom: + configMapKeyRef: + name: {{ .Values.databases.broker.name }} + key: port +- name: OASIS_RABBIT_USER + value: rabbit +- name: OASIS_RABBIT_PASS + value: rabbit {{- end }} {{/* diff --git a/kubernetes/charts/oasis-models/templates/model_registration_job.yaml b/kubernetes/charts/oasis-models/templates/model_registration_job.yaml index 630ff20d2..538668788 100644 --- a/kubernetes/charts/oasis-models/templates/model_registration_job.yaml +++ b/kubernetes/charts/oasis-models/templates/model_registration_job.yaml @@ -34,6 +34,8 @@ spec: {{- include "h.oasisServerEnvs" $root | indent 12 }} {{- $index = 0 -}} {{- range $name, $worker := .Values.workers }} + - name: OASIS_RUN_MODE_{{ $index }} + value: {{ .apiVersion | quote}} - name: OASIS_MODEL_SUPPLIER_ID_{{ $index }} value: {{ .supplierId | quote}} - name: OASIS_MODEL_ID_{{ $index }} diff --git a/kubernetes/charts/oasis-models/templates/workers.yaml b/kubernetes/charts/oasis-models/templates/workers.yaml index d914ea011..85383d5bd 100644 --- a/kubernetes/charts/oasis-models/templates/workers.yaml +++ b/kubernetes/charts/oasis-models/templates/workers.yaml @@ -1,6 +1,6 @@ {{- $root := . -}} {{- range $k, $worker := .Values.workers }} -{{- $name := printf "worker-%s-%s-%s" $worker.supplierId $worker.modelId $worker.modelVersionId | lower }} +{{- $name := printf "worker-%s-%s-%s-%s" $worker.supplierId $worker.modelId $worker.modelVersionId $worker.apiVersion | lower }} apiVersion: apps/v1 kind: Deployment metadata: @@ -37,9 +37,26 @@ spec: {{- end }} name: worker env: + - name: OASIS_RUN_MODE + value: {{ $worker.apiVersion }} {{- include "h.modelEnvs" . | indent 12 }} {{- include "h.brokerVars" $root | indent 12 }} {{- include "h.celeryDbVars" $root | indent 12}} +{{- if eq $worker.apiVersion "v1" }} + startupProbe: + exec: + command: ["celery", "-A", "src.model_execution_worker.tasks", "inspect", "ping"] + timeoutSeconds: 5 + periodSeconds: 10 + failureThreshold: 30 + livenessProbe: + exec: + command: ["celery", "-A", "src.model_execution_worker.tasks", "inspect", "ping"] + initialDelaySeconds: 30 + periodSeconds: 60 + timeoutSeconds: 60 + failureThreshold: 15 +{{- else }} startupProbe: exec: command: ["celery", "-A", "src.model_execution_worker.distributed_tasks", "inspect", "ping"] @@ -53,6 +70,7 @@ spec: periodSeconds: 60 timeoutSeconds: 60 failureThreshold: 15 +{{- end }} volumeMounts: - name: shared-fs-persistent-storage mountPath: /shared-fs diff --git a/kubernetes/charts/oasis-models/values.yaml b/kubernetes/charts/oasis-models/values.yaml index b364ace15..e05d3179c 100644 --- a/kubernetes/charts/oasis-models/values.yaml +++ b/kubernetes/charts/oasis-models/values.yaml @@ -26,6 +26,7 @@ workers: supplierId: OasisLMF modelId: PiWind modelVersionId: "1" + apiVersion: "v2" image: coreoasis/model_worker version: dev imagePullPolicy: Never diff --git a/kubernetes/charts/oasis-monitoring/templates/flower.yaml b/kubernetes/charts/oasis-monitoring/templates/flower.yaml index 8bb5a5ae0..6e9ddbc63 100644 --- a/kubernetes/charts/oasis-monitoring/templates/flower.yaml +++ b/kubernetes/charts/oasis-monitoring/templates/flower.yaml @@ -32,6 +32,8 @@ spec: {{- include "h.brokerVars" . | indent 12 }} - name: FLOWER_LOGGING value: INFO + - name: FLOWER_UNAUTHENTICATED_API + value: "true" livenessProbe: failureThreshold: 2 httpGet: diff --git a/kubernetes/charts/oasis-platform/templates/_helpers.tpl b/kubernetes/charts/oasis-platform/templates/_helpers.tpl index 75d783b3a..fddfc5a4f 100644 --- a/kubernetes/charts/oasis-platform/templates/_helpers.tpl +++ b/kubernetes/charts/oasis-platform/templates/_helpers.tpl @@ -184,16 +184,6 @@ Variables for a channel layer client secretKeyRef: name: {{ .Values.databases.channel_layer.name }} key: password -- name: OASIS_SERVER_CHANNEL_LAYER_PORT - valueFrom: - configMapKeyRef: - name: {{ .Values.databases.channel_layer.name }} - key: port -- name: OASIS_SERVER_CHANNEL_LAYER_SSL - valueFrom: - configMapKeyRef: - name: {{ .Values.databases.channel_layer.name }} - key: ssl {{- end }} - name: OASIS_INPUT_GENERATION_CONTROLLER_QUEUE value: task-controller diff --git a/kubernetes/charts/oasis-platform/templates/ingress.yaml b/kubernetes/charts/oasis-platform/templates/ingress.yaml index b2a50b0a8..aceb8ff6e 100644 --- a/kubernetes/charts/oasis-platform/templates/ingress.yaml +++ b/kubernetes/charts/oasis-platform/templates/ingress.yaml @@ -26,6 +26,7 @@ spec: {{- end }} rules: - host: {{ .Values.ingress.uiHostname }} + http: paths: - path: / @@ -49,3 +50,10 @@ spec: name: oasis-server port: number: 8000 + - path: /ws + pathType: Prefix + backend: + service: + name: oasis-websocket + port: + number: 8001 diff --git a/kubernetes/charts/oasis-platform/templates/keycloak.yaml b/kubernetes/charts/oasis-platform/templates/keycloak.yaml index 016485315..02ba24e53 100644 --- a/kubernetes/charts/oasis-platform/templates/keycloak.yaml +++ b/kubernetes/charts/oasis-platform/templates/keycloak.yaml @@ -94,44 +94,51 @@ spec: containers: - name: {{ .Values.keycloak.name }} image: {{ .Values.images.keycloak.image }}:{{ .Values.images.keycloak.version }} - args: [ - "start-dev", - "--import-realm", - "--http-relative-path /auth", - "--proxy passthrough", - "--hostname-strict=false"] + args: ["start", "--import-realm"] ports: - containerPort: {{ .Values.keycloak.port }} env: - - name: KEYCLOAK_PROXY_ADDRESS_FORWARDING + - name: KC_LOGLEVEL + value: DEBUG + - name: PROXY_ADDRESS_FORWARDING + value: "true" + - name: KC_PROXY + value: "edge" + - name: KC_IMPORT + value: "/opt/keycloak/data/import/oasis-realm.json" + - name: KC_HTTP_RELATIVE_PATH + value: "/auth" + - name: KC_HOSTNAME_STRICT + value: "false" + - name: KC_PROXY_ADDRESS_FORWARDING value: "true" - - name: DB_VENDOR + - name: KC_DB {{- if eq .Values.databases.keycloak_db.type "mysql" }} value: mysql {{- else }} value: postgres {{- end }} - - name: DB_ADDR + - name: KC_DB_URL_HOST valueFrom: configMapKeyRef: name: {{ .Values.databases.keycloak_db.name }} key: host - - name: DB_PORT + - name: KC_DB_URL_PORT valueFrom: configMapKeyRef: name: {{ .Values.databases.keycloak_db.name }} key: port - - name: DB_DATABASE + - name: KC_DB_URL_DATABASE valueFrom: configMapKeyRef: name: {{ .Values.databases.keycloak_db.name }} key: dbName - - name: DB_USER + - name: KC_DB_USERNAME valueFrom: secretKeyRef: name: {{ .Values.databases.keycloak_db.name }} key: user - - name: DB_PASSWORD + - name: KC_DB_PASSWORD valueFrom: secretKeyRef: name: {{ .Values.databases.keycloak_db.name }} @@ -146,12 +153,6 @@ spec: secretKeyRef: name: {{ .Values.keycloak.name }} key: password - - name: KEYCLOAK_LOGLEVEL - value: INFO - - name: PROXY_ADDRESS_FORWARDING - value: "true" - - name: KEYCLOAK_IMPORT - value: "/opt/keycloak/data/import/oasis-realm.json" startupProbe: httpGet: path: /auth/realms/master diff --git a/kubernetes/charts/oasis-platform/templates/oasis.yaml b/kubernetes/charts/oasis-platform/templates/oasis.yaml index 3b29e2dbc..b19cf28e8 100644 --- a/kubernetes/charts/oasis-platform/templates/oasis.yaml +++ b/kubernetes/charts/oasis-platform/templates/oasis.yaml @@ -33,19 +33,19 @@ spec: apiVersion: apps/v1 kind: Deployment metadata: - name: oasis-worker-monitor + name: oasis-v1-worker-monitor labels: {{- include "h.labels" . | nindent 4}} spec: selector: matchLabels: - app: oasis-worker-monitor + app: oasis-v1-worker-monitor strategy: type: Recreate template: metadata: labels: - app: oasis-worker-monitor + app: oasis-v1-worker-monitor {{- include "h.labels" . | nindent 8}} annotations: checksum/oasis-server: {{ toJson .Values.oasisServer | sha256sum }} @@ -58,7 +58,7 @@ spec: containers: - image: {{ .Values.images.oasis.platform.image }}:{{ .Values.images.oasis.platform.version }} imagePullPolicy: {{ .Values.images.oasis.platform.imagePullPolicy }} - name: oasis-worker-monitor + name: oasis-v1-worker-monitor env: {{- include "h.serverDbVars" . | indent 12}} {{- include "h.celeryDbVars" . | indent 12}} @@ -66,7 +66,7 @@ spec: {{- include "h.channelLayerVars" . | indent 12 }} - name: OASIS_DEBUG value: "1" - command: [ "celery", "-A", "src.server.oasisapi.celery_app", "worker", "--loglevel=INFO" ] + command: [ "celery", "-A", "src.server.oasisapi.celery_app_v1", "worker", "--loglevel=INFO", "-Q", "celery" ] volumeMounts: - name: shared-fs-persistent-storage mountPath: /shared-fs @@ -77,13 +77,88 @@ spec: {{- end }} startupProbe: exec: - command: ["celery", "-A", "src.server.oasisapi.celery_app", "inspect", "ping"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v1", "inspect", "ping"] timeoutSeconds: 10 periodSeconds: 10 failureThreshold: 30 livenessProbe: exec: - command: ["celery", "-A", "src.server.oasisapi.celery_app", "inspect", "ping"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v1", "inspect", "ping"] + initialDelaySeconds: 30 + periodSeconds: 60 + timeoutSeconds: 10 + volumes: +{{- if ((.Values.volumes.host).sharedFs) }} + - name: shared-fs-persistent-storage + persistentVolumeClaim: + claimName: {{ .Values.volumes.host.sharedFs.name }} +{{- else if ((.Values.volumes.azureFiles).sharedFs) }} + - name: shared-fs-persistent-storage +{{- toYaml .Values.volumes.azureFiles.sharedFs | nindent 10 }} +{{- if (.Values.azure).secretProvider }} + - name: azure-secret-provider + csi: + driver: secrets-store.csi.k8s.io + readOnly: true + volumeAttributes: + secretProviderClass: "azure-secret-provider" +{{- end }} +{{- end }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: oasis-v2-worker-monitor + labels: + {{- include "h.labels" . | nindent 4}} +spec: + selector: + matchLabels: + app: oasis-v2-worker-monitor + strategy: + type: Recreate + template: + metadata: + labels: + app: oasis-v2-worker-monitor + {{- include "h.labels" . | nindent 8}} + annotations: + checksum/oasis-server: {{ toJson .Values.oasisServer | sha256sum }} + checksum/{{ .Values.databases.oasis_db.name }}: versions{{ toJson .Values.databases.oasis_db | sha256sum }} + checksum/{{ .Values.databases.celery_db.name }}: versions{{ toJson .Values.databases.celery_db | sha256sum }} + spec: + {{- include "h.affinity" . | nindent 6 }} + initContainers: + {{- include "h.initTcpAvailabilityCheckBySecret" (list . .Values.databases.oasis_db.name .Values.databases.celery_db.name .Values.oasisServer.name .Values.databases.broker.name) | nindent 8}} + containers: + - image: {{ .Values.images.oasis.platform.image }}:{{ .Values.images.oasis.platform.version }} + imagePullPolicy: {{ .Values.images.oasis.platform.imagePullPolicy }} + name: oasis-v2-worker-monitor + env: +{{- include "h.serverDbVars" . | indent 12}} +{{- include "h.celeryDbVars" . | indent 12}} +{{- include "h.brokerVars" . | indent 12 }} +{{- include "h.channelLayerVars" . | indent 12 }} + - name: OASIS_DEBUG + value: "1" + command: [ "celery", "-A", "src.server.oasisapi.celery_app_v2", "worker", "--loglevel=INFO", "-Q", "celery-v2" ] + volumeMounts: + - name: shared-fs-persistent-storage + mountPath: /shared-fs +{{- if (.Values.azure).secretProvider }} + - name: azure-secret-provider + mountPath: "/mnt/azure-secrets-store" + readOnly: true +{{- end }} + startupProbe: + exec: + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "inspect", "ping"] + timeoutSeconds: 10 + periodSeconds: 10 + failureThreshold: 30 + livenessProbe: + exec: + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "inspect", "ping"] initialDelaySeconds: 30 periodSeconds: 60 timeoutSeconds: 10 @@ -132,9 +207,10 @@ spec: - image: {{ .Values.images.oasis.ui.image }}:{{ .Values.images.oasis.ui.version }} name: {{ .Values.oasisUI.name }} env: - {{- include "h.serverApiVars" . | nindent 12}} + - name: API_IP + value: oasis-server:8000/api/ - name: API_VERSION - value: v1 + value: v2 - name: OASIS_ENVIRONMENT value: oasis_localhost - name: API_SHARE_FILEPATH @@ -204,7 +280,7 @@ spec: - image: {{ .Values.images.oasis.platform.image }}:{{ .Values.images.oasis.platform.version }} imagePullPolicy: {{ .Values.images.oasis.platform.imagePullPolicy }} name: celery-beat - command: ["celery", "-A", "src.server.oasisapi.celery_app", "beat", "--loglevel=DEBUG"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "beat", "--loglevel=DEBUG"] env: {{- include "h.serverDbVars" . | indent 12}} {{- include "h.celeryDbVars" . | indent 12}} @@ -218,13 +294,13 @@ spec: {{- end }} startupProbe: exec: - command: ["celery", "-A", "src.server.oasisapi.celery_app", "inspect", "ping"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "inspect", "ping"] timeoutSeconds: 10 periodSeconds: 10 failureThreshold: 30 livenessProbe: exec: - command: ["celery", "-A", "src.server.oasisapi.celery_app", "inspect", "ping"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "inspect", "ping"] initialDelaySeconds: 30 periodSeconds: 60 timeoutSeconds: 10 @@ -268,7 +344,7 @@ spec: - image: {{ .Values.images.oasis.platform.image }}:{{ .Values.images.oasis.platform.version }} imagePullPolicy: {{ .Values.images.oasis.platform.imagePullPolicy }} name: main - command: ["celery", "-A", "src.server.oasisapi.celery_app", "worker", "--loglevel=INFO", "-Q", "task-controller"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "worker", "--loglevel=INFO", "-Q", "task-controller"] env: {{- include "h.serverDbVars" . | indent 12}} {{- include "h.celeryDbVars" . | indent 12}} @@ -284,13 +360,13 @@ spec: {{- end }} startupProbe: exec: - command: ["celery", "-A", "src.server.oasisapi.celery_app", "inspect", "ping"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "inspect", "ping"] timeoutSeconds: 10 periodSeconds: 10 failureThreshold: 30 livenessProbe: exec: - command: ["celery", "-A", "src.server.oasisapi.celery_app", "inspect", "ping"] + command: ["celery", "-A", "src.server.oasisapi.celery_app_v2", "inspect", "ping"] initialDelaySeconds: 30 periodSeconds: 60 timeoutSeconds: 10 diff --git a/kubernetes/charts/oasis-platform/templates/oasis_server.yaml b/kubernetes/charts/oasis-platform/templates/oasis_server.yaml index 2548fd3d1..d1b1bb5c8 100644 --- a/kubernetes/charts/oasis-platform/templates/oasis_server.yaml +++ b/kubernetes/charts/oasis-platform/templates/oasis_server.yaml @@ -132,18 +132,14 @@ spec: readOnly: true {{- end }} startupProbe: - httpGet: - path: /healthcheck/ + tcpSocket: port: {{.Values.oasisWebsocket.port}} - scheme: HTTP timeoutSeconds: 10 periodSeconds: 10 failureThreshold: 30 livenessProbe: - httpGet: - path: /healthcheck/ + tcpSocket: port: {{.Values.oasisWebsocket.port}} - scheme: HTTP initialDelaySeconds: 30 periodSeconds: 30 timeoutSeconds: 10 diff --git a/kubernetes/charts/oasis-platform/templates/oasis_worker_controller.yaml b/kubernetes/charts/oasis-platform/templates/oasis_worker_controller.yaml index af8d5fed3..5dba31de2 100644 --- a/kubernetes/charts/oasis-platform/templates/oasis_worker_controller.yaml +++ b/kubernetes/charts/oasis-platform/templates/oasis_worker_controller.yaml @@ -75,11 +75,20 @@ spec: - name: OASIS_NEVER_SHUTDOWN_FIXED_WORKERS value: "{{ .Values.workerController.neverShutdownFixedWorkers }}" - name: OASIS_API_HOST - value: {{ .Values.oasisWebsocket.name }} + value: {{ .Values.oasisServer.name }} - name: OASIS_API_PORT + value: {{ .Values.oasisServer.port | quote }} + - name: OASIS_API_SUBPATH + value: "api" + - name: OASIS_WEBSOCKET_HOST + value: {{ .Values.oasisWebsocket.name }} + - name: OASIS_WEBSOCKET_PORT value: {{ .Values.oasisWebsocket.port | quote }} - name: OASIS_CLUSTER_NAMESPACE value: {{ .Release.Namespace | quote }} + - name: OASIS_LOGLEVEL + value: "{{ .Values.workerController.logLevel }}" + {{- with .Values.workerController.totalWorkerLimit }} - name: OASIS_TOTAL_WORKER_LIMIT value: {{ . | quote }} diff --git a/kubernetes/charts/oasis-platform/values.yaml b/kubernetes/charts/oasis-platform/values.yaml index 8fd9fcc9f..ea520e19d 100644 --- a/kubernetes/charts/oasis-platform/values.yaml +++ b/kubernetes/charts/oasis-platform/values.yaml @@ -9,26 +9,26 @@ images: imagePullPolicy: Never ui: image: coreoasis/oasisui_app - version: 1.11.6 + version: 1.11.7 worker_controller: image: coreoasis/worker_controller version: dev imagePullPolicy: Never postgres: image: postgres - version: 15.3-alpine3.18 + version: 15-alpine3.18 mysql: image: mysql version: 8.0.33 keycloak: image: quay.io/keycloak/keycloak - version: 22.0.0 + version: 23.0.6-0 init: image: busybox version: 1.28 redis: image: redis - version: 7.0.11-alpine3.18 + version: 7-alpine3.18 rabbitmq: image: rabbitmq version: 3.11-management-alpine @@ -181,7 +181,7 @@ oasisServer: - name: OASIS_DEBUG value: "1" - name: OASIS_PORTFOLIO_UPLOAD_VALIDATION - value: "1" + value: "0" # Oasis API host, port and credentials (for Websocket connections) oasisWebsocket: @@ -211,6 +211,9 @@ workerController: # Debug option - This prevents workers for a model set to FIXED_WORKERS to be scaled to 0. They will always be started. neverShutdownFixedWorkers: false + # Set Log level, Warning - setting 'DEBUG' will cause instability under high load do not set this in production + logLevel: 'INFO' + # Settings related to prioritized runs priority: diff --git a/kubernetes/scripts/k8s/upload_piwind_model_data.sh b/kubernetes/scripts/k8s/upload_piwind_model_data.sh index 0ea3984d9..dbd8e1d1d 100755 --- a/kubernetes/scripts/k8s/upload_piwind_model_data.sh +++ b/kubernetes/scripts/k8s/upload_piwind_model_data.sh @@ -3,7 +3,7 @@ set -e PWP=$1 -MODEL_PATHS="meta-data/model_settings.json oasislmf.json model_data/ keys_data/ tests/" +MODEL_PATHS="meta-data/model_settings.json oasislmf.json model_data/ keys_data/ tests/ meta-data/" OPTIONAL_MODEL_FILES="meta-data/chunking_configuration.json meta-data/scaling_configuration.json" OASIS_CLUSTER_NAMESPACE="${OASIS_CLUSTER_NAMESPACE:-default}" diff --git a/kubernetes/worker-controller/data/test/queue-status-2-running-2-queues.json b/kubernetes/worker-controller/data/test/queue-status-2-running-2-queues.json index d6b620296..d8f9e0801 100644 --- a/kubernetes/worker-controller/data/test/queue-status-2-running-2-queues.json +++ b/kubernetes/worker-controller/data/test/queue-status-2-running-2-queues.json @@ -5,7 +5,7 @@ "content": [ { "queue": { - "name": "OasisLMF-PiWind-1", + "name": "OasisLMF-PiWind-1-v2", "pending_count": 8, "worker_count": 1, "queued_count": 2, @@ -41,8 +41,8 @@ "lookup_chunks": 2, "sub_task_count": 7, "queue_names": [ - "OasisLMF-PiWind-1", - "celery" + "OasisLMF-PiWind-1-v2", + "celery-v2" ], "status_count": { "TOTAL_IN_QUEUE": 6, @@ -61,7 +61,7 @@ }, { "queue": { - "name": "OasisLMF-PiWind-2", + "name": "OasisLMF-PiWind-2-v2", "pending_count": 8, "worker_count": 1, "queued_count": 2, @@ -97,8 +97,8 @@ "lookup_chunks": 2, "sub_task_count": 7, "queue_names": [ - "OasisLMF-PiWind-2", - "celery" + "OasisLMF-PiWind-2-v2", + "celery-v2" ], "status_count": { "TOTAL_IN_QUEUE": 7, @@ -195,4 +195,4 @@ "analyses": [] } ] -} \ No newline at end of file +} diff --git a/kubernetes/worker-controller/debug_local_env.sh b/kubernetes/worker-controller/debug_local_env.sh new file mode 100644 index 000000000..a46a55202 --- /dev/null +++ b/kubernetes/worker-controller/debug_local_env.sh @@ -0,0 +1,20 @@ +# base local env file to run the worker controller code, Used for debugging. +# +# Usage example: +# 1. deploy platform on mini-kube and open websocket 'kubectl port-forward deployment/oasis-websocket 8001:8001' +# 2. source this file, '. debug_local_env.sh' +# 3. install requirememts, 'pip install -r requirements.txt' +# 4. Run controller, './src/worker_controller.py' + +export OASIS_USERNAME=admin +export OASIS_PASSWORD=password +export OASIS_CONTINUE_UPDATE_SCALING=0 +export OASIS_NEVER_SHUTDOWN_FIXED_WORKERS=0 +export OASIS_API_HOST=ui.oasis.local/api +export OASIS_API_PORT='' +export OASIS_WEBSOCKET_HOST=ui.oasis.local/ws +export OASIS_WEBSOCKET_PORT='' +export OASIS_CLUSTER_NAMESPACE=default +export CLUSTER=local +export OASIS_TOTAL_WORKER_LIMIT=10 +export OASIS_PRIORITIZED_MODELS_LIMIT=10 diff --git a/kubernetes/worker-controller/requirements.in b/kubernetes/worker-controller/requirements.in index a7532e0c8..5bf44d2cb 100644 --- a/kubernetes/worker-controller/requirements.in +++ b/kubernetes/worker-controller/requirements.in @@ -1,4 +1,4 @@ websockets>=8.1 -aiohttp==3.7.4 +aiohttp>=3.7.4 kubernetes_asyncio==18.20.0 joblib>=1.2.0 diff --git a/kubernetes/worker-controller/requirements.txt b/kubernetes/worker-controller/requirements.txt index c685a7852..b8c3e4f51 100644 --- a/kubernetes/worker-controller/requirements.txt +++ b/kubernetes/worker-controller/requirements.txt @@ -4,18 +4,22 @@ # # pip-compile kubernetes/worker-controller/requirements.in # -aiohttp==3.7.4 +aiohttp==3.9.0 # via # -r kubernetes/worker-controller/requirements.in # kubernetes-asyncio -async-timeout==3.0.1 +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.3 # via aiohttp attrs==22.1.0 # via aiohttp -certifi==2022.12.7 +certifi==2023.7.22 # via kubernetes-asyncio -chardet==3.0.4 - # via aiohttp +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal idna==3.4 # via yarl joblib==1.2.0 @@ -34,8 +38,6 @@ six==1.16.0 # via # kubernetes-asyncio # python-dateutil -typing-extensions==4.4.0 - # via aiohttp urllib3==1.26.12 # via kubernetes-asyncio websockets==10.3 diff --git a/kubernetes/worker-controller/src/autoscaler.py b/kubernetes/worker-controller/src/autoscaler.py index 2cb65692e..ea48120b6 100755 --- a/kubernetes/worker-controller/src/autoscaler.py +++ b/kubernetes/worker-controller/src/autoscaler.py @@ -45,17 +45,18 @@ async def process_queue_status_message(self, msg): :param msg: The message content """ + pending_analyses: [RunningAnalysis] = await self.parse_queued_pending(msg) + logging.debug('Analyses pending: %s', pending_analyses) running_analyses: [RunningAnalysis] = await self.parse_running_analyses(msg) + logging.debug('Analyses running: %s', running_analyses) - logging.info('Analyses running: %s', running_analyses) - - model_states = self._aggregate_model_states(running_analyses) - + model_states = self._aggregate_model_states({**pending_analyses, **running_analyses}) logging.debug('Model statuses: %s', model_states) model_states_with_wd = await self._filter_model_states_with_wd(model_states) - prioritized_models = self._clear_unprioritized_models(model_states_with_wd) + v2_models = self._filter_models_by_api_version(model_states_with_wd, api_version='v2') + prioritized_models = self._clear_unprioritized_models(v2_models) await self._scale_models(prioritized_models) @@ -87,50 +88,36 @@ def _aggregate_model_states(self, analyses: []) -> dict: for wd in self.deployments.worker_deployments: id = wd.id_string() - if id not in model_states: model_state = ModelState(tasks=0, analyses=0, priority=1) model_states[id] = model_state return model_states - async def _scale_deployment(self, wd: WorkerDeployment, analysis_in_progress: bool, model_state: ModelState, limit: int) -> int: + async def _scale_deployment(self, wd: WorkerDeployment, model_state: ModelState, limit: int) -> int: """ Update a worker deployments number of replicas, based on what autoscaler_rules returns as desired number of replicas (if changed since laste time). :param wd: The WorkerDeployment for this model. - :param analysis_in_progress: Is an analysis currently running for this model? :param model_state: Number of taks available to be processed by the worker pool for this model. :param limit: None or maximum value of total number of replicas/workers for this model deployment. :return: Number of replicas set on deployment """ - desired_replicas = 0 - is_fixed_strategy = wd.auto_scaling.get('scaling_strategy') == 'FIXED_WORKERS' and self.never_shutdown_fixed_workers - - if analysis_in_progress or is_fixed_strategy: - - if analysis_in_progress: - logging.info('Analysis for model %s is running', wd.name) - if is_fixed_strategy: - logging.info('Model %s is set to "FIXED_WORKERS"', wd.name) - - try: - desired_replicas = autoscaler_rules.get_desired_worker_count(wd.auto_scaling, model_state) - - if limit is not None and desired_replicas > limit: - desired_replicas = limit - except ValueError as e: - logging.error('Could not calculate desired replicas count for model %s: %s', wd.id_string(), str(e)) + try: + desired_replicas = autoscaler_rules.get_desired_worker_count(wd.auto_scaling, model_state, self.never_shutdown_fixed_workers) + if limit is not None and desired_replicas > limit: + desired_replicas = limit + except ValueError as e: + desired_replicas = 0 + logging.error('Could not calculate desired replicas count for model %s: %s', wd.id_string(), str(e)) if desired_replicas > 0 and wd.name in self.cleanup_deployments: - if wd.name in self.cleanup_deployments: self.cleanup_deployments.remove(wd.name) if wd.replicas != desired_replicas: - if desired_replicas > 0: await self.cluster.set_replicas(wd.name, desired_replicas) else: @@ -142,7 +129,6 @@ async def _scale_deployment(self, wd: WorkerDeployment, analysis_in_progress: bo loop = asyncio.get_event_loop() self.cleanup_timer = loop.call_later(20, self._cleanup) - return wd.replicas return desired_replicas @@ -164,6 +150,33 @@ def _cleanup(self): self.cleanup_deployments.clear() + async def parse_queued_pending(self, msg) -> [RunningAnalysis]: + """ + Parse the web socket message and return a list of models with pending analyses + + :param msg: Web socket message + :return: A list of running analyses and their tasks. + """ + content: List[QueueStatusContentEntry] = msg['content'] + pending_analyses: [RunningAnalysis] = {} + + for entry in content: + analyses_list = entry['analyses'] + queue_name = entry['queue']['name'] + + # Check for pending analyses + if (queue_name not in ['celery', 'celery-v2', 'task-controller']): + queued_task_count = entry.get('queue', {}).get('queued_count', 0) # sub-task queued (API DB status) + queue_message_count = entry.get('queue', {}).get('queue_message_count', 0) # queue has messages + queued_count = max(queued_task_count, queue_message_count) + + if (queued_count > 0) and not analyses_list: + # a task is queued, but no analyses are running. + # worker-controller might have missed the analysis displatch + pending_analyses[f'pending-task_{queue_name}'] = RunningAnalysis(id=None, tasks=1, queue_names=[queue_name], priority=4) + + return pending_analyses + async def parse_running_analyses(self, msg) -> [RunningAnalysis]: """ Parse the web socket message and return a list of running analyses. @@ -172,7 +185,6 @@ async def parse_running_analyses(self, msg) -> [RunningAnalysis]: :return: A list of running analyses and their tasks. """ content: List[QueueStatusContentEntry] = msg['content'] - running_analyses: [RunningAnalysis] = {} for entry in content: @@ -185,23 +197,16 @@ async def parse_running_analyses(self, msg) -> [RunningAnalysis]: queue_names = set() task_counts = analysis.get('status_count', {}) tasks_in_queue = task_counts.get('TOTAL_IN_QUEUE', 0) + if tasks_in_queue > 0: queue_names = analysis.get('queue_names', []) - ## REPLACED: with the code block above ############ - # queue_names = set() - # sub_task_statuses = analysis['sub_task_statuses'] - # if sub_task_statuses and len(sub_task_statuses) > 0: - # for sub_task in sub_task_statuses: - # if sub_task['status'] != 'COMPLETED': - # queue_names.add(sub_task['queue_name']) - #################################################### - if tasks and tasks > 0 and len(queue_names) > 0: sa_id = analysis['id'] if sa_id not in running_analyses: priority = int(analysis.get('priority', 1)) - running_analyses[sa_id] = RunningAnalysis(id=analysis['id'], tasks=tasks, queue_names=list(queue_names), priority=priority) + running_analyses[sa_id] = RunningAnalysis(id=analysis['id'], tasks=tasks, + queue_names=list(queue_names), priority=priority) return running_analyses @@ -212,13 +217,14 @@ async def _scale_models(self, prioritized_models): :param prioritized_models: A dict of model names and their states. """ - workers_total = 0 + logging.debug('Scaling: %s', prioritized_models) for model, state, wd in prioritized_models: + workers_min = wd.auto_scaling.get('worker_count_min', 0) if wd.auto_scaling else 0 - # Load auto scaling settings everytime we scale up workers from 0 - if not wd.auto_scaling or wd.replicas == 0 or self.continue_update_scaling: + # Load auto scaling settings everytime we scale up workers from 0 or if 'worker_count_min' is set + if not wd.auto_scaling or wd.replicas == 0 or self.continue_update_scaling or workers_min > 0: if not wd.oasis_model_id: logging.info('Get oasis model id from API for model %s', wd.id_string()) @@ -228,19 +234,17 @@ async def _scale_models(self, prioritized_models): logging.error('No model id found for model %s', wd.id_string()) if wd.oasis_model_id: - wd.auto_scaling = await self.oasis_client.get_auto_scaling(wd.oasis_model_id) if wd.auto_scaling: - analysis_in_progress = state.get('tasks', 0) > 0 replicas_limit = self.limit - workers_total if self.limit else None - workers_created = await self._scale_deployment(wd, analysis_in_progress, state, replicas_limit) + workers_created = await self._scale_deployment(wd, state, replicas_limit) workers_total += workers_created else: logging.warning('No auto scaling setting found for model %s', wd.id_string()) - logging.info('Total desired number of workers: ' + str(workers_total)) + logging.debug('Total desired number of workers: ' + str(workers_total)) def _get_highest_model_priorities(self, model_states_with_wd): """ @@ -282,6 +286,20 @@ def _clear_unprioritized_models(self, model_states_with_wd): return result + def _filter_models_by_api_version(self, model_states_with_wd, api_version): + """ + Select model deployments matching an API version + + :param model_states_with_wd: List of (model, model state, WorkerDeployment) + :param api_version: String of either `v1` or `v2` + :return: Models as input model_states_with_wd, but filtered by deployment version + """ + result = [] + for model, state, wd in model_states_with_wd: + if wd.api_version == api_version.lower(): + result.append((model, state, wd), ) + return result + async def _filter_model_states_with_wd(self, model_states): """ Filter and exclude models with no worker deployment in the cluster and return a list of @@ -290,13 +308,9 @@ async def _filter_model_states_with_wd(self, model_states): :param model_states: List of model states :return: Models with an attached WorkerDeployment. """ - result = [] - for model, state in model_states.items(): - wd = await self.deployments.get_worker_deployment_by_name_id(model) - if wd: result.append((model, state, wd), ) diff --git a/kubernetes/worker-controller/src/autoscaler_rules.py b/kubernetes/worker-controller/src/autoscaler_rules.py index 55bcab46b..1ad80f629 100644 --- a/kubernetes/worker-controller/src/autoscaler_rules.py +++ b/kubernetes/worker-controller/src/autoscaler_rules.py @@ -3,7 +3,7 @@ from models import ModelState -def get_desired_worker_count(autoscaling_setting: dict, model_state: ModelState): +def get_desired_worker_count(autoscaling_setting: dict, model_state: ModelState, never_shutdown_fixed_workers: bool = False): """ This function is called for each worker deployment (model) having one or more analyses running. @@ -14,38 +14,62 @@ def get_desired_worker_count(autoscaling_setting: dict, model_state: ModelState) :param autoscaling_setting: Auto scaling configuration (see oasis API for more details) :param model_state: State of this model such as number of running analyses and tasks. + :param never_shutdown_fixed_workers: Debug model which dosn't spin down workers when in fixed mode :return: Desired number of workers to scale to. """ strategy = autoscaling_setting.get('scaling_strategy') + worker_count_min = int(autoscaling_setting.get('worker_count_min', 0)) + analysis_in_progress = any([ + model_state.get('tasks', 0) > 0, + model_state.get('analyses', 0) > 0 + ]) + + # Guard for missing options + if not strategy: + raise ValueError(f'No valid auto scaling configuration for model: {autoscaling_setting}') - if strategy: - - if strategy == 'FIXED_WORKERS': - - count = get_req_setting(autoscaling_setting, 'worker_count_fixed') - - return int(count) - - elif strategy == 'QUEUE_LOAD': - - worker_count_max = get_req_setting(autoscaling_setting, 'worker_count_max') - analyses = model_state['analyses'] - - return min(analyses, worker_count_max) - - elif strategy == 'DYNAMIC_TASKS': - - chunks_per_worker = autoscaling_setting.get('chunks_per_worker') - worker_count_max = get_req_setting(autoscaling_setting, 'worker_count_max') - - workers = math.ceil(int(model_state.get('tasks', 0)) / int(chunks_per_worker)) - return min(workers, int(worker_count_max)) + if strategy in ['QUEUE_LOAD', 'DYNAMIC_TASKS']: + worker_count_max = get_req_setting(autoscaling_setting, 'worker_count_max') + + # Debugging model (keep all fixed workers alive) + if strategy == 'FIXED_WORKERS' and never_shutdown_fixed_workers: + return max( + int(get_req_setting(autoscaling_setting, 'worker_count_fixed')), + worker_count_min, + ) + + # Scale down to Minimum worker count + if not analysis_in_progress: + return worker_count_min + + # Run a fixed set of workers when analysis is on queue + elif strategy == 'FIXED_WORKERS': + count = int(get_req_setting(autoscaling_setting, 'worker_count_fixed')) + return max( + count, + worker_count_min, + ) + + # Run one worker per analysis in progress + elif strategy == 'QUEUE_LOAD': + analyses = model_state['analyses'] + return max( + min(analyses, worker_count_max), + worker_count_min, + ) + + # Run `n` workers based on number of tasks on queue + elif strategy == 'DYNAMIC_TASKS': + chunks_per_worker = autoscaling_setting.get('chunks_per_worker') + workers = math.ceil(int(model_state.get('tasks', 0)) / int(chunks_per_worker)) + return max( + min(workers, worker_count_max), + worker_count_min, + ) - else: - raise ValueError(f'Unsupported scaling strategy: {strategy}') else: - raise ValueError(f'No valid auto scaling configuration for model: {autoscaling_setting}') + raise ValueError(f'Unsupported scaling strategy: {strategy}') def get_req_setting(autoscaling_setting: dict, name: str): diff --git a/kubernetes/worker-controller/src/cluster_client.py b/kubernetes/worker-controller/src/cluster_client.py index 720aa9d2f..f8d24d7ff 100644 --- a/kubernetes/worker-controller/src/cluster_client.py +++ b/kubernetes/worker-controller/src/cluster_client.py @@ -127,8 +127,9 @@ async def update_deployment(self, deployment, type=None): supplier_id = labels.get('oasislmf/supplier-id', '') model_id = labels.get('oasislmf/model-id', '') model_version_id = labels.get('oasislmf/model-version-id', '') + api_version = labels.get('oasislmf/api-version', '') - await self.deployments.update_worker(name, supplier_id, model_id, model_version_id, replicas) + await self.deployments.update_worker(name, supplier_id, model_id, model_version_id, api_version, replicas) async def load_deployments(self): """ diff --git a/kubernetes/worker-controller/src/oasis_client.py b/kubernetes/worker-controller/src/oasis_client.py index 0f8dc436b..16702f3fb 100644 --- a/kubernetes/worker-controller/src/oasis_client.py +++ b/kubernetes/worker-controller/src/oasis_client.py @@ -1,33 +1,45 @@ import asyncio import json -from urllib.parse import urljoin import aiohttp import time from aiohttp import ClientResponse +def urljoin(*args): + return '/'.join(s.strip('/') for s in args) + '/' + + class OasisClient: """ A simple client for the Oasis API. Takes care of the access token and supports searching for models. """ - def __init__(self, host, port, secure, username, password): + def __init__(self, http_host, http_port, http_subpath, ws_host, ws_port, secure, username, password): """ - :param host: Oasis API hostname. - :param port: Oasis API port. + :param http_host: Oasis API hostname. + :param http_port: Oasis API port. + :param ws_host: Oasis Websocket hostname. + :param ws_port: Oasis Websocket port. :param secure: Use secure connection. :param username: Username for API authentication. :param password: Password for API authentication. """ - self.host = host - self.port = port + self.ws_host = ws_host + self.ws_port = ws_port self.secure = secure - self.http_host = ('https://' if secure else 'http://') + f'{host}:{port}' + + api_proto = 'https://' if secure else 'http://' + api_host = http_host + api_port = f':{http_port}' if http_port else '' + api_path = f'/{http_subpath}' if http_subpath else '' + self.http_host = f'{api_proto}{api_host}{api_port}{api_path}' + self.username = username self.password = password self.access_token = None self.token_expire_time = None + print('Connecting to: ' + self.http_host) def is_authenticated(self) -> bool: """ @@ -47,9 +59,7 @@ async def authenticate(self): params = {'username': self.username, 'password': self.password} async with session.post(urljoin(self.http_host, '/access_token/'), data=params) as response: - data = await self.parse_answer(response) - self.access_token = data['access_token'] self.token_expire_time = time.time() + round(data['expires_in'] / 2) @@ -91,7 +101,7 @@ async def get_oasis_model_id(self, supplier_id: str, model_id: str, model_versio 'model_id': model_id, 'version_id': model_version_id } - models = await self._get('/v1/models/', params) + models = await self._get('/v2/models/', params) for model in models: if model['supplier_id'] == supplier_id and model['model_id'] == model_id and model['version_id'] == model_version_id: @@ -107,7 +117,7 @@ async def get_auto_scaling(self, model_id): await self.authenticate_if_needed() - model = await self._get(f'/v1/models/{model_id}/scaling_configuration/') + model = await self._get(f'/v2/models/{model_id}/scaling_configuration/') return model diff --git a/kubernetes/worker-controller/src/oasis_websocket.py b/kubernetes/worker-controller/src/oasis_websocket.py index c8c0103da..6a6e689d0 100644 --- a/kubernetes/worker-controller/src/oasis_websocket.py +++ b/kubernetes/worker-controller/src/oasis_websocket.py @@ -29,9 +29,11 @@ async def __aenter__(self): access_token = await self.oasis_client.get_access_token() + # https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html#websockets.client.connect self.connection = websockets.connect( - urljoin(f'{self.ws_scheme}{self.oasis_client.host}:{self.oasis_client.port}', '/ws/v1/queue-status/'), - extra_headers={'AUTHORIZATION': f'Bearer {access_token}'} + urljoin(f'{self.ws_scheme}{self.oasis_client.ws_host}:{self.oasis_client.ws_port}', '/ws/v2/queue-status/'), + extra_headers={'AUTHORIZATION': f'Bearer {access_token}'}, + ping_interval=None, ) return await self.connection.__aenter__() @@ -84,15 +86,24 @@ async def watch(self): while running: try: async with WebSocketConnection(self.oasis_client) as socket: - + logging.info(f'Connected to ws: {self.oasis_client.ws_host}:{self.oasis_client.ws_port}') async for msg in next_msg(socket): - logging.info('Socket message: %s', msg) + logging.debug('Socket message: %s', msg) await self.autoscaler.process_queue_status_message(msg) + except ConnectionClosedError as e: - logging.exception(f'Connection to {self.oasis_client.host}:{self.oasis_client.port} was closed') - running = False + """ + Websockets has an auto-reconnect with exponential back-off built in + See: + https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html#opening-a-connection + https://github.com/python-websockets/websockets/issues/414 + """ + logging.info(f'Connection to {self.oasis_client.ws_host}:{self.oasis_client.ws_port} lost, reconnecting') + logging.debug(f'Connection to {self.oasis_client.ws_host}:{self.oasis_client.ws_port} was closed', e) + continue + except (WebSocketException, ClientError) as e: - logging.exception(f'Connection to {self.oasis_client.host}:{self.oasis_client.port} failed', e) + logging.exception(f'Connection to {self.oasis_client.ws_host}:{self.oasis_client.ws_port} failed', e) running = False except Exception as e: logging.exception(f'Unexpected web socket exception thrown', e) diff --git a/kubernetes/worker-controller/src/tests/test_autoscaler.py b/kubernetes/worker-controller/src/tests/test_autoscaler.py index e20d87e34..06cca8f44 100644 --- a/kubernetes/worker-controller/src/tests/test_autoscaler.py +++ b/kubernetes/worker-controller/src/tests/test_autoscaler.py @@ -20,16 +20,17 @@ class TestAutoscaler(unittest.TestCase): def test_parse_queue_status(self): - wd1 = WorkerDeployment('worker-oasislmf-piwind-1', 'oasislmf', 'piwind', '1') - wd2 = WorkerDeployment('worker-oasislmf-piwind-2', 'oasislmf', 'piwind', '2') - wd3 = WorkerDeployment('worker-oasislmf-piwind-3', 'oasislmf', 'piwind', '3') + wd1 = WorkerDeployment('worker-oasislmf-piwind-1-v2', 'oasislmf', 'piwind', '1', 'v2') + wd2 = WorkerDeployment('worker-oasislmf-piwind-2-v2', 'oasislmf', 'piwind', '2', 'v2') + wd3 = WorkerDeployment('worker-oasislmf-piwind-3-v2', 'oasislmf', 'piwind', '3', 'v2') + wd4 = WorkerDeployment('worker-oasislmf-piwind-4-v1', 'oasislmf', 'piwind', '4', 'v1') wd1.replicas = 0 wd2.replicas = 2 wd3.replicas = 1 + wd4.replicas = 0 async def get_worker_deployment_by_name_id(model): - - for wd in [wd1, wd2, wd3]: + for wd in [wd1, wd2, wd3, wd4]: if wd.id_string().lower() == model: return wd @@ -67,7 +68,6 @@ async def get_auto_scaling(oasis_model_id): with open(queue_status_json) as json_file: data = json.load(json_file) - autoscaler = AutoScaler(deployments, cluster_client, client, None, None, False, False) asyncio.run(autoscaler.process_queue_status_message(data)) @@ -75,14 +75,12 @@ async def get_auto_scaling(oasis_model_id): cluster_client.set_replicas.assert_has_calls(calls, any_order=True) def test_model_status_prio(self): - - wd1 = WorkerDeployment('worker-oasislmf-piwind-1', 'oasislmf', 'piwind', '1') - wd2 = WorkerDeployment('worker-oasislmf-piwind-2', 'oasislmf', 'piwind', '2') + wd1 = WorkerDeployment('worker-oasislmf-piwind-1', 'oasislmf', 'piwind', '1', 'v1') + wd2 = WorkerDeployment('worker-oasislmf-piwind-2', 'oasislmf', 'piwind', '2', 'v2') wd1.replicas = 0 wd2.replicas = 2 async def get_worker_deployment_by_name_id(model): - for wd in [wd1, wd2]: if wd.id_string().lower() == model: return wd @@ -96,20 +94,20 @@ async def get_worker_deployment_by_name_id(model): model_states = { 'celery': {'tasks': 10, 'analyses': 2, 'priority': 1}, 'model-worker-broadcast': {'tasks': 5, 'analyses': 1, 'priority': 2}, - 'oasislmf-piwind-1': {'tasks': 10, 'analyses': 2, 'priority': 2}, - 'oasislmf-piwind-2': {'tasks': 5, 'analyses': 1, 'priority': 1}, + 'oasislmf-piwind-1-v1': {'tasks': 10, 'analyses': 2, 'priority': 2}, + 'oasislmf-piwind-2-v2': {'tasks': 5, 'analyses': 1, 'priority': 1}, 'oasislmf-piwind-3': {'tasks': 5, 'analyses': 1, 'priority': 10}, } prioritized_model_states = asyncio.run(autoscaler._filter_model_states_with_wd(model_states)) self.assertEqual(2, len(prioritized_model_states)) first = prioritized_model_states[0] - self.assertEqual('oasislmf-piwind-1', first[0]) + self.assertEqual('oasislmf-piwind-1-v1', first[0]) self.assertEqual(10, first[1].get('tasks')) self.assertEqual(2, first[1].get('analyses')) self.assertEqual(2, first[1].get('priority')) second = prioritized_model_states[1] - self.assertEqual('oasislmf-piwind-2', second[0]) + self.assertEqual('oasislmf-piwind-2-v2', second[0]) self.assertEqual(1, second[1].get('priority')) def test_get_highest_model_priorities(self): @@ -192,9 +190,6 @@ async def get_auto_scaling(oasis_model_id): self.assertEqual(2, oasis_client.get_auto_scaling.call_count) def test_never_shutdown_fixed_workers(self): - - autoscaler = AutoScaler(None, None, None, None, None, False, True) - wd1 = WorkerDeployment('worker-oasislmf-piwind-1', 'oasislmf', 'piwind', '1') wd1.auto_scaling = { 'scaling_strategy': 'FIXED_WORKERS', @@ -203,10 +198,12 @@ def test_never_shutdown_fixed_workers(self): wd1.replicas = 2 model_state = ModelState(tasks=10, analyses=2, priority=5) - desired_replicas = asyncio.run(autoscaler._scale_deployment(wd1, True, model_state, 10)) + autoscaler = AutoScaler(None, None, None, None, None, False, never_shutdown_fixed_workers=True) + desired_replicas = asyncio.run(autoscaler._scale_deployment(wd1, model_state, 10)) self.assertEqual(2, desired_replicas) - desired_replicas = asyncio.run(autoscaler._scale_deployment(wd1, False, model_state, 10)) + autoscaler = AutoScaler(None, None, None, None, None, False, never_shutdown_fixed_workers=False) + desired_replicas = asyncio.run(autoscaler._scale_deployment(wd1, model_state, 10)) self.assertEqual(2, desired_replicas) diff --git a/kubernetes/worker-controller/src/tests/test_autoscaler_rules.py b/kubernetes/worker-controller/src/tests/test_autoscaler_rules.py index f7c08db7a..0b95459fa 100644 --- a/kubernetes/worker-controller/src/tests/test_autoscaler_rules.py +++ b/kubernetes/worker-controller/src/tests/test_autoscaler_rules.py @@ -20,9 +20,33 @@ def test_fixed_correct(self): 'scaling_strategy': 'FIXED_WORKERS', 'worker_count_fixed': 5 } + state = { + 'analyses': 3 + } + desired_replicas = autoscaler_rules.get_desired_worker_count(as_conf, state) + + self.assertEqual(5, desired_replicas) + + def test_min_workers_correct(self): + + as_conf = { + 'scaling_strategy': 'FIXED_WORKERS', + 'worker_count_fixed': 5, + 'worker_count_min': 3 + } state = {} desired_replicas = autoscaler_rules.get_desired_worker_count(as_conf, state) + self.assertEqual(3, desired_replicas) + + def test_min_workers__scale_up_correct(self): + as_conf = { + 'scaling_strategy': 'FIXED_WORKERS', + 'worker_count_fixed': 5, + 'worker_count_min': 1 + } + state = {'analyses': 1} + desired_replicas = autoscaler_rules.get_desired_worker_count(as_conf, state) self.assertEqual(5, desired_replicas) def test_fixed_incorrect_missing_size(self): @@ -31,7 +55,7 @@ def test_fixed_incorrect_missing_size(self): 'scaling_strategy': 'FIXED_WORKERS' } state = {} - self.assertRaises(ValueError, lambda: autoscaler_rules.get_desired_worker_count(as_conf, state)) + self.assertRaises(ValueError, lambda: autoscaler_rules.get_desired_worker_count(as_conf, state, never_shutdown_fixed_workers=True)) def test_queue_load_correct(self): @@ -131,7 +155,6 @@ def test_chunks_per_worker_incorrect_config(self): as_conf = { 'scaling_strategy': 'DYNAMIC_TASKS', - } state = {} self.assertRaises(ValueError, lambda: autoscaler_rules.get_desired_worker_count(as_conf, state)) diff --git a/kubernetes/worker-controller/src/worker_controller.py b/kubernetes/worker-controller/src/worker_controller.py index bb2f286fe..3ea53feed 100755 --- a/kubernetes/worker-controller/src/worker_controller.py +++ b/kubernetes/worker-controller/src/worker_controller.py @@ -30,8 +30,6 @@ from oasis_websocket import OasisWebSocket from autoscaler import AutoScaler -logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO) - def str2bool(v): """ Func type for loading strings to boolean values using argparse @@ -56,21 +54,26 @@ def parse_args(): :return: """ parser = argparse.ArgumentParser('Oasis example model worker controller') - parser.add_argument('--api-host', help='The sever API hostname', default=getenv('OASIS_API_HOST') or 'localhost') - parser.add_argument('--api-port', help='The server API portnumber', default=getenv('OASIS_API_PORT') or 8000) + parser.add_argument('--api-host', help='The sever API hostname', default=getenv('OASIS_API_HOST', default='localhost')) + parser.add_argument('--api-port', help='The server API portnumber', default=getenv('OASIS_API_PORT', default=8000)) + parser.add_argument('--api-subpath', help='The server API subpath, e.g. "api"', default=getenv('OASIS_API_SUBPATH', default='')) + parser.add_argument('--websocket-host', help='The websocket hostname', default=getenv('OASIS_WEBSOCKET_HOST', default='localhost')) + parser.add_argument('--websocket-port', help='The websocket portnumber', default=getenv('OASIS_WEBSOCKET_PORT', default=8001)) parser.add_argument('--secure', help='Flag if https and wss should be used', default=bool(getenv('OASIS_API_SECURE')), action='store_true') - parser.add_argument('--username', help='The username of the worker controller user', default=getenv('OASIS_USERNAME') or 'admin') - parser.add_argument('--password', help='The password of the worker controller user', default=getenv('OASIS_PASSWORD') or 'password') - parser.add_argument('--namespace', help='Namespace of cluster where oasis is deployed to', default=getenv('OASIS_CLUSTER_NAMESPACE') or 'default') + parser.add_argument('--username', help='The username of the worker controller user', default=getenv('OASIS_USERNAME', default='admin')) + parser.add_argument('--password', help='The password of the worker controller user', default=getenv('OASIS_PASSWORD', default='password')) + parser.add_argument('--namespace', help='Namespace of cluster where oasis is deployed to', + default=getenv('OASIS_CLUSTER_NAMESPACE', default='default')) parser.add_argument('--limit', help='Hard limit for the total number of workers created', default=getenv('OASIS_TOTAL_WORKER_LIMIT')) parser.add_argument('--prioritized-models-limit', help='When prioritized runs are used - create workers for the models with the highest priority', default=getenv('OASIS_PRIORITIZED_MODELS_LIMIT')) parser.add_argument( - '--cluster', help='Type of kubernetes cluster to connect to, either "local" (~/.kube/config) or "in" to connect to the cluster the pod exists in', default=getenv('CLUSTER') or 'in') + '--cluster', help='Type of kubernetes cluster to connect to, either "local" (~/.kube/config) or "in" to connect to the cluster the pod exists in', default=getenv('CLUSTER', default='in')) parser.add_argument('--continue-update-scaling', help='Auto scaling - read the scaling settings from the API for a model on every update. (for testing)', - type=str2bool, default=getenv('OASIS_CONTINUE_UPDATE_SCALING') or False) + type=str2bool, default=getenv('OASIS_CONTINUE_UPDATE_SCALING', default=False)) parser.add_argument('--never-shutdown-fixed-workers', help='Auto scaling - never scale to 0 for strategy FIXED_WORKERS.', - type=str2bool, default=getenv('OASIS_NEVER_SHUTDOWN_FIXED_WORKERS') or False) + type=str2bool, default=getenv('OASIS_NEVER_SHUTDOWN_FIXED_WORKERS', default=False)) + parser.add_argument('--log-level', help='The logging level', default=getenv('OASIS_LOGLEVEL', default='INFO')) args = parser.parse_args() @@ -89,26 +92,41 @@ def main(): Entrypoint. Parse arguments, creates client for oasis and kubernetes cluster and starts tasks to monitor changes. """ - args = parse_args() - # Create an oasis client - oasis_client = OasisClient(args.api_host, args.api_port, args.secure, args.username, args.password) + cli_args = parse_args() + oasis_client = OasisClient( + cli_args.api_host, + cli_args.api_port, + cli_args.api_subpath, + cli_args.websocket_host, + cli_args.websocket_port, + cli_args.secure, + cli_args.username, + cli_args.password + ) + + # Set worker-controller logger + logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=cli_args.log_level) + + # Set Websocket logger + logger = logging.getLogger('websockets') + logger.setLevel(cli_args.log_level) + logger.addHandler(logging.StreamHandler()) # Cache to keep track of all deployments in the cluster - deployments: worker_deployments.WorkerDeployments = worker_deployments.WorkerDeployments(args) - + deployments: worker_deployments.WorkerDeployments = worker_deployments.WorkerDeployments(cli_args) event_loop = asyncio.get_event_loop() # Create cluster client and load configuration - cluster_client = ClusterClient(args.namespace) - event_loop.run_until_complete(cluster_client.load_config(args.cluster)) + cluster_client = ClusterClient(cli_args.namespace) + event_loop.run_until_complete(cluster_client.load_config(cli_args.cluster)) # Create the autoscaler to bind everything together - autoscaler = AutoScaler(deployments, cluster_client, oasis_client, args.prioritized_models_limit, args.limit, - args.continue_update_scaling, args.never_shutdown_fixed_workers) + autoscaler = AutoScaler(deployments, cluster_client, oasis_client, cli_args.prioritized_models_limit, cli_args.limit, + cli_args.continue_update_scaling, cli_args.never_shutdown_fixed_workers) # Create the deployment watcher and load all available deployments - deployments_watcher = DeploymentWatcher(args.namespace, deployments) + deployments_watcher = DeploymentWatcher(cli_args.namespace, deployments) event_loop.run_until_complete(deployments_watcher.load_deployments()) deployments.print_list() diff --git a/kubernetes/worker-controller/src/worker_deployments.py b/kubernetes/worker-controller/src/worker_deployments.py index 2484e0668..5979fc0f1 100644 --- a/kubernetes/worker-controller/src/worker_deployments.py +++ b/kubernetes/worker-controller/src/worker_deployments.py @@ -11,17 +11,20 @@ class WorkerDeployment: is actually not a part of the worker deployment in the cluster, it comes from the API, but kept here anyway. """ - def __init__(self, name: str, supplier_id: str, model_id: str, model_version_id: str): + def __init__(self, name: str, supplier_id: str, model_id: str, model_version_id: str, api_version: str = None): self.name = name self.supplier_id = supplier_id self.model_id = model_id self.model_version_id = model_version_id + self.api_version = api_version self.oasis_model_id: int = None self.replicas: int = None self.auto_scaling = None def id_string(self): - return f'{self.supplier_id}-{self.model_id}-{self.model_version_id}'.lower() + if self.api_version is None: + return f'{self.supplier_id}-{self.model_id}-{self.model_version_id}'.lower() + return f'{self.supplier_id}-{self.model_id}-{self.model_version_id}-{self.api_version}'.lower() class WorkerDeployments: @@ -40,7 +43,7 @@ def __init__(self, args): self.username = args.username self.password = args.password - async def update_worker(self, name, supplier_id, model_id, model_version_id, replicas: int): + async def update_worker(self, name, supplier_id, model_id, model_version_id, api_version, replicas: int): """ Update the cache of worker deployments. Create a new if not found or update an already existing one. @@ -48,18 +51,19 @@ async def update_worker(self, name, supplier_id, model_id, model_version_id, rep :param supplier_id: Model supplier id :param model_id: Model id :param model_version_id: Model version + :param api_version: 'V1' or 'V2' depening on the worker type :param replicas: Number of replicas currenty in kubernetes. :param auto_scaling: Auto scaling settings for this worker deployment if available. :return: """ new_deployment = False - model_id_string = f'{supplier_id}-{model_id}-{model_version_id}' + model_id_string = f'{supplier_id}-{model_id}-{model_version_id}-{api_version}' - wd: WorkerDeployment = self.get_worker_deployment(supplier_id, model_id, model_version_id) + wd: WorkerDeployment = self.get_worker_deployment(supplier_id, model_id, model_version_id, api_version) if wd is None: new_deployment = True - wd = WorkerDeployment(name, supplier_id, model_id, model_version_id) + wd = WorkerDeployment(name, supplier_id, model_id, model_version_id, api_version) self.worker_deployments.append(wd) logging.info('Deployment %s: New', model_id_string) @@ -69,19 +73,32 @@ async def update_worker(self, name, supplier_id, model_id, model_version_id, rep wd.replicas = replicas - def get_worker_deployment(self, supplier_id, model_id, model_version_id) -> WorkerDeployment: + def get_worker_deployment(self, supplier_id, model_id, model_version_id, api_version=None) -> WorkerDeployment: """ Return the WorkerDeployment for the given supplier, model and version. :param supplier_id: Model supplier id :param model_id: Model id :param model_version_id: Model version + :param api_version: Worker API type 'v1' / 'v2' :return: A WorkerDeployment or None if not found. """ - - for wd in self.worker_deployments: - if wd.supplier_id.lower() == supplier_id.lower() and wd.model_id.lower() == model_id.lower() and wd.model_version_id.lower() == model_version_id.lower(): - return wd + # Check API version for match + if api_version is not None: + for wd in self.worker_deployments: + if all([wd.supplier_id.lower() == supplier_id.lower(), + wd.model_id.lower() == model_id.lower(), + wd.model_version_id.lower() == model_version_id.lower(), + wd.api_version.lower() == api_version.lower()]): + return wd + + # return any matching model (ignore API version) + else: + for wd in self.worker_deployments: + if all([wd.supplier_id.lower() == supplier_id.lower(), + wd.model_id.lower() == model_id.lower(), + wd.model_version_id.lower() == model_version_id.lower()]): + return wd async def get_worker_deployment_by_name_id(self, name_id: str) -> WorkerDeployment: """ @@ -90,10 +107,11 @@ async def get_worker_deployment_by_name_id(self, name_id: str) -> WorkerDeployme :param name_id: String with format -- :return: A WorkerDeployment or None if not found. """ - split = name_id.split('-') if len(split) == 3: return self.get_worker_deployment(split[0], split[1], split[2]) + if len(split) == 4: + return self.get_worker_deployment(split[0], split[1], split[2], split[3]) async def delete(self, name): """ diff --git a/requirements-server.txt b/requirements-server.txt index 576b40e51..f1901cbda 100644 --- a/requirements-server.txt +++ b/requirements-server.txt @@ -24,7 +24,7 @@ attrs==22.2.0 # jsonschema # service-identity # twisted -autobahn==23.1.2 +autobahn==23.6.2 # via daphne automat==22.10.0 # via twisted @@ -44,7 +44,7 @@ celery==5.3.0 # via # -r requirements-server.in # django-celery-results -certifi==2022.12.7 +certifi==2023.7.22 # via requests cffi==1.15.1 # via cryptography @@ -82,7 +82,9 @@ coreschema==0.0.4 # via # coreapi # drf-yasg -cryptography==40.0.1 +cramjam==2.8.1 + # via fastparquet +cryptography==41.0.2 # via # autobahn # azure-storage-blob @@ -92,7 +94,7 @@ cryptography==40.0.1 # service-identity daphne==2.5.0 # via channels -django==3.2.20 +django==3.2.23 # via # channels # django-celery-results @@ -119,7 +121,9 @@ django-model-utils==4.3.1 django-request-logging==0.7.5 # via -r requirements-server.in django-storages[azure]==1.13.2 - # via -r requirements-server.in + # via + # -r requirements-server.in + # django-storages djangorestframework==3.14.0 # via # -r requirements-server.in @@ -132,6 +136,12 @@ drf-nested-routers==0.93.4 # via -r requirements-server.in drf-yasg==1.21.5 # via -r requirements-server.in +fastparquet==2023.10.1 + # via oasis-data-manager +fsspec==2024.2.0 + # via + # fastparquet + # oasis-data-manager greenlet==2.0.2 # via sqlalchemy gunicorn==20.1.0 @@ -177,6 +187,8 @@ jsonschema==4.17.3 # ods-tools kombu==5.3.0 # via celery +llvmlite==0.41.1 + # via numba markdown==3.4.3 # via -r requirements-server.in markupsafe==2.1.2 @@ -187,19 +199,28 @@ msgpack==0.6.2 # via channels-redis mysqlclient==2.1.1 # via -r requirements-server.in +numba==0.58.0 + # via ods-tools numpy==1.24.2 # via + # fastparquet + # numba # pandas # pyarrow -ods-tools==3.1.0 +oasis-data-manager==0.1.1 + # via ods-tools +ods-tools==3.2.0 # via -r requirements-server.in packaging==23.0 # via # drf-yasg + # fastparquet # ods-tools pandas==1.5.3 # via # -r requirements-server.in + # fastparquet + # oasis-data-manager # ods-tools pathlib2==2.3.7.post1 # via -r requirements-server.in @@ -207,7 +228,7 @@ prompt-toolkit==3.0.38 # via click-repl psycopg2-binary==2.9.6 # via -r requirements-server.in -pyarrow==11.0.0 +pyarrow==14.0.1 # via -r requirements-server.in pyasn1==0.4.8 # via @@ -221,7 +242,7 @@ pyjwt==2.6.0 # via djangorestframework-simplejwt pymysql==1.1.0 # via -r requirements-server.in -pyopenssl==23.1.1 +pyopenssl==23.2.0 # via # josepy # twisted @@ -273,13 +294,18 @@ sqlparse==0.4.4 # django # django-debug-toolbar twisted[tls]==22.10.0 - # via daphne + # via + # daphne + # twisted txaio==23.1.1 # via autobahn +typing==3.7.4.3 + # via oasis-data-manager typing-extensions==4.5.0 # via # azure-core # azure-storage-blob + # oasis-data-manager # twisted tzdata==2023.3 # via celery diff --git a/requirements-worker.txt b/requirements-worker.txt index 27c3c79e6..2ea3fbad8 100644 --- a/requirements-worker.txt +++ b/requirements-worker.txt @@ -35,7 +35,7 @@ botocore==1.29.107 # s3transfer celery==5.3.0 # via -r requirements-worker.in -certifi==2022.12.7 +certifi==2023.7.22 # via # fiona # pyproj @@ -72,14 +72,16 @@ configparser==5.3.0 # via -r requirements-worker.in cramjam==2.6.2 # via fastparquet -cryptography==40.0.1 +cryptography==41.0.2 # via azure-storage-blob exceptiongroup==1.1.1 # via pytest fasteners==0.18 # via -r requirements-worker.in fastparquet==2023.2.0 - # via oasislmf + # via + # oasis-data-manager + # oasislmf filelock==3.10.7 # via -r requirements-worker.in fiona==1.9.2 @@ -87,7 +89,9 @@ fiona==1.9.2 forex-python==1.8 # via oasislmf fsspec==2023.3.0 - # via fastparquet + # via + # fastparquet + # oasis-data-manager geopandas==0.12.2 # via oasislmf greenlet==2.0.2 @@ -112,7 +116,7 @@ jsonschema==4.17.3 # via ods-tools kombu==5.3.0 # via celery -llvmlite==0.39.1 +llvmlite==0.41.1 # via numba msgpack==1.0.5 # via oasislmf @@ -120,8 +124,10 @@ munch==2.5.0 # via fiona natsort==8.3.1 # via -r requirements-worker.in -numba==0.56.4 - # via oasislmf +numba==0.58.0 + # via + # oasislmf + # ods-tools numexpr==2.8.4 # via oasislmf numpy==1.22.4 @@ -135,9 +141,15 @@ numpy==1.22.4 # scikit-learn # scipy # shapely -oasislmf[extra]==1.28.1 - # via -r requirements-worker.in -ods-tools==3.1.0 +oasis-data-manager==0.1.1 + # via + # oasislmf + # ods-tools +oasislmf[extra]==2.3.0 + # via + # -r requirements-worker.in + # oasislmf +ods-tools==3.2.0 # via oasislmf packaging==23.0 # via @@ -149,6 +161,7 @@ pandas==1.5.3 # via # fastparquet # geopandas + # oasis-data-manager # oasislmf # ods-tools pathlib2==2.3.7.post1 @@ -159,7 +172,7 @@ prompt-toolkit==3.0.38 # via click-repl psycopg2-binary==2.9.6 # via -r requirements-worker.in -pyarrow==11.0.0 +pyarrow==14.0.1 # via oasislmf pycparser==2.21 # via cffi @@ -229,10 +242,13 @@ tomli==2.0.1 # via pytest tqdm==4.65.0 # via oasislmf +typing==3.7.4.3 + # via oasis-data-manager typing-extensions==4.5.0 # via # azure-core # azure-storage-blob + # oasis-data-manager tzdata==2023.3 # via celery urllib3==1.26.15 @@ -246,6 +262,3 @@ vine==5.0.0 # kombu wcwidth==0.2.6 # via prompt-toolkit - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/requirements.txt b/requirements.txt index 88c4577fa..f8bca548d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ attrs==22.2.0 # pytest # service-identity # twisted -autobahn==23.1.2 +autobahn==23.6.2 # via daphne automat==22.10.0 # via twisted @@ -74,7 +74,7 @@ celery==5.3.0 # -r ./requirements-server.in # -r ./requirements-worker.in # django-celery-results -certifi==2022.12.7 +certifi==2023.7.22 # via # fiona # pyproj @@ -137,7 +137,7 @@ coverage[toml]==7.2.2 # pytest-cov cramjam==2.6.2 # via fastparquet -cryptography==40.0.1 +cryptography==41.0.2 # via # autobahn # azure-storage-blob @@ -153,7 +153,7 @@ decorator==5.1.1 # ipython distlib==0.3.6 # via virtualenv -django==3.2.20 +django==3.2.23 # via # channels # django-celery-results @@ -183,7 +183,9 @@ django-model-utils==4.3.1 django-request-logging==0.7.5 # via -r ./requirements-server.in django-storages[azure]==1.13.2 - # via -r ./requirements-server.in + # via + # -r ./requirements-server.in + # django-storages django-webtest==1.9.10 # via -r requirements.in djangorestframework==3.14.0 @@ -209,7 +211,9 @@ fasteners==0.18 # -r ./requirements-worker.in # -r requirements.in fastparquet==2023.2.0 - # via oasislmf + # via + # oasis-data-manager + # oasislmf filelock==3.10.7 # via # -r ./requirements-worker.in @@ -224,7 +228,9 @@ forex-python==1.8 freezegun==1.2.2 # via -r requirements.in fsspec==2023.3.0 - # via fastparquet + # via + # fastparquet + # oasis-data-manager geopandas==0.12.2 # via oasislmf greenlet==2.0.2 @@ -285,7 +291,7 @@ jsonschema==4.17.3 # ods-tools kombu==5.3.0 # via celery -llvmlite==0.39.1 +llvmlite==0.41.1 # via numba markdown==3.4.3 # via -r ./requirements-server.in @@ -311,8 +317,10 @@ mysqlclient==2.1.1 # via -r ./requirements-server.in natsort==8.3.1 # via -r ./requirements-worker.in -numba==0.56.4 - # via oasislmf +numba==0.58.0 + # via + # oasislmf + # ods-tools numexpr==2.8.4 # via oasislmf numpy==1.22.4 @@ -326,9 +334,15 @@ numpy==1.22.4 # scikit-learn # scipy # shapely -oasislmf[extra]==1.28.1 - # via -r ./requirements-worker.in -ods-tools==3.1.0 +oasis-data-manager==0.1.1 + # via + # oasislmf + # ods-tools +oasislmf[extra]==2.3.0 + # via + # -r ./requirements-worker.in + # oasislmf +ods-tools==3.2.0 # via # -r ./requirements-server.in # oasislmf @@ -347,6 +361,7 @@ pandas==1.5.3 # -r ./requirements-server.in # fastparquet # geopandas + # oasis-data-manager # oasislmf # ods-tools parso==0.8.3 @@ -381,7 +396,7 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pyarrow==11.0.0 +pyarrow==14.0.1 # via # -r ./requirements-server.in # oasislmf @@ -405,7 +420,7 @@ pymysql==1.1.0 # via # -r ./requirements-server.in # -r ./requirements-worker.in -pyopenssl==23.1.1 +pyopenssl==23.2.0 # via # -r requirements.in # josepy @@ -534,13 +549,18 @@ traitlets==5.9.0 # ipython # matplotlib-inline twisted[tls]==22.10.0 - # via daphne + # via + # daphne + # twisted txaio==23.1.1 # via autobahn +typing==3.7.4.3 + # via oasis-data-manager typing-extensions==4.5.0 # via # azure-core # azure-storage-blob + # oasis-data-manager # twisted tzdata==2023.3 # via celery diff --git a/scripts/deploy.sh b/scripts/deploy.sh index dac78e1bd..18ac70af5 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -1,25 +1,32 @@ #!/bin/bash + SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PIWIND_PATH_FILE=$SCRIPT_DIR/piwind-path-cfg cd $SCRIPT_DIR; cd .. pwd +# Store PiWind Path +if ! [ -f $PIWIND_PATH_FILE ]; then + touch $PIWIND_PATH_FILE +else + source $PIWIND_PATH_FILE +fi if [[ -z $OASIS_MODEL_DATA_DIR ]]; then - echo -n "Path: " + echo -n "OasisPiwind Repo Path: " read filepath if [ ! -d $filepath ]; then echo "Please insert a correct path" sleep 1 $SCRIPT_DIR/deploy.sh fi - export OASIS_MODEL_DATA_DIR=$filepath - #echo "path=$filepath" >> .gitpush - #echo "$Path saved in the following file: $cur/.gitpush" + export OASIS_MODEL_DATA_DIR=$filepath + echo "export OASIS_MODEL_DATA_DIR=$filepath" > $PIWIND_PATH_FILE fi -docker rmi coreoasis/api_server:dev -docker rmi coreoasis/model_worker:dev +#docker rmi coreoasis/api_server:dev +#docker rmi coreoasis/model_worker:dev # Check for prev install and offer to clean wipe if [[ $(docker volume ls | grep OasisData -c) -gt 1 ]]; then @@ -27,6 +34,7 @@ if [[ $(docker volume ls | grep OasisData -c) -gt 1 ]]; then docker volume ls | grep OasisData | awk 'BEGIN { FS = "[ \t\n]+" }{ print $2 }' | xargs -r docker volume rm fi +set -e docker build -f Dockerfile.api_server -t coreoasis/api_server:dev . docker build -f Dockerfile.model_worker -t coreoasis/model_worker:dev . docker-compose up -d diff --git a/scripts/minikube-deploy.sh b/scripts/minikube-deploy.sh new file mode 100755 index 000000000..933c71e11 --- /dev/null +++ b/scripts/minikube-deploy.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +# cd to repo root +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +PIWIND_PATH_FILE=$SCRIPT_DIR/piwind-path-cfg +cd $SCRIPT_DIR/.. + +# Store PiWind Path +if ! [ -f $PIWIND_PATH_FILE ]; then + touch $PIWIND_PATH_FILE +else + source $PIWIND_PATH_FILE +fi + +if [[ -z $OASIS_MODEL_DATA_DIR ]]; then + echo -n "OasisPiwind Repo Path: " + read filepath + if [ ! -d $filepath ]; then + echo "Please insert a correct path" + sleep 1 + $SCRIPT_DIR/deploy.sh + fi + export OASIS_MODEL_DATA_DIR=$filepath + echo "export OASIS_MODEL_DATA_DIR=$filepath" > $PIWIND_PATH_FILE +fi + +## init minikube +# minikube delete +# minikube config set cpus 12 +# minikube config set memory 16000 +# minikube start + +# build images +eval $(minikube docker-env) +set -e + docker build -f Dockerfile.api_server -t coreoasis/api_server:dev . + docker build -f Dockerfile.model_worker -t coreoasis/model_worker:dev . + + pushd kubernetes/worker-controller + docker build -t coreoasis/worker_controller:dev . + popd +set +e + +# Upload piwind data +./kubernetes/scripts/k8s/upload_piwind_model_data.sh $OASIS_MODEL_DATA_DIR + +# update / apply charts +pushd kubernetes/charts + if ! helm status platform; then + helm install platform oasis-platform + else + helm upgrade platform oasis-platform + fi + + if ! helm status models; then + helm install models oasis-models + else + helm upgrade models oasis-models + fi +popd + + +# kubectl scale --replicas=0 deployment/oasis-worker-controller + +# Open local access to cluster +# +#minikube tunnel & +# kubectl get svc --template="{{range .items}}{{range .status.loadBalancer.ingress}}{{.ip}}{{end}}{{end}}" + +# Open single service +#kubectl port-forward deployment/oasis-websocket 8001:8001 #(forward websocket) + + diff --git a/scripts/update-changelog.py b/scripts/update-changelog.py index 7a20cc64f..12b59a47c 100755 --- a/scripts/update-changelog.py +++ b/scripts/update-changelog.py @@ -346,7 +346,7 @@ def release_plat_header(self, tag_platform=None, tag_oasislmf=None, tag_ods=None plat_header.append(f'* [coreoasis/oasisui_proxy:{t_ui}](https://hub.docker.com/r/coreoasis/oasisui_proxy/tags?name={t_ui})\n') plat_header.append('## Components\n') plat_header.append(f'* [oasislmf {t_lmf}](https://github.com/OasisLMF/OasisLMF/releases/tag/{t_lmf})\n') - plat_header.append(f'* [ods-tools {t_ods}](https://github.com/OasisLMF/OasisLMF/releases/tag/{t_ods})\n') + plat_header.append(f'* [ods-tools {t_ods}](https://github.com/OasisLMF/ODS_Tools/releases/tag/{t_ods})\n') plat_header.append(f'* [ktools {t_ktools}](https://github.com/OasisLMF/ktools/releases/tag/{t_ktools})\n') plat_header.append(f'* [Oasis UI {t_ui}](https://github.com/OasisLMF/OasisUI/releases/tag/{t_ui})\n') plat_header.append('\n') @@ -427,7 +427,7 @@ def check_rate_limit(github_token): @cli.command() -@click.option('--repo', type=click.Choice(['ktools', 'OasisLMF', 'OasisPlatform', 'OasisUI'], case_sensitive=True), required=True) +@click.option('--repo', type=click.STRING, required=True, help="Oasislmf Repo name case sensitive, ['ktools', 'OasisLMF', 'OasisUI' ..]") @click.option('--output-path', type=click.Path(exists=False), default='./CHANGELOG.rst', help='changelog output path') @click.option('--local-repo-path', type=click.Path(exists=False), default=None, help=' Path to local git repository, used to skip clone step (optional) ') @click.option('--from-tag', required=True, help='Github tag to track changes from') @@ -473,7 +473,7 @@ def build_changelog(repo, from_tag, to_tag, github_token, output_path, apply_mil @cli.command() -@click.option('--repo', type=click.Choice(['ktools', 'OasisLMF', 'OasisUI'], case_sensitive=True), required=True) +@click.option('--repo', type=click.STRING, required=True, help="Oasislmf Repo name case sensitive, ['ktools', 'OasisLMF', 'OasisUI' ..]") @click.option('--output-path', type=click.Path(exists=False), default='./RELEASE.md', help='Release notes output path') @click.option('--local-repo-path', type=click.Path(exists=False), default=None, help=' Path to local git repository, used to skip clone step (optional) ') @click.option('--from-tag', required=True, help='Github tag to track changes from') diff --git a/src/conf/base.py b/src/conf/base.py new file mode 100644 index 000000000..91e8dcdab --- /dev/null +++ b/src/conf/base.py @@ -0,0 +1,32 @@ +import urllib +from src.conf.iniconf import settings + +#: Celery config - IP address of the server running RabbitMQ and Celery +BROKER_URL = settings.get( + 'celery', + 'broker_url', + fallback="amqp://{RABBIT_USER}:{RABBIT_PASS}@{RABBIT_HOST}:{RABBIT_PORT}//".format( + RABBIT_USER=settings.get('celery', 'rabbit_user', fallback='rabbit'), + RABBIT_PASS=settings.get('celery', 'rabbit_pass', fallback='rabbit'), + RABBIT_HOST=settings.get('celery', 'rabbit_host', fallback='127.0.0.1'), + RABBIT_PORT=settings.get('celery', 'rabbit_port', fallback='5672'), + ) +) + +#: Celery config - result backend URI +CELERY_RESULTS_DB_BACKEND = settings.get('celery', 'DB_ENGINE', fallback='db+sqlite') +if CELERY_RESULTS_DB_BACKEND == 'db+sqlite': + CELERY_RESULT_BACKEND = '{DB_ENGINE}:///{DB_NAME}'.format( + DB_ENGINE=CELERY_RESULTS_DB_BACKEND, + DB_NAME=settings.get('celery', 'db_name', fallback='celery.db.sqlite'), + ) +else: + CELERY_RESULT_BACKEND = '{DB_ENGINE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}{SSL_MODE}'.format( + DB_ENGINE=settings.get('celery', 'db_engine'), + DB_USER=urllib.parse.quote(settings.get('celery', 'db_user')), + DB_PASS=urllib.parse.quote(settings.get('celery', 'db_pass')), + DB_HOST=settings.get('celery', 'db_host'), + DB_PORT=settings.get('celery', 'db_port'), + DB_NAME=settings.get('celery', 'db_name', fallback='celery'), + SSL_MODE=settings.get('celery', 'db_ssl_mode', fallback='?sslmode=prefer'), + ) diff --git a/src/conf/celeryconf_v1.py b/src/conf/celeryconf_v1.py new file mode 100644 index 000000000..9fc096479 --- /dev/null +++ b/src/conf/celeryconf_v1.py @@ -0,0 +1,38 @@ +from .base import * + +# Default Queue Name +CELERY_DEFAULT_QUEUE = "celery" + +#: Celery config - ignore result? +CELERY_IGNORE_RESULT = False + +#: Celery config - AMQP task result expiration time +CELERY_AMQP_TASK_RESULT_EXPIRES = 1000 + +#: Celery config - task serializer +CELERY_TASK_SERIALIZER = 'json' + +#: Celery config - result serializer +CELERY_RESULT_SERIALIZER = 'json' + +#: Celery config - accept content type +CELERY_ACCEPT_CONTENT = ['json'] + +#: Celery config - timezone (default == UTC) +# CELERY_TIMEZONE = 'Europe/London' + +#: Celery config - enable UTC +CELERY_ENABLE_UTC = True + +#: Celery config - concurrency +CELERYD_CONCURRENCY = 1 + +#: Disable celery task prefetch +#: https://docs.celeryproject.org/en/stable/userguide/configuration.html#std-setting-worker_prefetch_multiplier +CELERYD_PREFETCH_MULTIPLIER = 1 + +worker_task_kwargs = { + 'autoretry_for': (Exception,), + 'max_retries': 2, # The task will be run max_retries + 1 times + 'default_retry_delay': 6, # A small delay to recover from temporary bad states +} diff --git a/src/conf/celeryconf.py b/src/conf/celeryconf_v2.py similarity index 53% rename from src/conf/celeryconf.py rename to src/conf/celeryconf_v2.py index c3c65c2a0..0e6c1f05c 100644 --- a/src/conf/celeryconf.py +++ b/src/conf/celeryconf_v2.py @@ -1,42 +1,15 @@ from celery.schedules import crontab from kombu.common import Broadcast -import urllib from src.conf.iniconf import settings +from .base import * + +# Default Queue Name +CELERY_DEFAULT_QUEUE = "celery-v2" #: Celery config - ignore result? CELERY_IGNORE_RESULT = False -#: Celery config - IP address of the server running RabbitMQ and Celery -BROKER_URL = settings.get( - 'celery', - 'broker_url', - fallback="amqp://{RABBIT_USER}:{RABBIT_PASS}@{RABBIT_HOST}:{RABBIT_PORT}//".format( - RABBIT_USER=settings.get('celery', 'rabbit_user', fallback='rabbit'), - RABBIT_PASS=settings.get('celery', 'rabbit_pass', fallback='rabbit'), - RABBIT_HOST=settings.get('celery', 'rabbit_host', fallback='127.0.0.1'), - RABBIT_PORT=settings.get('celery', 'rabbit_port', fallback='5672'), - ) -) - -#: Celery config - result backend URI -CELERY_RESULTS_DB_BACKEND = settings.get('celery', 'DB_ENGINE', fallback='db+sqlite') -if CELERY_RESULTS_DB_BACKEND == 'db+sqlite': - CELERY_RESULT_BACKEND = '{DB_ENGINE}:///{DB_NAME}'.format( - DB_ENGINE=CELERY_RESULTS_DB_BACKEND, - DB_NAME=settings.get('celery', 'db_name', fallback='celery.db.sqlite'), - ) -else: - CELERY_RESULT_BACKEND = '{DB_ENGINE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}{SSL_MODE}'.format( - DB_ENGINE=settings.get('celery', 'db_engine'), - DB_USER=urllib.parse.quote(settings.get('celery', 'db_user')), - DB_PASS=urllib.parse.quote(settings.get('celery', 'db_pass')), - DB_HOST=settings.get('celery', 'db_host'), - DB_PORT=settings.get('celery', 'db_port'), - DB_NAME=settings.get('celery', 'db_name', fallback='celery'), - SSL_MODE=settings.get('celery', 'db_ssl_mode', fallback='?sslmode=prefer'), - ) - #: Celery config - AMQP task result expiration time CELERY_AMQP_TASK_RESULT_EXPIRES = 1000 @@ -59,16 +32,12 @@ # CELERYD_CONCURRENCY = 1 #: Disable celery task prefetch -#: https://docs.celeryproject.org/en/stable/userguide/configuration.html#std-setting-worker_prefetch_multiplier CELERYD_PREFETCH_MULTIPLIER = 1 -### Added from Arch2020 branch ### - # setup queues so that tasks aren't removed from the queue until # complete and reschedule if the task worker goes offline CELERY_ACKS_LATE = True CELERY_REJECT_ON_WORKER_LOST = True - CELERY_TASK_QUEUES = (Broadcast('model-worker-broadcast'), ) # Highest priority available @@ -77,8 +46,9 @@ # Set to make internal and subtasks inherit priority CELERY_INHERIT_PARENT_PRIORITY = True - # setup the beat schedule + + def crontab_from_string(s): minute, hour, day_of_week, day_of_month, month_of_year = s.split(' ') return crontab( diff --git a/src/model_execution_worker/distributed_tasks.py b/src/model_execution_worker/distributed_tasks.py index fa61397a1..50c95deee 100644 --- a/src/model_execution_worker/distributed_tasks.py +++ b/src/model_execution_worker/distributed_tasks.py @@ -27,7 +27,7 @@ from pathlib2 import Path from ..common.data import ORIGINAL_FILENAME, STORED_FILENAME -from ..conf import celeryconf as celery_conf +from ..conf import celeryconf_v2 as celery_conf from ..conf.iniconf import settings from .backends.aws_storage import AwsObjectStore from .backends.azure_storage import AzureObjectStore @@ -44,6 +44,8 @@ app = Celery() app.config_from_object(celery_conf) +# print(app._conf) + logging.info("Started worker") debug_worker = settings.getboolean('worker', 'DEBUG', fallback=False) @@ -195,9 +197,9 @@ def notify_api_status(analysis_pk, task_status): task_status )) signature( - 'set_task_status', + 'set_task_status_v2', args=(analysis_pk, task_status, datetime.now().timestamp()), - queue='celery' + queue='celery-v2' ).delay() @@ -255,8 +257,9 @@ def register_worker(sender, **k): logging.info('settings: {}'.format(m_settings)) signature( - 'run_register_worker', + 'run_register_worker_v2', args=(m_supplier, m_name, m_id, m_settings, m_version, m_conf), + queue='celery-v2', ).delay() # Required ENV @@ -479,6 +482,7 @@ def _prepare_directories(params, analysis_id, run_data_uuid, kwargs): params.setdefault('target_dir', params['root_run_dir']) params.setdefault('user_data_dir', os.path.join(params['root_run_dir'], 'user-data')) params.setdefault('lookup_complex_config_json', os.path.join(params['root_run_dir'], 'analysis_settings.json')) + params.setdefault('analysis_settings_json', os.path.join(params['root_run_dir'], 'analysis_settings.json')) # Generate keys files params.setdefault('keys_fp', os.path.join(params['root_run_dir'], 'keys.csv')) @@ -521,6 +525,7 @@ def _prepare_directories(params, analysis_id, run_data_uuid, kwargs): maybe_fetch_file(settings_file, params['lookup_complex_config_json']) else: params['lookup_complex_config_json'] = None + params['analysis_settings_json'] = None if complex_data_files: maybe_prepare_complex_data_files(complex_data_files, params['user_data_dir']) else: @@ -831,7 +836,7 @@ def cleanup_input_generation(self, params, analysis_id=None, initiator_id=None, filestore.delete_file(params.get('pre_scope_file')) params['log_location'] = filestore.put(kwargs.get('log_filename')) - return params + return {'input-location_generate-and-run': params.get('output_location')} # --- loss generation tasks ------------------------------------------------ # @@ -899,7 +904,11 @@ def _prepare_directories(params, analysis_id, run_data_uuid, kwargs): params.setdefault('user_data_dir', os.path.join(params['root_run_dir'], 'user-data')) params.setdefault('analysis_settings_json', os.path.join(params['root_run_dir'], 'analysis_settings.json')) + # case for 'generate_and_run' + if params.get('input-location_generate-and-run'): + kwargs['input_location'] = params.get('input-location_generate-and-run') input_location = kwargs.get('input_location') + if input_location: maybe_extract_tar( input_location, @@ -965,6 +974,7 @@ def prepare_losses_generation_params( loss_path_vars = [ 'model_data_dir', 'model_settings_json', + 'post_analysis_module', ] for path_val in loss_path_vars: @@ -978,9 +988,18 @@ def prepare_losses_generation_params( else: run_params[path_val] = None - params = OasisManager()._params_generate_losses(**run_params) + gen_losses_params = OasisManager()._params_generate_losses(**run_params) + post_hook_params = OasisManager()._params_post_analysis(**run_params) + params = {**gen_losses_params, **post_hook_params} + params['log_location'] = filestore.put(kwargs.get('log_filename')) params['verbose'] = debug_worker + + # needed incase input_data is missing on another node + input_tar_generate_and_run = run_params.get('input-location_generate-and-run') + if input_tar_generate_and_run: + params['input-location_generate-and-run'] = input_tar_generate_and_run + return params @@ -1006,7 +1025,6 @@ def generate_losses_chunk(self, params, chunk_idx, num_chunks, analysis_id=None, current_chunk_id = None max_chunk_id = -1 work_dir = 'work' - else: # Run a single ktools pipe current_chunk_id = chunk_idx + 1 @@ -1051,8 +1069,10 @@ def generate_losses_output(self, params, analysis_id=None, slug=None, **kwargs): merge_dirs(d, abs_work_dir) OasisManager().generate_losses_output(**res) - res['bash_trace'] = "" + if res.get('post_analysis_module', None): + OasisManager().post_analysis(**res) + res['bash_trace'] = "" return { **res, 'output_location': filestore.put(os.path.join(res['model_run_dir'], 'output'), arcname='output'), diff --git a/src/model_execution_worker/storage_manager.py b/src/model_execution_worker/storage_manager.py index 72ee4783d..d6d68ad31 100755 --- a/src/model_execution_worker/storage_manager.py +++ b/src/model_execution_worker/storage_manager.py @@ -178,7 +178,7 @@ def _fetch_file(self, reference, output_path, subdir): raise MissingInputsException(fpath) def filepath(self, reference): - """ return the absolute filepath + """ return the absolute filepath """ fpath = os.path.join( self.media_root, @@ -204,11 +204,15 @@ def extract(self, archive_fp, directory, storage_subdir=''): temp_dir = tempfile.TemporaryDirectory() try: temp_dir_path = temp_dir.__enter__() - local_archive_path = self.get( - archive_fp, - os.path.join(temp_dir_path, os.path.basename(archive_fp)), - subdir=storage_subdir - ) + if os.path.isfile(archive_fp): + local_archive_path = os.path.abspath(archive_fp) + else: + local_archive_path = self.get( + archive_fp, + os.path.join(temp_dir_path, os.path.basename(archive_fp)), + subdir=storage_subdir + ) + with tarfile.open(local_archive_path) as f: f.extractall(directory) finally: diff --git a/src/model_execution_worker/tasks.py b/src/model_execution_worker/tasks.py index 0718d2c92..882310b88 100755 --- a/src/model_execution_worker/tasks.py +++ b/src/model_execution_worker/tasks.py @@ -7,6 +7,7 @@ import sys import shutil import subprocess +import time import fasteners import tempfile @@ -20,14 +21,14 @@ from celery.exceptions import WorkerLostError, Terminated from celery.platforms import signals -from oasislmf.utils.data import get_json +# from oasislmf.utils.data import get_json from oasislmf.utils.exceptions import OasisException from oasislmf.utils.log import oasis_log from oasislmf.utils.status import OASIS_TASK_STATUS from oasislmf import __version__ as mdk_version from pathlib2 import Path -from ..conf import celeryconf as celery_conf +from ..conf import celeryconf_v1 as celery_conf from ..conf.iniconf import settings from ..common.data import STORED_FILENAME, ORIGINAL_FILENAME @@ -45,6 +46,7 @@ RUNNING_TASK_STATUS = OASIS_TASK_STATUS["running"]["id"] app = Celery() app.config_from_object(celery_conf) +# print(app._conf) logging.info("Started worker") debug_worker = settings.getboolean('worker', 'DEBUG', fallback=False) @@ -172,19 +174,18 @@ def check_worker_lost(task, analysis_pk): ) task.update_state(state=RUNNING_TASK_STATUS, meta={'analysis_pk': analysis_pk}) - # When a worker connects send a task to the worker-monitor to register a new model + + @worker_ready.connect def register_worker(sender, **k): + time.sleep(1) # Workaround, pause for 1 sec to makesure log messages are printed m_supplier = os.environ.get('OASIS_MODEL_SUPPLIER_ID') m_name = os.environ.get('OASIS_MODEL_ID') m_id = os.environ.get('OASIS_MODEL_VERSION_ID') m_version = get_worker_versions() - m_conf = get_json(get_oasislmf_config_path(m_id)) - logging.info('Worker: SUPPLIER_ID={}, MODEL_ID={}, VERSION_ID={}'.format(m_supplier, m_name, m_id)) logging.info('versions: {}'.format(m_version)) - logging.info('oasislmf config: {}'.format(m_conf)) # Check for 'DISABLE_WORKER_REG' before sending task to API if settings.getboolean('worker', 'DISABLE_WORKER_REG', fallback=False): @@ -198,7 +199,8 @@ def register_worker(sender, **k): signature( 'run_register_worker', - args=(m_supplier, m_name, m_id, m_settings, m_version, m_conf), + args=(m_supplier, m_name, m_id, m_settings, m_version), + queue='celery' ).delay() # Required ENV @@ -289,7 +291,7 @@ def notify_api_status(analysis_pk, task_status): 'set_task_status', args=(analysis_pk, task_status), queue='celery' - ).delay({}, priority=analysis_pk) + ).delay() @app.task(name='run_analysis', bind=True, acks_late=True, throws=(Terminated,)) diff --git a/tests/__init__.py b/src/model_execution_worker/tests/__init__.py similarity index 100% rename from tests/__init__.py rename to src/model_execution_worker/tests/__init__.py diff --git a/tests/localstack/create-bucket.sh b/src/model_execution_worker/tests/localstack/create-bucket.sh similarity index 100% rename from tests/localstack/create-bucket.sh rename to src/model_execution_worker/tests/localstack/create-bucket.sh diff --git a/tests/test_settings.py b/src/model_execution_worker/tests/test_settings.py similarity index 100% rename from tests/test_settings.py rename to src/model_execution_worker/tests/test_settings.py diff --git a/tests/test_tasks.py b/src/model_execution_worker/tests/test_tasks.py similarity index 98% rename from tests/test_tasks.py rename to src/model_execution_worker/tests/test_tasks.py index c6030a091..966b909a2 100644 --- a/tests/test_tasks.py +++ b/src/model_execution_worker/tests/test_tasks.py @@ -9,11 +9,12 @@ from hypothesis import given from hypothesis import settings as hypothesis_settings from hypothesis.strategies import text, integers -from mock import patch, Mock, ANY +from mock import patch, Mock +# from mock import ANY from pathlib2 import Path from src.conf.iniconf import SettingsPatcher, settings -from src.model_execution_worker.storage_manager import MissingInputsException +# from src.model_execution_worker.storage_manager import MissingInputsException from src.model_execution_worker.tasks import start_analysis, InvalidInputsException, \ start_analysis_task, get_oasislmf_config_path diff --git a/src/server/oasisapi/analyses/__init__.py b/src/server/oasisapi/analyses/__init__.py index 012f01bc6..e69de29bb 100644 --- a/src/server/oasisapi/analyses/__init__.py +++ b/src/server/oasisapi/analyses/__init__.py @@ -1 +0,0 @@ -default_app_config = 'src.server.oasisapi.analyses.apps.AnalysesAppConfig' diff --git a/src/server/oasisapi/analyses/apps.py b/src/server/oasisapi/analyses/apps.py index 5a57c757d..0c1a5fe7a 100644 --- a/src/server/oasisapi/analyses/apps.py +++ b/src/server/oasisapi/analyses/apps.py @@ -1,12 +1,16 @@ from django.apps import AppConfig -class AnalysesAppConfig(AppConfig): - name = 'src.server.oasisapi.analyses' +class V1_AnalysesAppConfig(AppConfig): + name = 'src.server.oasisapi.analyses.v1_api' + + +class V2_AnalysesAppConfig(AppConfig): + name = 'src.server.oasisapi.analyses.v2_api' def ready(self): from django.db.models.signals import post_save - from .signal_receivers import task_updated + from .v2_api.signal_receivers import task_updated from .models import AnalysisTaskStatus post_save.connect(task_updated, sender=AnalysisTaskStatus) diff --git a/src/server/oasisapi/analyses/management/commands/ws_echo.py b/src/server/oasisapi/analyses/management/commands/ws_echo.py index 358f3b226..5a64b3436 100644 --- a/src/server/oasisapi/analyses/management/commands/ws_echo.py +++ b/src/server/oasisapi/analyses/management/commands/ws_echo.py @@ -22,7 +22,7 @@ def on_error(app, error): class Command(BaseCommand): def add_arguments(self, parser): - parser.add_argument('--url', default='ws://localhost:8001/ws/v1/queue-status/') + parser.add_argument('--url', default='ws://localhost:8001/ws/v2/queue-status/') def handle(self, *args, **options): user = get_user_model().objects.first() diff --git a/src/server/oasisapi/analyses/migrations/0012_analysis_run_mode.py b/src/server/oasisapi/analyses/migrations/0012_analysis_run_mode.py new file mode 100644 index 000000000..f6def93ca --- /dev/null +++ b/src/server/oasisapi/analyses/migrations/0012_analysis_run_mode.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2023-12-04 09:01 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('analyses', '0011_auto_20230724_1134'), + ] + + operations = [ + migrations.AddField( + model_name='analysis', + name='run_mode', + field=models.CharField(choices=[('V1', 'Single-Instance Execution'), ('V2', 'Distributed Execution')], default=None, editable=False, max_length=2, null=True), + ), + ] diff --git a/src/server/oasisapi/analyses/migrations/0013_analysis_chunking_options.py b/src/server/oasisapi/analyses/migrations/0013_analysis_chunking_options.py new file mode 100644 index 000000000..9bf77f309 --- /dev/null +++ b/src/server/oasisapi/analyses/migrations/0013_analysis_chunking_options.py @@ -0,0 +1,20 @@ +# Generated by Django 3.2.20 on 2024-01-12 13:17 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('analysis_models', '0008_analysismodel_run_mode'), + ('analyses', '0012_analysis_run_mode'), + ] + + operations = [ + migrations.AddField( + model_name='analysis', + name='chunking_options', + field=models.OneToOneField(auto_created=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to='analysis_models.modelchunkingoptions'), + ), + ] diff --git a/src/server/oasisapi/analyses/models.py b/src/server/oasisapi/analyses/models.py index eea526644..2a78371aa 100644 --- a/src/server/oasisapi/analyses/models.py +++ b/src/server/oasisapi/analyses/models.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function from celery.result import AsyncResult -from django.conf import settings +from django.conf import settings as django_settings from django.core.files.base import File from django.core.validators import MinValueValidator, MaxValueValidator from django.db import models @@ -14,10 +14,11 @@ from rest_framework.exceptions import ValidationError from rest_framework.reverse import reverse -from src.server.oasisapi.celery_app import celery_app +from src.server.oasisapi.celery_app_v1 import v1 as celery_app_v1 +from src.server.oasisapi.celery_app_v2 import v2 as celery_app_v2 from src.server.oasisapi.queues.consumers import send_task_status_message, TaskStatusMessageItem, \ TaskStatusMessageAnalysisItem, build_task_status_message -from ..analysis_models.models import AnalysisModel +from ..analysis_models.models import AnalysisModel, ModelChunkingOptions from ..data_files.models import DataFile from ..files.models import RelatedFile, file_storage_link from ..portfolios.models import Portfolio @@ -25,6 +26,8 @@ from ....common.data import STORED_FILENAME, ORIGINAL_FILENAME from ....conf import iniconf +from .v1_api.tasks import record_generate_input_result, record_run_analysis_result + class AnalysisTaskStatusQuerySet(models.QuerySet): @classmethod @@ -125,11 +128,13 @@ class Meta: ('analysis', 'slug',) ) - def get_output_log_url(self, request=None): - return reverse('analysis-task-status-output-log', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_output_log_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else 'v2-analyses:' + return reverse(f'{override_ns}analysis-task-status-output-log', kwargs={'pk': self.pk}, request=request) - def get_error_log_url(self, request=None): - return reverse('analysis-task-status-error-log', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_error_log_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else 'v2-analyses:' + return reverse(f'{override_ns}analysis-task-status-error-log', kwargs={'pk': self.pk}, request=request) class Analysis(TimeStampedModel): @@ -146,16 +151,23 @@ class Analysis(TimeStampedModel): ('RUN_CANCELLED', 'Run cancelled'), ('RUN_ERROR', 'Run error'), ) + run_mode_choices = Choices( + ('V1', 'Single-Instance Execution'), + ('V2', 'Distributed Execution'), + ) input_generation_traceback_file_id = None - creator = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name='analyses') + creator = models.ForeignKey(django_settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name='analyses') portfolio = models.ForeignKey(Portfolio, on_delete=models.CASCADE, related_name='analyses', help_text=_('The portfolio to link the analysis to')) model = models.ForeignKey(AnalysisModel, on_delete=models.CASCADE, related_name='analyses', help_text=_('The model to link the analysis to')) name = models.CharField(help_text='The name of the analysis', max_length=255) status = models.CharField(max_length=max(len(c) for c in status_choices._db_values), choices=status_choices, default=status_choices.NEW, editable=False) + run_mode = models.CharField(max_length=max(len(c) for c in run_mode_choices._db_values), + choices=run_mode_choices, default=None, editable=False, null=True) task_started = models.DateTimeField(editable=False, null=True, default=None) + task_finished = models.DateTimeField(editable=False, null=True, default=None) run_task_id = models.CharField(max_length=255, editable=False, default='', blank=True) generate_inputs_task_id = models.CharField(max_length=255, editable=False, default='', blank=True) @@ -186,6 +198,8 @@ class Analysis(TimeStampedModel): summary_levels_file = models.ForeignKey(RelatedFile, on_delete=models.CASCADE, blank=True, null=True, default=None, related_name='summary_levels_file_analyses') + chunking_options = models.OneToOneField(ModelChunkingOptions, on_delete=models.CASCADE, auto_created=True, default=None, null=True) + class Meta: ordering = ['id'] verbose_name_plural = 'analyses' @@ -193,65 +207,103 @@ class Meta: def __str__(self): return self.name - def get_absolute_url(self, request=None): - return reverse('analysis-detail', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def _update_ns(self, request=None): + """ WORKAROUND - this is needed for when a copy request is issued + from the portfolio view '/{ver}/portfolios/{id}/create_analysis/' + + The inncorrect namespace '{ver}-portfolios' is inherited from the + original request. This needs to be replaced with '{ver}-analyses' + """ + if not request: + return None + ns_ver, ns_view = request.version.split('-') + if ns_view != 'analyses': + request.version = f'{ns_ver}-analyses' + return request - def get_absolute_run_url(self, request=None): - return reverse('analysis-run', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-detail', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_cancel_url(self, request=None): - return reverse('analysis-cancel', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_run_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-run', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_cancel_analysis_url(self, request=None): - return reverse('analysis-cancel-analysis-run', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_cancel_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-cancel', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_generate_inputs_url(self, request=None): - return reverse('analysis-generate-inputs', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_cancel_analysis_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-cancel-analysis-run', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_cancel_inputs_generation_url(self, request=None): - return reverse('analysis-cancel-generate-inputs', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_generate_inputs_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-generate-inputs', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_copy_url(self, request=None): - return reverse('analysis-copy', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_cancel_inputs_generation_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-cancel-generate-inputs', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_settings_file_url(self, request=None): - return reverse('analysis-settings-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_copy_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-copy', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_settings_url(self, request=None): - return reverse('analysis-settings', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_settings_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-settings-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_input_file_url(self, request=None): - return reverse('analysis-input-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_settings_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-settings', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_lookup_errors_file_url(self, request=None): - return reverse('analysis-lookup-errors-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_input_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-input-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_lookup_success_file_url(self, request=None): - return reverse('analysis-lookup-success-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_lookup_errors_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-lookup-errors-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_lookup_validation_file_url(self, request=None): - return reverse('analysis-lookup-validation-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_lookup_success_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-lookup-success-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_summary_levels_file_url(self, request=None): - return reverse('analysis-summary-levels-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_lookup_validation_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-lookup-validation-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_input_generation_traceback_file_url(self, request=None): - return reverse('analysis-input-generation-traceback-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_summary_levels_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-summary-levels-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_output_file_url(self, request=None): - return reverse('analysis-output-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_input_generation_traceback_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-input-generation-traceback-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_run_traceback_file_url(self, request=None): - return reverse('analysis-run-traceback-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_output_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-output-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_run_log_file_url(self, request=None): - return reverse('analysis-run-log-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_run_traceback_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-run-traceback-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_storage_url(self, request=None): - return reverse('analysis-storage-links', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_run_log_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-run-log-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) - def get_absolute_subtask_list_url(self, request=None): - return reverse('analysis-sub-task-list', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_storage_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-storage-links', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def get_absolute_subtask_list_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-sub-task-list', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def get_absolute_chunking_configuration_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-chunking-configuration', kwargs={'pk': self.pk}, request=self._update_ns(request)) def get_groups(self): groups = [] @@ -261,11 +313,19 @@ def get_groups(self): return groups def get_num_events(self): - selected_strat = self.model.chunking_options.loss_strategy - dynamic_strat = self.model.chunking_options.chunking_types.DYNAMIC_CHUNKS - if selected_strat != dynamic_strat: - return 1 + # Select Chunking opts + if self.chunking_options is None: + chunking_options = self.model.chunking_options + else: + chunking_options = self.chunking_options + + # Esc if not DYNAMIC + DYNAMIC_CHUNKS = chunking_options.chunking_types.DYNAMIC_CHUNKS + if chunking_options.loss_strategy != DYNAMIC_CHUNKS: + return None + + # Return num of selected User events from settings analysis_settings = self.settings_file.read_json() model_settings = self.model.resource_file.read_json() user_selected_events = analysis_settings.get('event_ids', []) @@ -273,6 +333,7 @@ def get_num_events(self): if len(user_selected_events) > 1: return len(user_selected_events) + # Read event_set_size for model settings event_set_options = model_settings.get('model_settings', {}).get('event_set').get('options', []) event_set_sizes = {e['id']: e['number_of_events'] for e in event_set_options if 'number_of_events' in e} if selected_event_set not in event_set_sizes: @@ -280,6 +341,69 @@ def get_num_events(self): f"Failed to read event set size for chunking: selected event_id: {selected_event_set}, options: {event_set_options}") return event_set_sizes.get(selected_event_set) + # --- V1 task signatures ------------------------------------------------ # + + @property + def v1_run_analysis_signature(self): + complex_data_files = self.create_complex_model_data_file_dicts() + input_file = file_storage_link(self.input_file) + settings_file = file_storage_link(self.settings_file) + + return celery_app_v1.signature( + 'run_analysis', + args=(self.pk, input_file, settings_file, complex_data_files), + queue=self.model.queue_name, + ) + + @property + def v1_generate_input_signature(self): + loc_file = file_storage_link(self.portfolio.location_file) + acc_file = file_storage_link(self.portfolio.accounts_file) + info_file = file_storage_link(self.portfolio.reinsurance_info_file) + scope_file = file_storage_link(self.portfolio.reinsurance_scope_file) + settings_file = file_storage_link(self.settings_file) + complex_data_files = self.create_complex_model_data_file_dicts() + + return celery_app_v1.signature( + 'generate_input', + args=(self.pk, loc_file, acc_file, info_file, scope_file, settings_file, complex_data_files), + queue=self.model.queue_name, + ) + + # --- V2 task signatures ------------------------------------------------ # + + @property + def v2_run_analysis_signature(self): + return celery_app_v2.signature( + 'start_loss_generation_task', + options={'queue': iniconf.settings.get('worker', 'LOSSES_GENERATION_CONTROLLER_QUEUE', fallback='celery-v2')} + ) + + @property + def v2_start_input_and_loss_generation_signature(self): + return celery_app_v2.signature( + 'start_input_and_loss_generation_task', + options={'queue': iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery-v2')} + ) + + @property + def v2_cancel_subtasks_signature(self): + return celery_app_v2.signature( + 'cancel_subtasks', + options={ + 'queue': iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery-v2') + } + ) + + @property + def v2_generate_input_signature(self): + return celery_app_v2.signature( + 'start_input_generation_task', + options={ + 'queue': iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery-v2') + } + ) + def validate_run(self): valid_choices = [ self.status_choices.READY, @@ -303,8 +427,14 @@ def validate_run(self): if not self.input_file: errors['input_file'] = ['Must not be null'] + # Get chunking options + if self.chunking_options is None: + chunking_options = self.model.chunking_options + else: + chunking_options = self.chunking_options + # Valadation for dyanmic loss chunks - if self.model.chunking_options.loss_strategy == self.model.chunking_options.chunking_types.DYNAMIC_CHUNKS: + if chunking_options.loss_strategy == chunking_options.chunking_types.DYNAMIC_CHUNKS: if not self.model.resource_file: errors['model_settings_file'] = ['Must not be null for Dynamic chunking'] elif self.settings_file: @@ -334,27 +464,46 @@ def run_callback(self, body): self.status = self.status_choices.RUN_STARTED self.save() - @property - def run_analysis_signature(self): - return celery_app.signature( - 'start_loss_generation_task', - options={'queue': iniconf.settings.get('worker', 'LOSSES_GENERATION_CONTROLLER_QUEUE', fallback='celery')} - ) - - def run(self, initiator): + def run(self, initiator, run_mode_override=None): self.validate_run() + if (self.model.run_mode is None) and (run_mode_override is None): + raise ValidationError({ + 'model': ['Model pk "{}" - "run_mode" must not be null'.format(self.model.id)] + }) + run_mode = run_mode_override if run_mode_override else self.model.run_mode + valid_run_modes = [ + self.run_mode_choices.V1, + self.run_mode_choices.V2, + ] + if run_mode not in valid_run_modes: + raise ValidationError( + {'run_mode': ['run_mode must be [{}]'.format(', '.join(valid_run_modes))]} + ) + events_total = self.get_num_events() self.status = self.status_choices.RUN_QUEUED self.save() - task = self.run_analysis_signature - task.on_error(celery_app.signature('handle_task_failure', kwargs={ - 'analysis_id': self.pk, - 'initiator_id': initiator.pk, - 'traceback_property': 'run_traceback_file', - 'failure_status': Analysis.status_choices.RUN_ERROR, - })) - task_id = task.apply_async(args=[self.pk, initiator.pk, events_total], priority=self.priority).id + if run_mode == self.run_mode_choices.V1: + task = self.v1_run_analysis_signature + task.link(record_run_analysis_result.s(self.pk, initiator.pk)) + task.link_error( + celery_app_v1.signature('on_error', args=('record_run_analysis_failure', self.pk, initiator.pk), queue=self.model.queue_name) + ) + self.status = self.status_choices.RUN_QUEUED + self.run_mode = self.run_mode_choices.V1 + task_id = task.delay().id + + elif run_mode == self.run_mode_choices.V2: + task = self.v2_run_analysis_signature + task.on_error(celery_app_v2.signature('handle_task_failure', kwargs={ + 'analysis_id': self.pk, + 'initiator_id': initiator.pk, + 'traceback_property': 'run_traceback_file', + 'failure_status': Analysis.status_choices.RUN_ERROR, + }, queue='celery-v2')) + self.run_mode = self.run_mode_choices.V2 + task_id = task.apply_async(args=[self.pk, initiator.pk, events_total], priority=self.priority).id self.run_task_id = task_id self.task_started = timezone.now() @@ -368,13 +517,6 @@ def raise_validate_errors(self, errors, error_state=None): if errors: raise ValidationError(detail=errors) - @property - def start_input_and_loss_generation_signature(self): - return celery_app.signature( - 'start_input_and_loss_generation_task', - options={'queue': iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery')} - ) - def generate_and_run(self, initiator): valid_choices = [ self.status_choices.NEW, @@ -392,10 +534,23 @@ def generate_and_run(self, initiator): errors['status'] = ['Analysis status must be one of [{}]'.format(', '.join(valid_choices))] if self.model.deleted: errors['model'] = ['Model pk "{}" has been deleted'.format(self.model.id)] + if self.model.run_mode != self.model.run_mode_choices.V2: + errors['model'] = ['Model pk "{}" - Unsupported Operation, "run_mode" must be "V2", not "{}"'.format(self.model.id, self.model.run_mode)] if not self.settings_file: errors['settings_file'] = ['Must not be null'] if not self.portfolio.location_file: errors['portfolio'] = ['"location_file" must not be null'] + else: + # get loc lines + try: + loc_lines = self.portfolio.location_file_len() + except Exception as e: + errors['portfolio'] = [f"Failed to read location file size for chunking: {e}"] + if loc_lines < 1: + errors['portfolio'] = ['"location_file" must at least one row'] + + # get events + events_total = self.get_num_events() # Raise for error self.raise_validate_errors(errors) @@ -406,15 +561,17 @@ def generate_and_run(self, initiator): self.lookup_validation_file = None self.summary_levels_file = None self.input_generation_traceback_file_id = None + self.input_file = None - task = self.start_input_and_loss_generation_signature - task.on_error(celery_app.signature('handle_task_failure', kwargs={ + task = self.v2_start_input_and_loss_generation_signature + task.on_error(celery_app_v2.signature('handle_task_failure', kwargs={ 'analysis_id': self.pk, 'initiator_id': initiator.pk, 'traceback_property': 'input_generation_traceback_file', 'failure_status': Analysis.status_choices.INPUTS_GENERATION_ERROR, })) - task_id = task.apply_async(args=[self.pk, initiator.pk], priority=self.priority).id + self.run_mode = self.run_mode_choices.V2 + task_id = task.apply_async(args=[self.pk, initiator.pk, loc_lines, events_total], priority=self.priority).id self.generate_inputs_task_id = task_id self.task_started = timezone.now() @@ -451,25 +608,7 @@ def cancel(self): self.task_finished = _now self.save() - @property - def generate_input_signature(self): - return celery_app.signature( - 'start_input_generation_task', - options={ - 'queue': iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery') - } - ) - - @property - def cancel_subtasks_signature(self): - return celery_app.signature( - 'cancel_subtasks', - options={ - 'queue': iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery') - } - ) - - def generate_inputs(self, initiator): + def generate_inputs(self, initiator, run_mode_override=None): valid_choices = [ self.status_choices.NEW, self.status_choices.INPUTS_GENERATION_ERROR, @@ -485,6 +624,8 @@ def generate_inputs(self, initiator): errors['status'] = ['Analysis status must be one of [{}]'.format(', '.join(valid_choices))] if self.model.deleted: errors['model'] = ['Model pk "{}" has been deleted'.format(self.model.id)] + if (self.model.run_mode is None) and (run_mode_override is None): + errors['model'] = ['Model pk "{}" - "run_mode" must not be null'.format(self.model.id)] if not self.portfolio.location_file: errors['portfolio'] = ['"location_file" must not be null'] @@ -492,33 +633,53 @@ def generate_inputs(self, initiator): try: loc_lines = self.portfolio.location_file_len() except Exception as e: - raise ValidationError(f"Failed to read location file size for chunking: {e}") - if not isinstance(loc_lines, int): - errors['portfolio'] = [ - f'Failed to read "location_file" size, content_type={self.portfolio.location_file.content_type} might not be supported'] - - else: - if loc_lines < 1: - errors['portfolio'] = ['"location_file" must at least one row'] + errors['portfolio'] = [f"Failed to read location file size for chunking: {e}"] + if loc_lines < 1: + errors['portfolio'] = ['"location_file" must at least one row'] + if errors: raise ValidationError(errors) + valid_run_modes = [ + self.run_mode_choices.V1, + self.run_mode_choices.V2, + ] + run_mode = run_mode_override if run_mode_override else self.model.run_mode + if run_mode not in valid_run_modes: + raise ValidationError( + {'run_mode': ['run_mode must be [{}]'.format(', '.join(valid_run_modes))]} + ) + self.status = self.status_choices.INPUTS_GENERATION_QUEUED self.lookup_errors_file = None self.lookup_success_file = None self.lookup_validation_file = None self.summary_levels_file = None self.input_generation_traceback_file_id = None + self.input_file = None self.save() - task = self.generate_input_signature - task.on_error(celery_app.signature('handle_task_failure', kwargs={ - 'analysis_id': self.pk, - 'initiator_id': initiator.pk, - 'traceback_property': 'input_generation_traceback_file', - 'failure_status': Analysis.status_choices.INPUTS_GENERATION_ERROR, - })) - task_id = task.apply_async(args=[self.pk, initiator.pk, loc_lines], priority=self.priority).id + if run_mode == self.run_mode_choices.V1: + task = self.v1_generate_input_signature + task.link(record_generate_input_result.s(self.pk, initiator.pk)) + task.link_error( + celery_app_v1.signature('on_error', args=('record_generate_input_failure', self.pk, initiator.pk), queue=self.model.queue_name) + ) + self.run_mode = self.run_mode_choices.V1 + self.status = self.status_choices.INPUTS_GENERATION_QUEUED + task_id = task.delay().id + + elif run_mode == self.run_mode_choices.V2: + task = self.v2_generate_input_signature + task.on_error(celery_app_v2.signature('handle_task_failure', kwargs={ + 'analysis_id': self.pk, + 'initiator_id': initiator.pk, + 'traceback_property': 'input_generation_traceback_file', + 'failure_status': Analysis.status_choices.INPUTS_GENERATION_ERROR, + })) + self.run_mode = self.run_mode_choices.V2 + task_id = task.apply_async(args=[self.pk, initiator.pk, loc_lines], priority=self.priority).id + self.generate_inputs_task_id = task_id self.task_started = timezone.now() self.task_finished = None @@ -551,17 +712,26 @@ def cancel_analysis(self): if self.status not in valid_choices: raise ValidationError({'status': ['Analysis execution is not running or queued']}) - # Kill the task controller job incase its still on queue - AsyncResult(self.run_task_id).revoke( - signal='SIGKILL', - terminate=True, - ) + # Terminate V2 Execution + if self.run_mode is self.run_mode_choices.V2: + # Kill the task controller job incase its still on queue + AsyncResult(self.run_task_id).revoke( + signal='SIGKILL', + terminate=True, + ) + # Send Kill chain call to worker-controller + cancel_tasks = self.v2_cancel_subtasks_signature + task_id = cancel_tasks.apply_async(args=[self.pk], priority=10).id - # Send Kill chain call to worker-controller - cancel_tasks = self.cancel_subtasks_signature - task_id = cancel_tasks.apply_async(args=[self.pk], priority=10).id + # Terminate V1 Execution -- assume this option if not set + else: + AsyncResult(self.run_task_id).revoke( + signal='SIGTERM', + terminate=True, + ) self.status = self.status_choices.RUN_CANCELLED + self.run_mode = None self.task_finished = timezone.now() self.save() @@ -573,17 +743,26 @@ def cancel_generate_inputs(self): if self.status not in valid_choices: raise ValidationError({'status': ['Analysis input generation is not running or queued']}) - # Kill the task controller job incase its still on queue - AsyncResult(self.generate_inputs_task_id).revoke( - signal='SIGKILL', - terminate=True, - ) + # Terminate V2 Execution + if self.run_mode is self.run_mode_choices.V2: + # Kill the task controller job incase its still on queue + AsyncResult(self.generate_inputs_task_id).revoke( + signal='SIGKILL', + terminate=True, + ) + # Send Kill chain call to worker-controller + cancel_tasks = self.v2_cancel_subtasks_signature + task_id = cancel_tasks.apply_async(args=[self.pk], priority=10).id - # Send Kill chain call to worker-controller - cancel_tasks = self.cancel_subtasks_signature - task_id = cancel_tasks.apply_async(args=[self.pk], priority=10).id + # Terminate V1 Execution -- assume this option if not set + else: + AsyncResult(self.generate_inputs_task_id).revoke( + signal='SIGTERM', + terminate=True, + ) self.status = self.status_choices.INPUTS_GENERATION_CANCELLED + self.run_mode = None self.task_finished = timezone.now() self.save() @@ -641,7 +820,7 @@ def copy(self): def delete_connected_files(sender, instance, **kwargs): """ Post delete handler to clear out any dangaling analyses files """ - files_for_removal = [ + for_removal = [ 'settings_file', 'input_file', 'input_generation_traceback_file', @@ -652,8 +831,23 @@ def delete_connected_files(sender, instance, **kwargs): 'lookup_success_file', 'lookup_validation_file', 'summary_levels_file', + 'chunking_options', + ] + for ref in for_removal: + obj_ref = getattr(instance, ref) + if obj_ref: + obj_ref.delete() + + +@receiver(post_delete, sender=AnalysisTaskStatus) +def delete_connected_task_logs(sender, instance, **kwargs): + """ Post delete handler to clear out any dangaling log files + """ + for_removal = [ + 'output_log', + 'error_log', ] - for ref in files_for_removal: - file_ref = getattr(instance, ref) - if file_ref: - file_ref.delete() + for ref in for_removal: + obj_ref = getattr(instance, ref) + if obj_ref: + obj_ref.delete() diff --git a/src/server/oasisapi/analyses/v1_api/__init__.py b/src/server/oasisapi/analyses/v1_api/__init__.py new file mode 100644 index 000000000..90a921f1a --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/__init__.py @@ -0,0 +1 @@ +default_app_config = 'src.server.oasisapi.analyses.apps.V1_AnalysesAppConfig' diff --git a/src/server/oasisapi/analyses/v1_api/serializers.py b/src/server/oasisapi/analyses/v1_api/serializers.py new file mode 100644 index 000000000..b5c2b7aca --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/serializers.py @@ -0,0 +1,307 @@ +from drf_yasg.utils import swagger_serializer_method +from rest_framework import serializers +from rest_framework.exceptions import ValidationError + +from ..models import Analysis +from ...files.models import file_storage_link + + +class AnalysisListSerializer(serializers.Serializer): + """ Read Only Analyses Deserializer for efficiently returning a list of all + Analyses from DB + """ + + # model fields + created = serializers.DateTimeField(read_only=True) + modified = serializers.DateTimeField(read_only=True) + name = serializers.CharField(read_only=True) + id = serializers.IntegerField(read_only=True) + portfolio = serializers.IntegerField(source='portfolio_id', read_only=True) + model = serializers.IntegerField(source='model_id', read_only=True) + status = serializers.CharField(read_only=True) + task_started = serializers.DateTimeField(read_only=True) + task_finished = serializers.DateTimeField(read_only=True) + complex_model_data_files = serializers.PrimaryKeyRelatedField(many=True, read_only=True) + + # file fields + input_file = serializers.SerializerMethodField(read_only=True) + settings_file = serializers.SerializerMethodField(read_only=True) + settings = serializers.SerializerMethodField(read_only=True) + lookup_errors_file = serializers.SerializerMethodField(read_only=True) + lookup_success_file = serializers.SerializerMethodField(read_only=True) + lookup_validation_file = serializers.SerializerMethodField(read_only=True) + summary_levels_file = serializers.SerializerMethodField(read_only=True) + input_generation_traceback_file = serializers.SerializerMethodField(read_only=True) + output_file = serializers.SerializerMethodField(read_only=True) + run_traceback_file = serializers.SerializerMethodField(read_only=True) + run_log_file = serializers.SerializerMethodField(read_only=True) + storage_links = serializers.SerializerMethodField(read_only=True) + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_input_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_input_file_url(request=request) if instance.input_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_settings_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_settings_file_url(request=request) if instance.settings_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_settings(self, instance): + request = self.context.get('request') + return instance.get_absolute_settings_url(request=request) if instance.settings_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_lookup_errors_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_lookup_errors_file_url(request=request) if instance.lookup_errors_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_lookup_success_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_lookup_success_file_url(request=request) if instance.lookup_success_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_lookup_validation_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_lookup_validation_file_url(request=request) if instance.lookup_validation_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_summary_levels_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_summary_levels_file_url(request=request) if instance.summary_levels_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_input_generation_traceback_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_input_generation_traceback_file_url(request=request) if instance.input_generation_traceback_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_output_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_output_file_url(request=request) if instance.output_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_run_traceback_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_run_traceback_file_url(request=request) if instance.run_traceback_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_run_log_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_run_log_file_url(request=request) if instance.run_log_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_storage_links(self, instance): + request = self.context.get('request') + return instance.get_absolute_storage_url(request=request) + + +class AnalysisSerializer(serializers.ModelSerializer): + input_file = serializers.SerializerMethodField() + settings_file = serializers.SerializerMethodField() + settings = serializers.SerializerMethodField() + lookup_errors_file = serializers.SerializerMethodField() + lookup_success_file = serializers.SerializerMethodField() + lookup_validation_file = serializers.SerializerMethodField() + summary_levels_file = serializers.SerializerMethodField() + input_generation_traceback_file = serializers.SerializerMethodField() + output_file = serializers.SerializerMethodField() + run_traceback_file = serializers.SerializerMethodField() + run_log_file = serializers.SerializerMethodField() + storage_links = serializers.SerializerMethodField() + ns = 'v1-analyses' + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = Analysis + fields = ( + 'created', + 'modified', + 'name', + 'id', + 'portfolio', + 'model', + 'status', + 'task_started', + 'task_finished', + 'complex_model_data_files', + 'input_file', + 'settings_file', + 'settings', + 'lookup_errors_file', + 'lookup_success_file', + 'lookup_validation_file', + 'summary_levels_file', + 'input_generation_traceback_file', + 'output_file', + 'run_traceback_file', + 'run_log_file', + 'storage_links', + ) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_input_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_input_file_url(request=request, namespace=self.ns) if instance.input_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_settings_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_settings_file_url(request=request, namespace=self.ns) if instance.settings_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_settings(self, instance): + request = self.context.get('request') + return instance.get_absolute_settings_url(request=request, namespace=self.ns) if instance.settings_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_lookup_errors_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_lookup_errors_file_url(request=request, namespace=self.ns) if instance.lookup_errors_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_lookup_success_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_lookup_success_file_url(request=request, namespace=self.ns) if instance.lookup_success_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_lookup_validation_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_lookup_validation_file_url(request=request, namespace=self.ns) if instance.lookup_validation_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_summary_levels_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_summary_levels_file_url(request=request, namespace=self.ns) if instance.summary_levels_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_input_generation_traceback_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_input_generation_traceback_file_url(request=request, namespace=self.ns) if instance.input_generation_traceback_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_output_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_output_file_url(request=request, namespace=self.ns) if instance.output_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_run_traceback_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_run_traceback_file_url(request=request, namespace=self.ns) if instance.run_traceback_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_run_log_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_run_log_file_url(request=request, namespace=self.ns) if instance.run_log_file_id else None + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_storage_links(self, instance): + request = self.context.get('request') + return instance.get_absolute_storage_url(request=request, namespace=self.ns) + + def validate(self, attrs): + if not attrs.get('creator') and 'request' in self.context: + attrs['creator'] = self.context.get('request').user + + # Check that portfilio has a location file + if attrs.get('portfolio'): + if not attrs['portfolio'].location_file: + raise ValidationError({'portfolio': '"location_file" must not be null'}) + + # check that model isn't soft-deleted + if attrs.get('model'): + if attrs['model'].deleted: + raise ValidationError({'model': ["Model pk '{}' - has been deleted.".format(attrs['model'].id)]}) + if attrs['model'].run_mode != attrs['model'].run_mode_choices.V1: + raise ValidationError({ + 'model': ["Model pk '{}' - Unsupported Operation, 'run_mode' must be 'V1', not '{}'".format( + attrs['model'].id, + attrs['model'].run_mode, + )] + }) + + return attrs + + +class AnalysisStorageSerializer(serializers.ModelSerializer): + settings_file = serializers.SerializerMethodField() + input_file = serializers.SerializerMethodField() + input_generation_traceback_file = serializers.SerializerMethodField() + output_file = serializers.SerializerMethodField() + run_traceback_file = serializers.SerializerMethodField() + run_log_file = serializers.SerializerMethodField() + lookup_errors_file = serializers.SerializerMethodField() + lookup_success_file = serializers.SerializerMethodField() + lookup_validation_file = serializers.SerializerMethodField() + summary_levels_file = serializers.SerializerMethodField() + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = Analysis + fields = ( + 'settings_file', + 'input_file', + 'input_generation_traceback_file', + 'output_file', + 'run_traceback_file', + 'run_log_file', + 'lookup_errors_file', + 'lookup_success_file', + 'lookup_validation_file', + 'summary_levels_file', + ) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_settings_file(self, instance): + return file_storage_link(instance.settings_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_input_file(self, instance): + return file_storage_link(instance.input_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_input_generation_traceback_file(self, instance): + return file_storage_link(instance.input_generation_traceback_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_output_file(self, instance): + return file_storage_link(instance.output_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_run_traceback_file(self, instance): + return file_storage_link(instance.run_traceback_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_run_log_file(self, instance): + return file_storage_link(instance.run_log_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_lookup_errors_file(self, instance): + return file_storage_link(instance.lookup_errors_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_lookup_success_file(self, instance): + return file_storage_link(instance.lookup_success_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_lookup_validation_file(self, instance): + return file_storage_link(instance.lookup_validation_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_summary_levels_file(self, instance): + return file_storage_link(instance.summary_levels_file, True) + + +class AnalysisCopySerializer(AnalysisSerializer): + + def __init__(self, *args, **kwargs): + super(AnalysisCopySerializer, self).__init__(*args, **kwargs) + + self.fields['portfolio'].required = False + self.fields['model'].required = False + self.fields['name'].required = False diff --git a/src/server/oasisapi/analyses/v1_api/tasks.py b/src/server/oasisapi/analyses/v1_api/tasks.py new file mode 100644 index 000000000..37ceb3f08 --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/tasks.py @@ -0,0 +1,601 @@ +from __future__ import absolute_import + +import uuid +import os + +# Remote debugging 'rdb.set_trace()' +# https://docs.celeryproject.org/en/stable/userguide/debugging.html +# from celery.contrib import rdb + +from celery.utils.log import get_task_logger +from celery import Task +from celery import signals +from django.contrib.auth import get_user_model +from django.core.exceptions import ObjectDoesNotExist +from django.core.files import File +from django.core.files.base import ContentFile +from django.core.files.storage import default_storage +from django.conf import settings +from django.http import HttpRequest +from django.utils import timezone +from django.db import transaction + +from botocore.exceptions import ClientError as S3_ClientError +from azure.core.exceptions import ResourceNotFoundError as Blob_ResourceNotFoundError +from azure.storage.blob import BlobLeaseClient +from tempfile import TemporaryFile +from urllib.request import urlopen +from urllib.parse import urlparse + +from src.server.oasisapi.files.models import RelatedFile +from src.server.oasisapi.files.views import handle_json_data +from src.server.oasisapi.schemas.serializers import ModelParametersSerializer +from src.server.oasisapi.files.upload import wait_for_blob_copy + +from ...celery_app_v1 import v1 as celery_app_v1 +from .....conf import celeryconf_v1 as celery_conf +logger = get_task_logger(__name__) + + +def is_valid_url(url): + if url: + result = urlparse(url) + return all([result.scheme, result.netloc]) and result.scheme in ['http', 'https'] + else: + return False + + +def is_in_bucket(object_key): + if not hasattr(default_storage, 'bucket'): + return False + else: + try: + default_storage.bucket.Object(object_key).load() + return True + except S3_ClientError as e: + if e.response['Error']['Code'] == "404": + return False + else: + raise e + + +def is_in_container(object_key): + if not hasattr(default_storage, 'azure_container'): + return False + else: + try: + blob = default_storage.client.get_blob_client(object_key) + blob.get_blob_properties() + return True + except Blob_ResourceNotFoundError: + return False + + +def store_file(reference, content_type, creator, required=True, filename=None): + """ Returns a `RelatedFile` obejct to store + + :param reference: Storage reference of file (url or file path) + :type reference: string + + :param content_type: Mime type of file + :type content_type: string + + :param creator: Id of Django user + :type creator: int + + :param required: Allow for None returns if set to false + :type required: boolean + + :return: Model Object holding a Django file + :rtype RelatedFile + """ + + # Download data from URL + if is_valid_url(reference): + response = urlopen(reference) + fdata = response.read() + + # Find file name + header_fname = response.headers.get('Content-Disposition', '').split('filename=')[-1] + ref = header_fname if header_fname else os.path.basename(urlparse(reference).path) + fname = filename if filename else ref + logger.info('Store file: {}'.format(ref)) + + # Create temp file, download content and store + with TemporaryFile() as tmp_file: + tmp_file.write(fdata) + tmp_file.seek(0) + return RelatedFile.objects.create( + file=File(tmp_file, name=fname), + filename=fname, + content_type=content_type, + creator=creator, + store_as_filename=True, + ) + + # Issue S3 object Copy + if is_in_bucket(reference): + fname = filename if filename else ref + new_file = ContentFile(b'') + new_file.name = fname + new_related_file = RelatedFile.objects.create( + file=new_file, + filename=fname, + content_type=content_type, + creator=creator, + store_as_filename=True, + ) + stored_file = default_storage.open(new_related_file.file.name) + stored_file.obj.copy({"Bucket": default_storage.bucket.name, "Key": reference}) + stored_file.obj.wait_until_exists() + return new_related_file + + # Issue Azure object Copy + if is_in_container(reference): + new_filename = filename if filename else ref + fname = default_storage._get_valid_path(new_filename) + source_blob = default_storage.client.get_blob_client(reference) + dest_blob = default_storage.client.get_blob_client(fname) + + try: + lease = BlobLeaseClient(source_blob) + lease.acquire() + dest_blob.start_copy_from_url(source_blob.url) + wait_for_blob_copy(dest_blob) + lease.break_lease() + except Exception as e: + # copy failed, break file lease and re-raise + lease.break_lease() + raise e + + stored_blob = default_storage.open(os.path.basename(fname)) + new_related_file = RelatedFile.objects.create( + file=File(stored_blob, name=fname), + filename=fname, + content_type=content_type, + creator=creator, + store_as_filename=True) + return new_related_file + + try: + # Copy via shared FS + ref = str(os.path.basename(reference)) + fname = filename if filename else ref + return RelatedFile.objects.create( + file=ref, + filename=fname, + content_type=content_type, + creator=creator, + store_as_filename=True, + ) + except TypeError as e: + if not required: + logger.warning(f'Failed to store file reference: {reference} - {e}') + return None + else: + raise e + + +def delete_prev_output(object_model, field_list=[]): + files_for_removal = list() + + # collect prev attached files + for field in field_list: + current_file = getattr(object_model, field) + if current_file: + logger.info('delete {}: {}'.format(field, current_file)) + setattr(object_model, field, None) + files_for_removal.append(current_file) + + # Clear fields + object_model.save(update_fields=field_list) + + # delete old files + for f in files_for_removal: + f.delete() + + +class LogTaskError(Task): + # from gist https://gist.github.com/darklow/c70a8d1147f05be877c3 + def on_failure(self, exc, task_id, args, kwargs, einfo): + try: + self.handle_task_failure(exc, task_id, args, kwargs, einfo) + except Exception as e: + logger.info('Unhandled Exception in: {}'.format(self.name)) + logger.exception(str(e)) + + super(LogTaskError, self).on_failure(exc, task_id, args, kwargs, einfo) + + def handle_task_failure(self, exc, task_id, args, kwargs, traceback): + logger.info('name: {}'.format(self.name)) + logger.info('args: {}'.format(args)) + logger.info('kwargs: {}'.format(kwargs)) + logger.info('traceback: {}'.format(traceback)) + files_for_removal = list() + + if self.name in ['record_run_analysis_result', 'record_generate_input_result']: + _, analysis_pk, initiator_pk = args + + from ..models import Analysis + initiator = get_user_model().objects.get(pk=initiator_pk) + analysis = Analysis.objects.get(pk=analysis_pk) + random_filename = '{}.txt'.format(uuid.uuid4().hex) + traceback_msg = "worker-monitor error:\n {}".format(traceback) + analysis.task_finished = timezone.now() + + # Store status first, incase issue is in file storage + if self.name == 'record_generate_input_result': + analysis.status = Analysis.status_choices.INPUTS_GENERATION_ERROR + if self.name == 'record_run_analysis_result': + analysis.status = Analysis.status_choices.RUN_ERROR + + # Store Error to traceback file + try: + if self.name == 'record_generate_input_result': + with TemporaryFile() as tmp_file: + tmp_file.write(traceback_msg.encode('utf-8')) + analysis.input_generation_traceback_file = RelatedFile.objects.create( + file=File(tmp_file, name=random_filename), + filename=random_filename, + content_type='text/plain', + creator=initiator, + ) + + if self.name == 'record_run_analysis_result': + with TemporaryFile() as tmp_file: + tmp_file.write(traceback_msg.encode('utf-8')) + analysis.run_traceback_file = RelatedFile.objects.create( + file=File(tmp_file, name=random_filename), + filename=random_filename, + content_type='text/plain', + creator=initiator, + ) + delete_prev_output(analysis, ['run_log_file']) + analysis.save() + + except Exception as e: + # ensure error status is stored (if storage fails) + analysis.save() + raise e + + +@signals.worker_ready.connect +def log_worker_monitor(sender, **k): + logger.info('DEBUG: {}'.format(settings.DEBUG)) + logger.info('DB_ENGINE: {}'.format(settings.DB_ENGINE)) + logger.info('STORAGE_TYPE: {}'.format(settings.STORAGE_TYPE)) + logger.info('DEFAULT_FILE_STORAGE: {}'.format(settings.DEFAULT_FILE_STORAGE)) + logger.info('MEDIA_ROOT: {}'.format(settings.MEDIA_ROOT)) + logger.info('AWS_STORAGE_BUCKET_NAME: {}'.format(settings.AWS_STORAGE_BUCKET_NAME)) + logger.info('AWS_LOCATION: {}'.format(settings.AWS_LOCATION)) + logger.info('AWS_LOG_LEVEL: {}'.format(settings.AWS_LOG_LEVEL)) + logger.info('AWS_S3_REGION_NAME: {}'.format(settings.AWS_S3_REGION_NAME)) + logger.info('AWS_QUERYSTRING_AUTH: {}'.format(settings.AWS_QUERYSTRING_AUTH)) + logger.info('AWS_QUERYSTRING_EXPIRE: {}'.format(settings.AWS_QUERYSTRING_EXPIRE)) + logger.info('AWS_SHARED_BUCKET: {}'.format(settings.AWS_SHARED_BUCKET)) + logger.info('AWS_IS_GZIPPED: {}'.format(settings.AWS_IS_GZIPPED)) + + +@transaction.atomic +@celery_app_v1.task(name='run_register_worker', **celery_conf.worker_task_kwargs) +def run_register_worker(m_supplier, m_name, m_id, m_settings, m_version): + logger.info('model_supplier: {}, model_name: {}, model_id: {}'.format(m_supplier, m_name, m_id)) + try: + from django.contrib.auth.models import User + from src.server.oasisapi.analysis_models.models import AnalysisModel + + try: + model = AnalysisModel.all_objects.get( + model_id=m_name, + supplier_id=m_supplier, + version_id=m_id + ) + # Re-enable model if soft deleted + if model.deleted: + model.activate() + + except ObjectDoesNotExist: + user = User.objects.get(username='admin') + model = AnalysisModel.objects.create( + model_id=m_name, + supplier_id=m_supplier, + version_id=m_id, + creator=user + ) + + # Update model settings file + if m_settings: + try: + request = HttpRequest() + request.data = {**m_settings} + request.method = 'post' + request.version = 'v1' + request.user = model.creator + handle_json_data(model, 'resource_file', request, ModelParametersSerializer) + + logger.info('Updated model settings') + except Exception as e: + logger.info('Failed to update model settings:') + logger.exception(str(e)) + if isinstance(e, S3_ClientError): + raise e + + # Update model version info + if m_version: + try: + model.ver_ktools = m_version['ktools'] + model.ver_oasislmf = m_version['oasislmf'] + model.ver_platform = m_version['platform'] + logger.info('Updated model versions') + except Exception as e: + logger.info('Failed to set model veriosns:') + logger.exception(str(e)) + + # check current value of run_mode -> Set to V1 if null, if 'V2' set to both + if not model.run_mode: + model.run_mode = model.run_mode_choices.V1 + + model.save() + # Log unhandled execptions + except Exception as e: + logger.exception(str(e)) + logger.exception(model) + if isinstance(e, S3_ClientError): + raise e + + +@celery_app_v1.task(name='set_task_status') +def set_task_status(analysis_pk, task_status): + try: + from ..models import Analysis + analysis = Analysis.objects.get(pk=analysis_pk) + analysis.status = task_status + analysis.task_started = timezone.now() + analysis.save(update_fields=["status", "task_started"]) + logger.info('Task Status Update: analysis_pk: {}, status: {}, time: {}'.format(analysis_pk, task_status, analysis.task_started)) + except Exception as e: + logger.error('Task Status Update: Failed') + logger.exception(str(e)) + + +@celery_app_v1.task(name='record_run_analysis_result', base=LogTaskError) +def record_run_analysis_result(res, analysis_pk, initiator_pk): + output_location, traceback_location, log_location, return_code = res + logger.info('output_location: {}, log_location: {}, traceback_location: {}, status: {}, analysis_pk: {}, initiator_pk: {}'.format( + output_location, traceback_location, log_location, return_code, analysis_pk, initiator_pk)) + + from ..models import Analysis + initiator = get_user_model().objects.get(pk=initiator_pk) + analysis = Analysis.objects.get(pk=analysis_pk) + analysis.status = Analysis.status_choices.RUN_COMPLETED if return_code == 0 else Analysis.status_choices.RUN_ERROR + analysis.task_finished = timezone.now() + + delete_prev_output(analysis, ['output_file', 'run_log_file', 'run_traceback_file']) + + # Store results + if return_code == 0: + analysis.output_file = store_file(output_location, 'application/gzip', initiator, filename=f'analysis_{analysis_pk}_output.tar.gz') + # Store Ktools logs + if log_location: + analysis.run_log_file = store_file(log_location, 'application/gzip', initiator, filename=f'analysis_{analysis_pk}_logs.tar.gz') + # record the error file + if traceback_location: + analysis.run_traceback_file = store_file(traceback_location, 'text/plain', initiator, filename=f'analysis_{analysis_pk}_run_traceback.txt') + analysis.save() + + +@celery_app_v1.task(name='record_generate_input_result', base=LogTaskError) +def record_generate_input_result(result, analysis_pk, initiator_pk): + logger.info('result: {}, analysis_pk: {}, initiator_pk: {}'.format( + result, analysis_pk, initiator_pk)) + + from ..models import Analysis + ( + input_location, + lookup_error_fp, + lookup_success_fp, + lookup_validation_fp, + summary_levels_fp, + traceback_fp, + return_code, + ) = result + + analysis = Analysis.objects.get(pk=analysis_pk) + initiator = get_user_model().objects.get(pk=initiator_pk) + + # Remove previous output + delete_prev_output(analysis, [ + 'output_file', + 'input_file', + 'lookup_errors_file', + 'lookup_success_file', + 'lookup_validation_file', + 'summary_levels_file', + 'input_generation_traceback_file', + 'run_traceback_file', + 'run_log_file', + ]) + + # SUCCESS + if return_code == 0: + analysis.status = Analysis.status_choices.READY + # FAILED + else: + analysis.status = Analysis.status_choices.INPUTS_GENERATION_ERROR + + # Add current Output + analysis.input_file = store_file(input_location, 'application/gzip', initiator, + filename=f'analysis_{analysis_pk}_inputs.tar.gz') if input_location else None + analysis.lookup_success_file = store_file(lookup_success_fp, 'text/csv', initiator, + filename=f'analysis_{analysis_pk}_gul_summary_map.csv') if lookup_success_fp else None + analysis.lookup_errors_file = store_file(lookup_error_fp, 'text/csv', initiator, required=False, + filename=f'analysis_{analysis_pk}_keys-errors.csv') if lookup_error_fp else None + analysis.lookup_validation_file = store_file(lookup_validation_fp, 'application/json', initiator, required=False, + filename=f'analysis_{analysis_pk}_exposure_summary_report.json') if lookup_validation_fp else None + analysis.summary_levels_file = store_file(summary_levels_fp, 'application/json', initiator, required=False, + filename=f'analysis_{analysis_pk}_exposure_summary_levels.json') if summary_levels_fp else None + analysis.task_finished = timezone.now() + + # always store traceback + if traceback_fp: + analysis.input_generation_traceback_file = store_file( + traceback_fp, 'text/plain', initiator, filename=f'analysis_{analysis_pk}_generation_traceback.txt') + logger.info(analysis.input_generation_traceback_file) + analysis.save() + + +@celery_app_v1.task(name='record_run_analysis_failure') +def record_run_analysis_failure(analysis_pk, initiator_pk, traceback): + logger.warning('"run_analysis_success" is deprecated and should only be used to process tasks already on the queue.') + logger.info('analysis_pk: {}, initiator_pk: {}, traceback: {}'.format( + analysis_pk, initiator_pk, traceback)) + + try: + from ..models import Analysis + + analysis = Analysis.objects.get(pk=analysis_pk) + analysis.status = Analysis.status_choices.RUN_ERROR + analysis.task_finished = timezone.now() + analysis.save() + + random_filename = '{}.txt'.format(uuid.uuid4().hex) + with TemporaryFile() as tmp_file: + tmp_file.write(traceback.encode('utf-8')) + analysis.run_traceback_file = RelatedFile.objects.create( + file=File(tmp_file, name=random_filename), + filename=f'analysis_{analysis_pk}_run_traceback.txt', + content_type='text/plain', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + + # remove the current command log file + if analysis.run_log_file: + analysis.run_log_file.delete() + analysis.run_log_file = None + + analysis.save() + except Exception as e: + logger.exception(str(e)) + + +@celery_app_v1.task(name='record_generate_input_failure') +def record_generate_input_failure(analysis_pk, initiator_pk, traceback): + logger.info('analysis_pk: {}, initiator_pk: {}, traceback: {}'.format( + analysis_pk, initiator_pk, traceback)) + try: + from ..models import Analysis + + analysis = Analysis.objects.get(pk=analysis_pk) + analysis.status = Analysis.status_choices.INPUTS_GENERATION_ERROR + analysis.task_finished = timezone.now() + analysis.save() + + random_filename = '{}.txt'.format(uuid.uuid4().hex) + with TemporaryFile() as tmp_file: + tmp_file.write(traceback.encode('utf-8')) + analysis.input_generation_traceback_file = RelatedFile.objects.create( + file=File(tmp_file, name=random_filename), + filename=f'analysis_{analysis_pk}_generation_traceback.txt', + content_type='text/plain', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + + analysis.save() + except Exception as e: + logger.exception(str(e)) + +## --- Deprecated tasks ---------------------------------------------------- ## + + +@celery_app_v1.task(name='run_analysis_success') +def run_analysis_success(output_location, analysis_pk, initiator_pk): + logger.warning('"run_analysis_success" is deprecated and should only be used to process tasks already on the queue.') + + logger.info('output_location: {}, analysis_pk: {}, initiator_pk: {}'.format( + output_location, analysis_pk, initiator_pk)) + + try: + from ..models import Analysis + analysis = Analysis.objects.get(pk=analysis_pk) + analysis.status = Analysis.status_choices.RUN_COMPLETED + analysis.task_finished = timezone.now() + + analysis.output_file = RelatedFile.objects.create( + file=str(output_location), + filename=str(output_location), + content_type='application/gzip', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + + # Delete previous error trace + if analysis.run_traceback_file: + traceback = analysis.run_traceback_file + analysis.run_traceback_file = None + traceback.delete() + + analysis.save() + except Exception as e: + logger.exception(str(e)) + + +@celery_app_v1.task(name='generate_input_success') +def generate_input_success(result, analysis_pk, initiator_pk): + logger.warning('"generate_input_success" is deprecated and should only be used to process tasks already on the queue.') + + logger.info('result: {}, analysis_pk: {}, initiator_pk: {}'.format( + result, analysis_pk, initiator_pk)) + try: + from ..models import Analysis + ( + input_location, + lookup_error_fp, + lookup_success_fp, + lookup_validation_fp, + summary_levels_fp, + ) = result + + analysis = Analysis.objects.get(pk=analysis_pk) + analysis.status = Analysis.status_choices.READY + analysis.task_finished = timezone.now() + + analysis.input_file = RelatedFile.objects.create( + file=str(input_location), + filename=str(input_location), + content_type='application/gzip', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + analysis.lookup_errors_file = RelatedFile.objects.create( + file=str(lookup_error_fp), + filename=str('keys-errors.csv'), + content_type='text/csv', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + analysis.lookup_success_file = RelatedFile.objects.create( + file=str(lookup_success_fp), + filename=str('gul_summary_map.csv'), + content_type='text/csv', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + analysis.lookup_validation_file = RelatedFile.objects.create( + file=str(lookup_validation_fp), + filename=str('exposure_summary_report.json'), + content_type='application/json', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + + analysis.summary_levels_file = RelatedFile.objects.create( + file=str(summary_levels_fp), + filename=str('exposure_summary_levels.json'), + content_type='application/json', + creator=get_user_model().objects.get(pk=initiator_pk), + ) + + # Delete previous error trace + if analysis.input_generation_traceback_file: + traceback = analysis.input_generation_traceback_file + analysis.input_generation_traceback_file = None + traceback.delete() + + analysis.save() + except Exception as e: + logger.exception(str(e)) diff --git a/src/server/oasisapi/analyses/tests/__init__.py b/src/server/oasisapi/analyses/v1_api/tests/__init__.py similarity index 100% rename from src/server/oasisapi/analyses/tests/__init__.py rename to src/server/oasisapi/analyses/v1_api/tests/__init__.py diff --git a/src/server/oasisapi/analyses/v1_api/tests/fakes.py b/src/server/oasisapi/analyses/v1_api/tests/fakes.py new file mode 100644 index 000000000..2d9a79ffa --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/tests/fakes.py @@ -0,0 +1,57 @@ +import six +from celery.states import STARTED +from model_mommy import mommy + +from src.server.oasisapi.files.tests.fakes import fake_related_file +from ...models import Analysis + + +class FakeAsyncResultFactory(object): + def __init__(self, target_task_id=None, target_status=STARTED, target_result=None, target_traceback=None): + self.target_task_id = target_task_id + self.revoke_kwargs = {} + self.revoke_called = False + self.target_status = target_status + self.target_result = target_result + self.target_traceback = target_traceback + + class FakeAsyncResult(object): + def __init__(_self, task_id): + if task_id != target_task_id: + raise ValueError() + + _self.id = task_id + _self.status = self.target_status + _self.result = self.target_result + _self.traceback = self.target_traceback + + def revoke(_self, **kwargs): + self.revoke_called = True + self.revoke_kwargs = kwargs + + self.fake_res = FakeAsyncResult + + def __call__(self, task_id): + return self.fake_res(task_id) + + +def fake_analysis(**kwargs): + if isinstance(kwargs.get('input_file'), (six.string_types, six.binary_type)): + kwargs['input_file'] = fake_related_file(file=kwargs['input_file']) + + if isinstance(kwargs.get('lookup_errors_file'), (six.string_types, six.binary_type)): + kwargs['lookup_errors_file'] = fake_related_file(file=kwargs['lookup_errors_file']) + + if isinstance(kwargs.get('lookup_success_file'), (six.string_types, six.binary_type)): + kwargs['lookup_success_file'] = fake_related_file(file=kwargs['lookup_success_file']) + + if isinstance(kwargs.get('lookup_validation_file'), (six.string_types, six.binary_type)): + kwargs['lookup_validation_file'] = fake_related_file(file=kwargs['lookup_validation_file']) + + if isinstance(kwargs.get('output_file'), (six.string_types, six.binary_type)): + kwargs['output_file'] = fake_related_file(file=kwargs['output_file']) + + if isinstance(kwargs.get('settings_file'), (six.string_types, six.binary_type)): + kwargs['settings_file'] = fake_related_file(file=kwargs['settings_file']) + + return mommy.make(Analysis, **kwargs) diff --git a/src/server/oasisapi/analyses/v1_api/tests/test_analysis_api.py b/src/server/oasisapi/analyses/v1_api/tests/test_analysis_api.py new file mode 100644 index 000000000..0ecb0bacb --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/tests/test_analysis_api.py @@ -0,0 +1,1379 @@ +import json +import string + +# from tempfile import NamedTemporaryFile +# from django.conf import settings +from backports.tempfile import TemporaryDirectory +from django.test import override_settings +from django.urls import reverse +from django_webtest import WebTestMixin +from hypothesis import given, settings +from hypothesis.extra.django import TestCase +from hypothesis.strategies import text, binary, sampled_from +from mock import patch +from rest_framework_simplejwt.tokens import AccessToken + +from src.server.oasisapi.files.tests.fakes import fake_related_file +from src.server.oasisapi.analysis_models.v1_api.tests.fakes import fake_analysis_model +from src.server.oasisapi.portfolios.v1_api.tests.fakes import fake_portfolio +from src.server.oasisapi.auth.tests.fakes import fake_user +from src.server.oasisapi.data_files.v1_api.tests.fakes import fake_data_file +from ...models import Analysis +from .fakes import fake_analysis + +# Override default deadline for all tests to 8s +settings.register_profile("ci", deadline=800.0) +settings.load_profile("ci") +NAMESPACE = 'v1-analyses' + + +class AnalysisApi(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + reverse(f'{NAMESPACE}:analysis-detail', kwargs={'pk': analysis.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_name_is_not_provided___response_is_400(self): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params={}, + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + + @given(name=text(alphabet=' \t\n\r', max_size=10)) + def test_cleaned_name_is_empty___response_is_400(self, name): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name}), + content_type='application/json' + ) + + self.assertEqual(400, response.status_code) + + @given(name=text(alphabet=string.ascii_letters, max_size=10, min_size=1)) + def test_cleaned_name_portfolio_and_model_are_present___object_is_created(self, name): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + self.maxDiff = None + user = fake_user() + portfolio = fake_portfolio(location_file=fake_related_file()) + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V1 + model.save() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-list'), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': model.pk}), + content_type='application/json' + ) + self.assertEqual(201, response.status_code) + + analysis = Analysis.objects.get(pk=response.json['id']) + analysis.settings_file = fake_related_file() + analysis.input_file = fake_related_file() + analysis.lookup_errors_file = fake_related_file() + analysis.lookup_success_file = fake_related_file() + analysis.lookup_validation_file = fake_related_file() + analysis.summary_levels_file = fake_related_file() + analysis.input_generation_traceback_file = fake_related_file() + analysis.output_file = fake_related_file() + analysis.run_traceback_file = fake_related_file() + analysis.run_log_file = fake_related_file() + analysis.save() + + response = self.app.get( + analysis.get_absolute_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(200, response.status_code) + self.assertEqual({ + 'complex_model_data_files': [], + 'created': analysis.created.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + 'modified': analysis.modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + 'id': analysis.pk, + 'name': name, + 'portfolio': portfolio.pk, + 'model': model.pk, + 'settings_file': response.request.application_url + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + 'settings': response.request.application_url + analysis.get_absolute_settings_url(namespace=NAMESPACE), + 'input_file': response.request.application_url + analysis.get_absolute_input_file_url(namespace=NAMESPACE), + 'lookup_errors_file': response.request.application_url + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), + 'lookup_success_file': response.request.application_url + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), + 'lookup_validation_file': response.request.application_url + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), + 'input_generation_traceback_file': response.request.application_url + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), + 'output_file': response.request.application_url + analysis.get_absolute_output_file_url(namespace=NAMESPACE), + 'run_log_file': response.request.application_url + analysis.get_absolute_run_log_file_url(namespace=NAMESPACE), + 'run_traceback_file': response.request.application_url + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), + 'status': Analysis.status_choices.NEW, + 'storage_links': response.request.application_url + analysis.get_absolute_storage_url(namespace=NAMESPACE), + 'summary_levels_file': response.request.application_url + analysis.get_absolute_summary_levels_file_url(namespace=NAMESPACE), + 'task_started': None, + 'task_finished': None, + }, response.json) + + @given(name=text(alphabet=string.ascii_letters, max_size=10, min_size=1)) + def test_complex_model_file_present___object_is_created(self, name): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + self.maxDiff = None + user = fake_user() + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V1 + model.save() + portfolio = fake_portfolio(location_file=fake_related_file()) + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-list'), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': model.pk}), + content_type='application/json' + ) + self.assertEqual(201, response.status_code) + + analysis = Analysis.objects.get(pk=response.json['id']) + cmf_1 = fake_data_file() + cmf_2 = fake_data_file() + analysis.complex_model_data_files.set([cmf_1, cmf_2]) + analysis.save() + + response = self.app.get( + analysis.get_absolute_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(200, response.status_code) + self.assertEqual({ + 'complex_model_data_files': [cmf_1.pk, cmf_2.pk], + 'created': analysis.created.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + 'modified': analysis.modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + 'id': analysis.pk, + 'name': name, + 'portfolio': portfolio.pk, + 'model': model.pk, + 'settings_file': None, + 'settings': None, + 'input_file': None, + 'lookup_errors_file': None, + 'lookup_success_file': None, + 'lookup_validation_file': None, + 'input_generation_traceback_file': None, + 'output_file': None, + 'run_log_file': None, + 'run_traceback_file': None, + 'status': Analysis.status_choices.NEW, + 'storage_links': response.request.application_url + analysis.get_absolute_storage_url(namespace=NAMESPACE), + 'summary_levels_file': None, + 'task_started': None, + 'task_finished': None, + }, response.json) + + def test_model_does_not_exist___response_is_400(self): + user = fake_user() + analysis = fake_analysis() + model = fake_analysis_model() + + response = self.app.patch( + analysis.get_absolute_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'model': model.pk + 1}), + content_type='application/json', + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + + def test_model_run_mode_is_not_V1___response_is_400(self): + user = fake_user() + analysis = fake_analysis() + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() + + response = self.app.patch( + analysis.get_absolute_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'model': model.pk}), + content_type='application/json', + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + self.assertEqual( + '{"model":["Model pk \'2\' - Unsupported Operation, \'run_mode\' must be \'V1\', not \'V2\'"]}', + response.text, + ) + + def test_model_does_exist___response_is_200(self): + user = fake_user() + analysis = fake_analysis() + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V1 + model.save() + + response = self.app.patch( + analysis.get_absolute_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'model': model.pk}), + content_type='application/json', + ) + + analysis.refresh_from_db() + + self.assertEqual(200, response.status_code) + self.assertEqual(analysis.model, model) + + +class AnalysisRun(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.post(analysis.get_absolute_run_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-run', kwargs={'pk': analysis.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_run_mode_is_not_V1___responce_is_400(self): + with patch('src.server.oasisapi.analyses.models.Analysis.run', autospec=True) as run_mock: + user = fake_user() + analysis = fake_analysis() + analysis.model.run_mode = analysis.model.run_mode_choices.V2 + analysis.model.save() + + response = self.app.post( + analysis.get_absolute_run_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + self.assertEqual(400, response.status_code) + self.assertEqual( + '{"model":["Model pk 1\' - Unsupported Operation, \'run_mode\' must be \'V1\', not \'V2\'"]}', + response.text + ) + + def test_user_is_authenticated_object_exists___run_is_called(self): + with patch('src.server.oasisapi.analyses.models.Analysis.run', autospec=True) as run_mock: + user = fake_user() + analysis = fake_analysis() + analysis.model.run_mode = analysis.model.run_mode_choices.V1 + analysis.model.save() + + self.app.post( + analysis.get_absolute_run_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + run_mock.assert_called_once_with(analysis, user) + + +class AnalysisCancel(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.post(analysis.get_absolute_cancel_analysis_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-cancel', kwargs={'pk': analysis.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_user_is_authenticated_object_exists___cancel_is_called(self): + with patch('src.server.oasisapi.analyses.models.Analysis.cancel_analysis', autospec=True) as cancel_mock: + user = fake_user() + analysis = fake_analysis() + + self.app.post( + analysis.get_absolute_cancel_analysis_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + cancel_mock.assert_called_once_with(analysis) + + +class AnalysisGenerateInputs(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.post(analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-generate-inputs', kwargs={'pk': analysis.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_user_is_authenticated_object_exists___generate_inputs_is_called(self): + with patch('src.server.oasisapi.analyses.models.Analysis.generate_inputs', autospec=True) as generate_inputs_mock: + user = fake_user() + analysis = fake_analysis() + analysis.model.run_mode = analysis.model.run_mode_choices.V1 + analysis.model.save() + + self.app.post( + analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + generate_inputs_mock.assert_called_once_with(analysis, user) + + def test_model_run_mode_not_V1___response_is_400(self): + with patch('src.server.oasisapi.analyses.models.Analysis.generate_inputs', autospec=True) as generate_inputs_mock: + user = fake_user() + analysis = fake_analysis() + analysis.model.run_mode = analysis.model.run_mode_choices.V2 + analysis.model.save() + + response = self.app.post( + analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + self.assertEqual(400, response.status_code) + self.assertEqual( + '{"model":["Model pk 1\' - Unsupported Operation, \'run_mode\' must be \'V1\', not \'V2\'"]}', + response.text + ) + + +class AnalysisCancelInputsGeneration(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.post(analysis.get_absolute_cancel_inputs_generation_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-cancel-generate-inputs', kwargs={'pk': analysis.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_user_is_authenticated_object_exists___generate_inputs_generation_is_called(self): + with patch('src.server.oasisapi.analyses.models.Analysis.cancel_generate_inputs', autospec=True) as cancel_generate_inputs: + user = fake_user() + analysis = fake_analysis() + + self.app.post( + analysis.get_absolute_cancel_inputs_generation_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + cancel_generate_inputs.assert_called_once_with(analysis) + + +class AnalysisCopy(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.post(analysis.get_absolute_copy_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-copy', kwargs={'pk': analysis.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_new_object_is_created(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertNotEqual(analysis.pk, response.json['id']) + + @given(name=text(min_size=1, max_size=10, alphabet=string.ascii_letters)) + def test_no_new_name_is_provided___copy_is_appended_to_name(self, name): + user = fake_user() + analysis = fake_analysis(name=name) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).name, '{} - Copy'.format(name)) + + @given(orig_name=text(min_size=1, max_size=10, alphabet=string.ascii_letters), new_name=text(min_size=1, max_size=10, alphabet=string.ascii_letters)) + def test_new_name_is_provided___new_name_is_set_on_new_object(self, orig_name, new_name): + user = fake_user() + analysis = fake_analysis(name=orig_name) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': new_name}), + content_type='application/json' + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).name, new_name) + + @given(status=sampled_from(list(Analysis.status_choices._db_values))) + def test_state_is_reset(self, status): + user = fake_user() + analysis = fake_analysis(status=status) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).status, Analysis.status_choices.NEW) + + def test_creator_is_set_to_caller(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).creator, user) + + @given(task_id=text(min_size=1, max_size=10, alphabet=string.ascii_letters)) + def test_run_task_id_is_reset(self, task_id): + user = fake_user() + analysis = fake_analysis(run_task_id=task_id) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).run_task_id, '') + + @given(task_id=text(min_size=1, max_size=10, alphabet=string.ascii_letters)) + def test_generate_inputs_task_id_is_reset(self, task_id): + user = fake_user() + analysis = fake_analysis(generate_inputs_task_id=task_id) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).generate_inputs_task_id, '') + + def test_portfolio_is_not_supplied___portfolio_is_copied(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).portfolio, analysis.portfolio) + + def test_portfolio_is_supplied___portfolio_is_replaced(self): + user = fake_user() + analysis = fake_analysis() + new_portfolio = fake_portfolio(location_file=fake_related_file()) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'portfolio': new_portfolio.pk}), + content_type='application/json', + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).portfolio, new_portfolio) + + def test_model_is_not_supplied___model_is_copied(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).model, analysis.model) + + def test_model_is_supplied___model_is_replaced(self): + user = fake_user() + analysis = fake_analysis() + new_model = fake_analysis_model() + new_model.run_mode = new_model.run_mode_choices.V1 + new_model.save() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'model': new_model.pk}), + content_type='application/json', + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).model, new_model) + + def test_complex_model_file_is_not_supplied___model_is_copied(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + actual_cmf_pks = [obj.pk for obj in Analysis.objects.get(pk=response.json['id']).complex_model_data_files.all()] + expected_cmf_pks = [obj.pk for obj in analysis.complex_model_data_files.all()] + + self.assertEqual(expected_cmf_pks, actual_cmf_pks) + + def test_complex_model_file_is_supplied___model_is_replaced(self): + user = fake_user() + analysis = fake_analysis() + new_cmf_1 = fake_data_file() + new_cmf_2 = fake_data_file() + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'complex_model_data_files': [new_cmf_1.pk, new_cmf_2.pk]}), + content_type='application/json', + ) + + actual_cmf_pks = [obj.pk for obj in Analysis.objects.get(pk=response.json['id']).complex_model_data_files.all()] + expected_cmf_pks = [new_cmf_1.pk, new_cmf_2.pk] + + self.assertEqual(expected_cmf_pks, actual_cmf_pks) + + def test_settings_file_is_not_supplied___settings_file_is_copied(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(settings_file=fake_related_file(file='{}')) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).settings_file.pk, Analysis.objects.get(pk=response.json['id']).pk) + + def test_input_file_is_not_supplied___input_file_is_not_copied(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(input_file=fake_related_file()) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(Analysis.objects.get(pk=response.json['id']).input_file, None) + + def test_lookup_errors_file_is_cleared(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(lookup_errors_file=fake_related_file()) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertIsNone(Analysis.objects.get(pk=response.json['id']).lookup_errors_file) + + def test_lookup_success_file_is_cleared(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(lookup_success_file=fake_related_file()) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertIsNone(Analysis.objects.get(pk=response.json['id']).lookup_success_file) + + def test_lookup_validation_file_is_cleared(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(lookup_validation_file=fake_related_file()) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertIsNone(Analysis.objects.get(pk=response.json['id']).lookup_validation_file) + + def test_output_file_is_cleared(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(output_file=fake_related_file()) + + response = self.app.post( + analysis.get_absolute_copy_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertIsNone(Analysis.objects.get(pk=response.json['id']).output_file) + + +class AnalysisSettingsJson(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_settings_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_settings_json_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_settings_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_settings_json_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_settings_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_settings_json_is_not_valid___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis() + json_data = { + "version": "3", + "analysis_tag": "test_analysis", + "model_supplier_id": "OasisIM", + "model_name_id": "1", + "number_of_samples": -1, + "gul_threshold": 0, + "model_settings": { + "use_random_number_file": True, + "event_occurrence_file_id": "1" + }, + "gul_output": True, + "gul_summaries": [ + { + "id": 1, + "summarycalc": True, + "eltcalc": True, + "aalcalc": "Not-A-Boolean", + "pltcalc": True, + "lec_output": False + } + ], + "il_output": False + } + + response = self.app.post( + analysis.get_absolute_settings_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps(json_data), + content_type='application/json', + expect_errors=True, + ) + + validation_error = { + 'number_of_samples': ['-1 is less than the minimum of 0'], + 'gul_summaries-0-aalcalc': ["'Not-A-Boolean' is not of type 'boolean'"] + } + self.assertEqual(400, response.status_code) + self.assertEqual(json.loads(response.body), validation_error) + + def test_settings_json_is_uploaded___can_be_retrieved(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis() + json_data = { + "version": "3", + "source_tag": "test_source", + "analysis_tag": "test_analysis", + "model_supplier_id": "OasisIM", + 'module_supplier_id': 'OasisIM', + "model_name_id": "1", + 'model_version_id': '1', + "number_of_samples": 10, + "gul_threshold": 0, + "model_settings": { + "use_random_number_file": True, + "event_occurrence_file_id": "1" + }, + "gul_output": True, + "gul_summaries": [ + { + "id": 1, + "summarycalc": True, + "eltcalc": True, + "aalcalc": True, + "pltcalc": True, + "lec_output": False + } + ], + "il_output": False + } + + self.app.post( + analysis.get_absolute_settings_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps(json_data), + content_type='application/json' + ) + + response = self.app.get( + analysis.get_absolute_settings_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + self.assertEqual(json.loads(response.body), json_data) + self.assertEqual(response.content_type, 'application/json') + + +class AnalysisSettingsFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_settings_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_settings_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_settings_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_settings_file_is_not_a_valid_format___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.tar', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + + @given(file_content=binary(min_size=1)) + def test_settings_file_is_uploaded___file_can_be_retrieved(self, file_content): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis() + + self.app.post( + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.json', file_content), + ), + ) + + response = self.app.get( + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, 'application/json') + + +class AnalysisInputFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_input_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_input_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_input_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['application/x-gzip', 'application/gzip', 'application/x-tar', 'application/tar'])) + def test_input_file_is_present___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(input_file=fake_related_file(file=file_content, content_type=content_type)) + + response = self.app.get( + analysis.get_absolute_input_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + +class AnalysisLookupErrorsFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_lookup_errors_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + ''' + def test_lookup_errors_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + ''' + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv'])) + def test_lookup_errors_file_is_present___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(lookup_errors_file=fake_related_file(file=file_content, content_type=content_type)) + + response = self.app.get( + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + +class AnalysisLookupSuccessFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_lookup_success_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + ''' + def test_lookup_success_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + ''' + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv'])) + def test_lookup_success_file_is_present___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(lookup_success_file=fake_related_file(file=file_content, content_type=content_type)) + + response = self.app.get( + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + +class AnalysisLookupValidationFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_lookup_validation_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + ''' + def test_lookup_validation_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + ''' + + @given(file_content=binary(min_size=1), content_type=sampled_from(['application/json'])) + def test_lookup_validation_file_is_present___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(lookup_validation_file=fake_related_file(file=file_content, content_type=content_type)) + + response = self.app.get( + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + +class AnalysisInputGenerationTracebackFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_input_generation_traceback_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_input_generation_traceback_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + @given(file_content=binary(min_size=1)) + def test_input_generation_traceback_file_is_present___file_can_be_retrieved(self, file_content): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(input_generation_traceback_file=fake_related_file(file=file_content, content_type='text/plain')) + + response = self.app.get( + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, 'text/plain') + + +class AnalysisOutputFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_output_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_output_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_output_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_output_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_output_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_output_file_is_not_valid_format___post_response_is_405(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_output_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.csv', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(405, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['application/x-gzip', 'application/gzip', 'application/x-tar', 'application/tar'])) + def test_output_file_is_present___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(output_file=fake_related_file(file=file_content, content_type=content_type)) + + response = self.app.get( + analysis.get_absolute_output_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + +class AnalysisRunTracebackFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + analysis = fake_analysis() + + response = self.app.get(analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_run_traceback_file_is_not_present___get_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.get( + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_run_traceback_file_is_not_present___delete_response_is_404(self): + user = fake_user() + analysis = fake_analysis() + + response = self.app.delete( + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_run_traceback_file_is_not_valid_format___post_response_is_405(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis() + + response = self.app.post( + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.csv', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(405, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['application/x-gzip', 'application/gzip', 'application/x-tar', 'application/tar'])) + def test_run_traceback_file_is_present___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + analysis = fake_analysis(run_traceback_file=fake_related_file(file=file_content, content_type=content_type)) + + response = self.app.get( + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) diff --git a/src/server/oasisapi/analyses/tests/OLD_test_analysis_model.py b/src/server/oasisapi/analyses/v1_api/tests/test_analysis_model.py similarity index 90% rename from src/server/oasisapi/analyses/tests/OLD_test_analysis_model.py rename to src/server/oasisapi/analyses/v1_api/tests/test_analysis_model.py index e8afc700b..fd1d8ead2 100644 --- a/src/server/oasisapi/analyses/tests/OLD_test_analysis_model.py +++ b/src/server/oasisapi/analyses/v1_api/tests/test_analysis_model.py @@ -10,10 +10,10 @@ from mock import patch, PropertyMock, Mock from rest_framework.exceptions import ValidationError -from ...portfolios.tests.fakes import fake_portfolio -from ...files.tests.fakes import fake_related_file -from ...auth.tests.fakes import fake_user -from ..models import Analysis +from src.server.oasisapi.portfolios.v1_api.tests.fakes import fake_portfolio +from src.server.oasisapi.files.tests.fakes import fake_related_file +from src.server.oasisapi.auth.tests.fakes import fake_user +from ...models import Analysis from ..tasks import record_run_analysis_result, record_generate_input_result from .fakes import fake_analysis, FakeAsyncResultFactory @@ -76,8 +76,8 @@ def test_state_is_ready___run_is_started(self, status, task_id): sig_res = Mock() sig_res.delay.return_value = res_factory(task_id) - with patch('src.server.oasisapi.analyses.models.Analysis.run_analysis_signature', PropertyMock(return_value=sig_res)): - analysis.run(initiator) + with patch('src.server.oasisapi.analyses.models.Analysis.v1_run_analysis_signature', PropertyMock(return_value=sig_res)): + analysis.run(initiator, run_mode_override='V1') sig_res.link.assert_called_once_with(record_run_analysis_result.s(analysis.pk, initiator.pk)) sig_res.link_error.assert_called_once_with( @@ -100,23 +100,23 @@ def test_state_is_running_or_generating_inputs___validation_error_is_raised_revo initiator = fake_user() sig_res = Mock() - with patch('src.server.oasisapi.analyses.models.Analysis.run_analysis_signature', PropertyMock(return_value=sig_res)): + with patch('src.server.oasisapi.analyses.models.Analysis.v1_run_analysis_signature', PropertyMock(return_value=sig_res)): analysis = fake_analysis(status=status, run_task_id=task_id) with self.assertRaises(ValidationError) as ex: - analysis.run(initiator) + analysis.run(initiator, run_mode_override='V1') self.assertEqual( {'status': ['Analysis must be in one of the following states [READY, RUN_COMPLETED, RUN_ERROR, RUN_CANCELLED]']}, ex.exception.detail) self.assertEqual(status, analysis.status) self.assertFalse(res_factory.revoke_called) - def test_run_analysis_signature_is_correct(self): + def test_v1_run_analysis_signature_is_correct(self): with TemporaryDirectory() as d: with override_settings(MEDIA_ROOT=d): analysis = fake_analysis(input_file=fake_related_file(), settings_file=fake_related_file()) - sig = analysis.run_analysis_signature + sig = analysis.v1_run_analysis_signature self.assertEqual(sig.task, 'run_analysis') self.assertEqual(sig.args, (analysis.id, analysis.input_file.file.name, analysis.settings_file.file.name, [])) @@ -177,8 +177,8 @@ def test_state_is_not_running___run_is_started(self, status, task_id): sig_res = Mock() sig_res.delay.return_value = res_factory(task_id) - with patch('src.server.oasisapi.analyses.models.Analysis.generate_input_signature', PropertyMock(return_value=sig_res)): - analysis.generate_inputs(initiator) + with patch('src.server.oasisapi.analyses.models.Analysis.v1_generate_input_signature', PropertyMock(return_value=sig_res)): + analysis.generate_inputs(initiator, run_mode_override='V1') sig_res.link.assert_called_once_with(record_generate_input_result.s(analysis.pk, initiator.pk)) sig_res.link_error.assert_called_once_with( @@ -200,11 +200,11 @@ def test_state_is_running_or_generating_inputs___validation_error_is_raised_revo initiator = fake_user() sig_res = Mock() - with patch('src.server.oasisapi.analyses.models.Analysis.generate_input_signature', PropertyMock(return_value=sig_res)): + with patch('src.server.oasisapi.analyses.models.Analysis.v1_generate_input_signature', PropertyMock(return_value=sig_res)): analysis = fake_analysis(status=status, run_task_id=task_id, portfolio=fake_portfolio(location_file=fake_related_file())) with self.assertRaises(ValidationError) as ex: - analysis.generate_inputs(initiator) + analysis.generate_inputs(initiator, run_mode_override='V1') self.assertEqual({'status': [ 'Analysis status must be one of [NEW, INPUTS_GENERATION_ERROR, INPUTS_GENERATION_CANCELLED, READY, RUN_COMPLETED, RUN_CANCELLED, RUN_ERROR]' @@ -220,23 +220,23 @@ def test_portfolio_has_no_location_file___validation_error_is_raised_revoke_is_n initiator = fake_user() sig_res = Mock() - with patch('src.server.oasisapi.analyses.models.Analysis.generate_input_signature', PropertyMock(return_value=sig_res)): + with patch('src.server.oasisapi.analyses.models.Analysis.v1_generate_input_signature', PropertyMock(return_value=sig_res)): analysis = fake_analysis(status=Analysis.status_choices.NEW, run_task_id=task_id) with self.assertRaises(ValidationError) as ex: - analysis.generate_inputs(initiator) + analysis.generate_inputs(initiator, run_mode_override='V1') self.assertEqual({'portfolio': ['"location_file" must not be null']}, ex.exception.detail) self.assertEqual(Analysis.status_choices.NEW, analysis.status) self.assertFalse(res_factory.revoke_called) - def test_generate_input_signature_is_correct(self): + def test_v1_generate_input_signature_is_correct(self): with TemporaryDirectory() as d: with override_settings(MEDIA_ROOT=d): analysis = fake_analysis(portfolio=fake_portfolio(location_file=fake_related_file())) - sig = analysis.generate_input_signature + sig = analysis.v1_generate_input_signature self.assertEqual(sig.task, 'generate_input') self.assertEqual(sig.args, (analysis.id, analysis.portfolio.location_file.file.name, None, None, None, None, [])) diff --git a/src/server/oasisapi/analyses/tests/OLD_test_analysis_tasks.py b/src/server/oasisapi/analyses/v1_api/tests/test_analysis_tasks.py similarity index 98% rename from src/server/oasisapi/analyses/tests/OLD_test_analysis_tasks.py rename to src/server/oasisapi/analyses/v1_api/tests/test_analysis_tasks.py index 472e92a65..2144d6447 100644 --- a/src/server/oasisapi/analyses/tests/OLD_test_analysis_tasks.py +++ b/src/server/oasisapi/analyses/v1_api/tests/test_analysis_tasks.py @@ -14,8 +14,8 @@ except ModuleNotFoundError: from hypothesis.strategies import sampled_from -from ..models import Analysis -from ...auth.tests.fakes import fake_user +from ...models import Analysis +from src.server.oasisapi.auth.tests.fakes import fake_user from ..tasks import record_run_analysis_result, record_run_analysis_failure, record_generate_input_result, record_generate_input_failure from .fakes import fake_analysis diff --git a/src/server/oasisapi/analyses/v1_api/urls.py b/src/server/oasisapi/analyses/v1_api/urls.py new file mode 100644 index 000000000..b87a41fcf --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/urls.py @@ -0,0 +1,20 @@ +from django.conf.urls import include, url +from rest_framework.routers import SimpleRouter +from .viewsets import AnalysisViewSet, AnalysisSettingsView + + +app_name = 'analyses' +v1_api_router = SimpleRouter() +v1_api_router.include_root_view = False +v1_api_router.register('analyses', AnalysisViewSet, basename='analysis') + +analyses_settings = AnalysisSettingsView.as_view({ + 'get': 'analysis_settings', + 'post': 'analysis_settings', + 'delete': 'analysis_settings' +}) + +urlpatterns = [ + url(r'analyses/(?P\d+)/settings/', analyses_settings, name='analysis-settings'), + url(r'', include(v1_api_router.urls)), +] diff --git a/src/server/oasisapi/analyses/v1_api/viewsets.py b/src/server/oasisapi/analyses/v1_api/viewsets.py new file mode 100644 index 000000000..45da4224a --- /dev/null +++ b/src/server/oasisapi/analyses/v1_api/viewsets.py @@ -0,0 +1,474 @@ +from __future__ import absolute_import + +from django.utils.translation import gettext_lazy as _ +from django.utils.decorators import method_decorator +from django.conf import settings + +from rest_framework import viewsets +from rest_framework import permissions +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.serializers import Serializer +from rest_framework.exceptions import APIException + +from drf_yasg.utils import swagger_auto_schema +from django_filters import rest_framework as filters +from django_filters import NumberFilter + +from ..models import Analysis +from .serializers import AnalysisSerializer, AnalysisCopySerializer, AnalysisStorageSerializer, AnalysisListSerializer + +from ...analysis_models.models import AnalysisModel +from ...data_files.v1_api.serializers import DataFileSerializer +from ...filters import TimeStampedFilter, CsvMultipleChoiceFilter, CsvModelMultipleChoiceFilter +from ...files.views import handle_related_file, handle_json_data +from ...files.serializers import RelatedFileSerializer +from ...schemas.custom_swagger import FILE_RESPONSE +from ...schemas.serializers import AnalysisSettingsSerializer + + +class LogAcessDenied(APIException): + status_code = 403 + default_detail = 'Only accounts with staff access are alowed to view system logs.' + default_code = 'system logs disabled by admin' + + +class check_log_permission(permissions.BasePermission): + RESTRICTED_ACTIONS = [ + 'input_generation_traceback_file', + 'run_traceback_file', + 'run_log_file' + ] + + def has_permission(self, request, view): + if not settings.RESTRICT_SYSTEM_LOGS: # are analyses log restricted? + return True + if request.user.is_staff: # user is admin? + return True + # was it a system log message? + if view.action not in self.RESTRICTED_ACTIONS: # request for a log file? + return True + else: + raise LogAcessDenied + + +class AnalysisFilter(TimeStampedFilter): + name = filters.CharFilter( + help_text=_('Filter results by case insensitive names equal to the given string'), + lookup_expr='iexact' + ) + name__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive name containing the given string'), + lookup_expr='icontains', + field_name='name' + ) + status = filters.ChoiceFilter( + help_text=_('Filter results by results in the current analysis status, one of [{}]'.format( + ', '.join(Analysis.status_choices._db_values)) + ), + choices=Analysis.status_choices, + ) + status__in = CsvMultipleChoiceFilter( + help_text=_( + 'Filter results by results where the current analysis status ' + 'is one of a given set (provide multiple parameters or comma separated list), ' + 'from [{}]'.format(', '.join(Analysis.status_choices._db_values)) + ), + choices=Analysis.status_choices, + field_name='status', + label=_('Status in') + ) + model = NumberFilter( + help_text=_('Filter results by the id of the model the analysis belongs to'), + field_name='model' + ) + model__in = CsvModelMultipleChoiceFilter( + help_text=_('Filter results by the id of the model the analysis belongs to'), + field_name='model', + label=_('Model in'), + queryset=AnalysisModel.objects.all(), + ) + user = filters.CharFilter( + help_text=_('Filter results by case insensitive `user` equal to the given string'), + lookup_expr='iexact', + field_name='creator__username' + ) + + class Meta: + model = Analysis + fields = [ + 'name', + 'name__contains', + 'status', + 'status__in', + 'model', + 'model__in', + 'user', + ] + + def __init__(self, *args, **kwargs): + super(AnalysisFilter, self).__init__(*args, **kwargs) + + +@method_decorator(name='list', decorator=swagger_auto_schema(responses={200: AnalysisSerializer(many=True)})) +class AnalysisViewSet(viewsets.ModelViewSet): + """ + list: + Returns a list of Analysis objects. + + ### Examples + + To get all analyses with 'foo' in their name + + /analyses/?name__contains=foo + + To get all analyses created on 1970-01-01 + + /analyses/?created__date=1970-01-01 + + To get all analyses updated before 2000-01-01 + + /analyses/?modified__lt=2000-01-01 + + To get all analyses in the `NEW` state + + /analyses/?status=NEW + + To get all started and pending tasks + + /analyses/?status__in=PENDING&status__in=STARTED + + To get all models in model `1` + + /analyses/?model=1 + + To get all models in models `2` and `3` + + /analyses/?model__in=2&model__in=3 + + retrieve: + Returns the specific analysis entry. + + create: + Creates a analysis based on the input data + + update: + Updates the specified analysis + + partial_update: + Partially updates the specified analysis (only provided fields are updated) + """ + file_action_types = ['settings_file', + 'input_file', + 'lookup_errors_file', + 'lookup_success_file', + 'lookup_validation_file', + 'summary_levels_file', + 'input_generation_traceback_file', + 'run_traceback_file', + 'output_file', + 'run_traceback_file'] + + task_action_types = ['run', + 'cancel', + 'generate_inputs', + 'cancel_generate_inputs'] + + queryset = Analysis.objects.all().select_related(*file_action_types).prefetch_related('complex_model_data_files') + serializer_class = AnalysisSerializer + filterset_class = AnalysisFilter + permission_classes = (permissions.IsAuthenticated, check_log_permission) + + file_action_types.append('set_settings_file') + + def get_serializer_class(self): + if self.action in ['create', 'options', 'update', 'partial_update', 'retrieve']: + return super(AnalysisViewSet, self).get_serializer_class() + elif self.action in ['list']: + return AnalysisListSerializer + elif self.action == 'copy': + return AnalysisCopySerializer + elif self.action == 'data_files': + return DataFileSerializer + elif self.action == 'storage_links': + return AnalysisStorageSerializer + elif self.action in self.file_action_types: + return RelatedFileSerializer + else: + return Serializer + + @property + def parser_classes(self): + if getattr(self, 'action', None) in ['set_settings_file']: + return [MultiPartParser] + else: + return api_settings.DEFAULT_PARSER_CLASSES + + @swagger_auto_schema(responses={200: AnalysisSerializer}) + @action(methods=['post'], detail=True) + def run(self, request, pk=None, version=None): + """ + Runs all the analysis. The analysis must have one of the following + statuses, `NEW`, `RUN_COMPLETED`, `RUN_CANCELLED` or + `RUN_ERROR` + """ + obj = self.get_object() + if obj.model.run_mode != obj.model.run_mode_choices.V1: + obj.raise_validate_errors( + {'model': [f"Model pk {obj.model.id}' - Unsupported Operation, 'run_mode' must be 'V1', not '{obj.model.run_mode}'"]} + ) + else: + obj.run(request.user) + return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) + + @swagger_auto_schema(responses={200: AnalysisSerializer}) + @action(methods=['post'], detail=True) + def cancel(self, request, pk=None, version=None): + """ + Cancels either input generation or analysis execution depending on the active stage. + The analysis must have one of the following statuses, `INPUTS_GENERATION_QUEUED`, `INPUTS_GENERATION_STARTED`, `RUN_QUEUED` or `RUN_STARTED` + """ + obj = self.get_object() + obj.cancel_any() + return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) + + @swagger_auto_schema(responses={200: AnalysisSerializer}) + @action(methods=['post'], detail=True) + def cancel_analysis_run(self, request, pk=None, version=None): + """ + Cancels a running analysis execution. The analysis must have one of the following statuses, `RUN_QUEUED` or `RUN_STARTED` + """ + obj = self.get_object() + obj.cancel_analysis() + return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) + + @swagger_auto_schema(responses={200: AnalysisSerializer}) + @action(methods=['post'], detail=True) + def generate_inputs(self, request, pk=None, version=None): + """ + Generates the inputs for the analysis based on the portfolio. + The analysis must have one of the following statuses, `INPUTS_GENERATION_QUEUED` or `INPUTS_GENERATION_STARTED` + """ + obj = self.get_object() + # Check run_mode == V1 before dispatch + if obj.model.run_mode != obj.model.run_mode_choices.V1: + obj.raise_validate_errors( + {'model': [f"Model pk {obj.model.id}' - Unsupported Operation, 'run_mode' must be 'V1', not '{obj.model.run_mode}'"]} + ) + else: + obj.generate_inputs(request.user) + return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) + + @swagger_auto_schema(responses={200: AnalysisSerializer}) + @action(methods=['post'], detail=True) + def cancel_generate_inputs(self, request, pk=None, version=None): + """ + Cancels a currently inputs generation. The analysis status must be `INPUTS_GENERATION_STARTED` + """ + obj = self.get_object() + obj.cancel_generate_inputs() + return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) + + @action(methods=['post'], detail=True) + def copy(self, request, pk=None, version=None): + """ + Copies an existing analysis, copying the associated input files and model and modifying + it's name (if none is provided) and resets the status, input errors and outputs + """ + obj = self.get_object() + new_obj = obj.copy() + + new_obj.save() + new_obj.creator = None + + serializer = self.get_serializer(instance=new_obj, data=request.data, context=self.get_serializer_context(), partial=True) + serializer.is_valid(raise_exception=True) + serializer.save() + + return Response(serializer.data) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get', 'delete'], detail=True) + def settings_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `settings_file` contents + + delete: + Disassociates the portfolios `settings_file` contents + """ + return handle_related_file(self.get_object(), 'settings_file', request, ['application/json']) + + @settings_file.mapping.post + def set_settings_file(self, request, pk=None, version=None): + """ + post: + Sets the portfolios `settings_file` contents + """ + return handle_related_file(self.get_object(), 'settings_file', request, ['application/json']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get'], detail=True) + def input_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `input_file` contents + + delete: + Disassociates the portfolios `input_file` contents + """ + return handle_related_file(self.get_object(), 'input_file', request, ['application/x-gzip', 'application/gzip', 'application/x-tar', 'application/tar']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get'], detail=True) + def lookup_errors_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `lookup_errors_file` contents + + post: + Sets the portfolios `lookup_errors_file` contents + + delete: + Disassociates the portfolios `lookup_errors_file` contents + """ + return handle_related_file(self.get_object(), 'lookup_errors_file', request, ['text/csv']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get'], detail=True) + def lookup_success_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `lookup_success_file` contents + + post: + Sets the portfolios `lookup_success_file` contents + + delete: + Disassociates the portfolios `lookup_success_file` contents + """ + return handle_related_file(self.get_object(), 'lookup_success_file', request, ['text/csv']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get'], detail=True) + def lookup_validation_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `lookup_validation_file` contents + + post: + Sets the portfolios `lookup_validation_file` contents + + delete: + Disassociates the portfolios `lookup_validation_file` contents + """ + return handle_related_file(self.get_object(), 'lookup_validation_file', request, ['application/json']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get'], detail=True) + def summary_levels_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `summary_levels_file` contents + + post: + Sets the portfolios `summary_levels_file` contents + + delete: + Disassociates the portfolios `summary_levels_file` contents + """ + return handle_related_file(self.get_object(), 'summary_levels_file', request, ['application/json']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get', 'delete'], detail=True) + def input_generation_traceback_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `input_generation_traceback_file` contents + + delete: + Disassociates the portfolios `input_generation_traceback_file` contents + """ + return handle_related_file(self.get_object(), 'input_generation_traceback_file', request, ['text/plain']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get', 'delete'], detail=True) + def output_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `output_file` contents + + delete: + Disassociates the portfolios `output_file` contents + """ + return handle_related_file(self.get_object(), 'output_file', request, ['application/x-gzip', 'application/gzip', 'application/x-tar', 'application/tar']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get', 'delete'], detail=True) + def run_traceback_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `run_traceback_file` contents + + delete: + Disassociates the portfolios `run_traceback_file` contents + """ + return handle_related_file(self.get_object(), 'run_traceback_file', request, ['text/plain']) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get', 'delete'], detail=True) + def run_log_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `run_log_file` contents + + delete: + Disassociates the portfolios `run_log_file` contents + """ + return handle_related_file(self.get_object(), 'run_log_file', request, ['text/plain']) + + @swagger_auto_schema(responses={200: DataFileSerializer(many=True)}) + @action(methods=['get'], detail=True) + def data_files(self, request, pk=None, version=None): + df = self.get_object().complex_model_data_files.all() + context = {'request': request} + + df_serializer = DataFileSerializer(df, many=True, context=context) + return Response(df_serializer.data) + + @action(methods=['get'], detail=True) + def storage_links(self, request, pk=None, version=None): + """ + get: + Gets the analyses storage backed link references, `object keys` or `file paths` + """ + serializer = self.get_serializer(self.get_object()) + return Response(serializer.data) + + +class AnalysisSettingsView(viewsets.ModelViewSet): + """ + list: + Return the settings of an Analysis object. + """ + queryset = Analysis.objects.all() + serializer_class = AnalysisSerializer + filterset_class = AnalysisFilter + + @swagger_auto_schema(methods=['get'], responses={200: AnalysisSettingsSerializer}) + @swagger_auto_schema(methods=['post'], request_body=AnalysisSettingsSerializer, responses={201: RelatedFileSerializer}) + @action(methods=['get', 'post', 'delete'], detail=True) + def analysis_settings(self, request, pk=None, version=None): + """ + get: + Gets the analyses `settings` contents + + post: + Sets the analyses `settings` contents + + delete: + Disassociates the portfolios `settings_file` contents + """ + return handle_json_data(self.get_object(), 'settings_file', request, AnalysisSettingsSerializer) diff --git a/src/server/oasisapi/analyses/v2_api/__init__.py b/src/server/oasisapi/analyses/v2_api/__init__.py new file mode 100644 index 000000000..01cf3aaef --- /dev/null +++ b/src/server/oasisapi/analyses/v2_api/__init__.py @@ -0,0 +1 @@ +default_app_config = 'src.server.oasisapi.analyses.apps.V2_AnalysesAppConfig' diff --git a/src/server/oasisapi/analyses/serializers.py b/src/server/oasisapi/analyses/v2_api/serializers.py similarity index 87% rename from src/server/oasisapi/analyses/serializers.py rename to src/server/oasisapi/analyses/v2_api/serializers.py index 8be6408e6..8ca72777c 100644 --- a/src/server/oasisapi/analyses/serializers.py +++ b/src/server/oasisapi/analyses/v2_api/serializers.py @@ -4,10 +4,17 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError -from ....conf import iniconf -from .models import Analysis, AnalysisTaskStatus -from ..files.models import file_storage_link -from ..permissions.group_auth import verify_and_get_groups, validate_data_files +from .....conf import iniconf +from ..models import Analysis, AnalysisTaskStatus +from ...files.models import file_storage_link +from ...permissions.group_auth import verify_and_get_groups, validate_data_files + +from ...schemas.serializers import ( + GroupNameSerializer, + QueueNameSerializer, + TaskCountSerializer, + TaskErrorSerializer, +) class AnalysisTaskStatusSerializer(serializers.ModelSerializer): @@ -54,6 +61,7 @@ class AnalysisListSerializer(serializers.Serializer): portfolio = serializers.IntegerField(source='portfolio_id', read_only=True) model = serializers.IntegerField(source='model_id', read_only=True) status = serializers.CharField(read_only=True) + run_mode = serializers.CharField(read_only=True) task_started = serializers.DateTimeField(read_only=True) task_finished = serializers.DateTimeField(read_only=True) complex_model_data_files = serializers.PrimaryKeyRelatedField(many=True, read_only=True) @@ -82,6 +90,7 @@ class AnalysisListSerializer(serializers.Serializer): run_traceback_file = serializers.SerializerMethodField(read_only=True) run_log_file = serializers.SerializerMethodField(read_only=True) storage_links = serializers.SerializerMethodField(read_only=True) + chunking_configuration = serializers.SerializerMethodField(read_only=True) @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_input_file(self, instance): @@ -143,7 +152,12 @@ def get_storage_links(self, instance): request = self.context.get('request') return instance.get_absolute_storage_url(request=request) - @swagger_serializer_method(serializer_or_field=serializers.CharField) + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_chunking_configuration(self, instance): + request = self.context.get('request') + return instance.get_absolute_chunking_configuration_url(request=request) + + @swagger_serializer_method(serializer_or_field=GroupNameSerializer) def get_groups(self, instance): return instance.get_groups() @@ -152,14 +166,17 @@ def get_sub_task_list(self, instance): request = self.context.get('request') return instance.get_absolute_subtask_list_url(request=request) + @swagger_serializer_method(serializer_or_field=serializers.IntegerField) def get_sub_task_count(self, instance): subtask_queryset = instance.sub_task_statuses.get_queryset() return subtask_queryset.count() + @swagger_serializer_method(serializer_or_field=TaskErrorSerializer) def get_sub_task_error_ids(self, instance): subtask_queryset = instance.sub_task_statuses.get_queryset() return subtask_queryset.filter(status='ERROR').values_list('pk', flat=True) + @swagger_serializer_method(serializer_or_field=TaskCountSerializer()) def get_status_count(self, instance): # request = self.context.get('request') subtask_queryset = instance.sub_task_statuses.get_queryset() @@ -190,6 +207,8 @@ class AnalysisSerializer(serializers.ModelSerializer): run_traceback_file = serializers.SerializerMethodField() run_log_file = serializers.SerializerMethodField() storage_links = serializers.SerializerMethodField() + chunking_configuration = serializers.SerializerMethodField() + ns = 'v2-analyses' # Groups - inherited from portfolio groups = serializers.SerializerMethodField(read_only=True) @@ -210,6 +229,7 @@ class Meta: 'portfolio', 'model', 'status', + 'run_mode', 'task_started', 'task_finished', 'complex_model_data_files', @@ -225,6 +245,7 @@ class Meta: 'run_traceback_file', 'run_log_file', 'storage_links', + 'chunking_configuration', 'lookup_chunks', 'analysis_chunks', 'sub_task_count', @@ -238,80 +259,88 @@ class Meta: @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_input_file(self, instance): request = self.context.get('request') - return instance.get_absolute_input_file_url(request=request) if instance.input_file_id else None + return instance.get_absolute_input_file_url(request=request, namespace=self.ns) if instance.input_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_settings_file(self, instance): request = self.context.get('request') - return instance.get_absolute_settings_file_url(request=request) if instance.settings_file_id else None + return instance.get_absolute_settings_file_url(request=request, namespace=self.ns) if instance.settings_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_settings(self, instance): request = self.context.get('request') - return instance.get_absolute_settings_url(request=request) if instance.settings_file_id else None + return instance.get_absolute_settings_url(request=request, namespace=self.ns) if instance.settings_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_lookup_errors_file(self, instance): request = self.context.get('request') - return instance.get_absolute_lookup_errors_file_url(request=request) if instance.lookup_errors_file_id else None + return instance.get_absolute_lookup_errors_file_url(request=request, namespace=self.ns) if instance.lookup_errors_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_lookup_success_file(self, instance): request = self.context.get('request') - return instance.get_absolute_lookup_success_file_url(request=request) if instance.lookup_success_file_id else None + return instance.get_absolute_lookup_success_file_url(request=request, namespace=self.ns) if instance.lookup_success_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_lookup_validation_file(self, instance): request = self.context.get('request') - return instance.get_absolute_lookup_validation_file_url(request=request) if instance.lookup_validation_file_id else None + return instance.get_absolute_lookup_validation_file_url(request=request, namespace=self.ns) if instance.lookup_validation_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_summary_levels_file(self, instance): request = self.context.get('request') - return instance.get_absolute_summary_levels_file_url(request=request) if instance.summary_levels_file_id else None + return instance.get_absolute_summary_levels_file_url(request=request, namespace=self.ns) if instance.summary_levels_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_input_generation_traceback_file(self, instance): request = self.context.get('request') - return instance.get_absolute_input_generation_traceback_file_url(request=request) if instance.input_generation_traceback_file_id else None + return instance.get_absolute_input_generation_traceback_file_url(request=request, namespace=self.ns) if instance.input_generation_traceback_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_output_file(self, instance): request = self.context.get('request') - return instance.get_absolute_output_file_url(request=request) if instance.output_file_id else None + return instance.get_absolute_output_file_url(request=request, namespace=self.ns) if instance.output_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_run_traceback_file(self, instance): request = self.context.get('request') - return instance.get_absolute_run_traceback_file_url(request=request) if instance.run_traceback_file_id else None + return instance.get_absolute_run_traceback_file_url(request=request, namespace=self.ns) if instance.run_traceback_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_run_log_file(self, instance): request = self.context.get('request') - return instance.get_absolute_run_log_file_url(request=request) if instance.run_log_file_id else None + return instance.get_absolute_run_log_file_url(request=request, namespace=self.ns) if instance.run_log_file_id else None @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_storage_links(self, instance): request = self.context.get('request') - return instance.get_absolute_storage_url(request=request) + return instance.get_absolute_storage_url(request=request, namespace=self.ns) - @swagger_serializer_method(serializer_or_field=serializers.CharField) + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_chunking_configuration(self, instance): + request = self.context.get('request') + return instance.get_absolute_chunking_configuration_url(request=request, namespace=self.ns) + + @swagger_serializer_method(serializer_or_field=GroupNameSerializer) def get_groups(self, instance): return instance.get_groups() @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_sub_task_list(self, instance): request = self.context.get('request') - return instance.get_absolute_subtask_list_url(request=request) + return instance.get_absolute_subtask_list_url(request=request, namespace=self.ns) + @swagger_serializer_method(serializer_or_field=serializers.IntegerField) def get_sub_task_count(self, instance): subtask_queryset = instance.sub_task_statuses.get_queryset() return subtask_queryset.count() + @swagger_serializer_method(serializer_or_field=TaskErrorSerializer) def get_sub_task_error_ids(self, instance): subtask_queryset = instance.sub_task_statuses.get_queryset() return subtask_queryset.filter(status='ERROR').values_list('pk', flat=True) + @swagger_serializer_method(serializer_or_field=TaskCountSerializer()) def get_status_count(self, instance): # request = self.context.get('request') subtask_queryset = instance.sub_task_statuses.get_queryset() @@ -361,8 +390,14 @@ def validate(self, attrs): # check that model isn't soft-deleted if attrs.get('model'): if attrs['model'].deleted: - error = {'model': ["Model pk \"{}\" - has been deleted.".format(attrs['model'].id)]} - raise ValidationError(detail=error) + raise ValidationError({ + 'model': ["Model pk \"{}\" - has been deleted.".format(attrs['model'].id)] + }) + if attrs['model'].run_mode is None: + raise ValidationError({ + 'model': ["Model pk \"{}\" - 'run_mode' must not be null".format(attrs['model'].id)] + }) + return attrs @@ -384,15 +419,18 @@ class AnalysisSerializerWebSocket(serializers.Serializer): queue_names = serializers.SerializerMethodField(read_only=True) status_count = serializers.SerializerMethodField(read_only=True) + @swagger_serializer_method(serializer_or_field=serializers.IntegerField) def get_sub_task_count(self, instance): subtask_queryset = instance.sub_task_statuses.get_queryset() return subtask_queryset.count() + @swagger_serializer_method(serializer_or_field=QueueNameSerializer) def get_queue_names(self, instance): subtask_queryset = instance.sub_task_statuses.get_queryset() running_subtasks_queryset = subtask_queryset.filter(status__in=['PENDING', 'QUEUED', 'STARTED']) return list(running_subtasks_queryset.order_by().values_list('queue_name', flat=True).distinct()) + @swagger_serializer_method(serializer_or_field=TaskCountSerializer()) def get_status_count(self, instance): # request = self.context.get('request') subtask_queryset = instance.sub_task_statuses.get_queryset() diff --git a/src/server/oasisapi/analyses/signal_receivers.py b/src/server/oasisapi/analyses/v2_api/signal_receivers.py similarity index 93% rename from src/server/oasisapi/analyses/signal_receivers.py rename to src/server/oasisapi/analyses/v2_api/signal_receivers.py index d1719a630..d7a039977 100644 --- a/src/server/oasisapi/analyses/signal_receivers.py +++ b/src/server/oasisapi/analyses/v2_api/signal_receivers.py @@ -1,6 +1,6 @@ from src.server.oasisapi.queues.consumers import send_task_status_message, TaskStatusMessageAnalysisItem, \ TaskStatusMessageItem, build_task_status_message -from ..queues.utils import filter_queues_info +from ...queues.utils import filter_queues_info def task_updated(instance, *args, **kwargs): diff --git a/src/server/oasisapi/analyses/task_controller.py b/src/server/oasisapi/analyses/v2_api/task_controller.py similarity index 91% rename from src/server/oasisapi/analyses/task_controller.py rename to src/server/oasisapi/analyses/v2_api/task_controller.py index 7f403c0d2..fd946c84c 100644 --- a/src/server/oasisapi/analyses/task_controller.py +++ b/src/server/oasisapi/analyses/v2_api/task_controller.py @@ -9,7 +9,7 @@ from django.contrib.auth.models import User from src.conf.iniconf import settings -from ..files.models import file_storage_link +from ...files.models import file_storage_link if TYPE_CHECKING: from src.server.oasisapi.analyses.models import Analysis, AnalysisTaskStatus @@ -38,7 +38,7 @@ def get_subtask_signature(cls, task_name, analysis, initiator, run_data_uuid, sl :param params: The parameters to send to the task :return: Signature representing the task """ - from src.server.oasisapi.analyses.tasks import record_sub_task_success, record_sub_task_failure + from src.server.oasisapi.analyses.v2_api.tasks import record_sub_task_success, record_sub_task_failure sig = signature( task_name, queue=queue, @@ -244,7 +244,7 @@ def get_generate_inputs_queue(cls, analysis: 'Analysis', initiator: User) -> str :return: The name of the queue """ - return str(analysis.model) + return str(analysis.model) + '-v2' @classmethod def get_inputs_generation_tasks( @@ -341,7 +341,7 @@ def get_inputs_generation_tasks( run_data_uuid, 'Record input files', 'record-input-files', - 'celery', + 'celery-v2', TaskParams(analysis_finish_status=analysis_finish_status), ), cls.get_subtask_statuses_and_signature( @@ -369,11 +369,7 @@ def generate_inputs(cls, analysis: 'Analysis', initiator: User, loc_lines: int) from src.server.oasisapi.analyses.models import Analysis # fetch the number of lookup chunks and store in analysis - if analysis.model.chunking_options.lookup_strategy == 'FIXED_CHUNKS': - num_chunks = min(analysis.model.chunking_options.fixed_lookup_chunks, loc_lines) - elif analysis.model.chunking_options.lookup_strategy == 'DYNAMIC_CHUNKS': - loc_lines_per_chunk = analysis.model.chunking_options.dynamic_locations_per_lookup - num_chunks = min(ceil(loc_lines / loc_lines_per_chunk), analysis.model.chunking_options.dynamic_chunks_max) + num_chunks = cls._get_inputs_generation_chunks(analysis, loc_lines) run_data_uuid = uuid.uuid4().hex statuses, tasks = cls.get_inputs_generation_tasks(analysis, initiator, run_data_uuid, num_chunks) @@ -394,13 +390,22 @@ def generate_inputs(cls, analysis: 'Analysis', initiator: User, loc_lines: int) return chain @classmethod - def _get_inputs_generation_chunks(cls, analysis): - if analysis.model.chunking_options.lookup_strategy == 'FIXED_CHUNKS': - num_chunks = analysis.model.chunking_options.fixed_lookup_chunks - elif analysis.model.chunking_options.lookup_strategy == 'DYNAMIC_CHUNKS': - loc_lines = sum(1 for line in analysis.portfolio.location_file.read()) - loc_lines_per_chunk = analysis.model.chunking_options.dynamic_locations_per_lookup - num_chunks = ceil(loc_lines / loc_lines_per_chunk) + def _get_inputs_generation_chunks(cls, analysis, loc_lines): + # loc_lines = sum(1 for line in analysis.portfolio.location_file.read()) + + # Get options + if analysis.chunking_options is not None: + chunking_options = analysis.chunking_options # Use options from Analysis + else: + chunking_options = analysis.model.chunking_options # Use defaults set on model + + # Set chunks + if chunking_options.lookup_strategy == 'FIXED_CHUNKS': + num_chunks = min(chunking_options.fixed_lookup_chunks, loc_lines) + elif chunking_options.lookup_strategy == 'DYNAMIC_CHUNKS': + loc_lines_per_chunk = chunking_options.dynamic_locations_per_lookup + num_chunks = min(ceil(loc_lines / loc_lines_per_chunk), chunking_options.dynamic_chunks_max) + return num_chunks @classmethod @@ -413,7 +418,7 @@ def get_generate_losses_queue(cls, analysis: 'Analysis', initiator: User) -> str :return: The name of the queue """ - return str(analysis.model) + return str(analysis.model) + '-v2' @classmethod def get_loss_generation_tasks(cls, analysis: 'Analysis', initiator: User, run_data_uuid: str, num_chunks: int): @@ -485,7 +490,7 @@ def get_loss_generation_tasks(cls, analysis: 'Analysis', initiator: User, run_da run_data_uuid, 'Record losses files', 'record-losses-files', - 'celery', + 'celery-v2', TaskParams(**base_kwargs), ), cls.get_subtask_statuses_and_signature( @@ -512,13 +517,7 @@ def generate_losses(cls, analysis: 'Analysis', initiator: User, events_total: in """ from src.server.oasisapi.analyses.models import Analysis - # fetch number of event chunks - if analysis.model.chunking_options.loss_strategy == 'FIXED_CHUNKS': - num_chunks = analysis.model.chunking_options.fixed_analysis_chunks - elif analysis.model.chunking_options.loss_strategy == 'DYNAMIC_CHUNKS': - events_per_chunk = analysis.model.chunking_options.dynamic_events_per_analysis - num_chunks = min(ceil(events_total / events_per_chunk), analysis.model.chunking_options.dynamic_chunks_max) - + num_chunks = cls._get_loss_generation_chunks(analysis, events_total) run_data_uuid = uuid.uuid4().hex statuses, tasks = cls.get_loss_generation_tasks(analysis, initiator, run_data_uuid, num_chunks) @@ -538,15 +537,24 @@ def generate_losses(cls, analysis: 'Analysis', initiator: User, events_total: in return chain @classmethod - def _get_loss_generation_chunks(cls, analysis): - if analysis.model.chunking_options.loss_strategy == 'FIXED_CHUNKS': - num_chunks = analysis.model.chunking_options.fixed_analysis_chunks - elif analysis.model.chunking_options.loss_strategy == 'DYNAMIC_CHUNKS': - raise notimplementederror("FEATURE NOT AVALIBLE -- need event set size from worker") + def _get_loss_generation_chunks(cls, analysis, events_total): + # Get options + if analysis.chunking_options is not None: + chunking_options = analysis.chunking_options # Use options from Analysis + else: + chunking_options = analysis.model.chunking_options # Use defaults set on model + + # fetch number of event chunks + if chunking_options.loss_strategy == 'FIXED_CHUNKS': + num_chunks = chunking_options.fixed_analysis_chunks + elif chunking_options.loss_strategy == 'DYNAMIC_CHUNKS': + events_per_chunk = chunking_options.dynamic_events_per_analysis + num_chunks = min(ceil(events_total / events_per_chunk), chunking_options.dynamic_chunks_max) + return num_chunks @classmethod - def generate_input_and_losses(cls, analysis: 'Analysis', initiator: User): + def generate_input_and_losses(cls, analysis: 'Analysis', initiator: User, loc_lines: int, events_total: int): """ Starts the input generation chain @@ -567,9 +575,9 @@ def generate_input_and_losses(cls, analysis: 'Analysis', initiator: User): from src.server.oasisapi.analyses.models import Analysis # fetch the number of lookup chunks and store in analysis - input_num_chunks = cls._get_inputs_generation_chunks(analysis) + input_num_chunks = cls._get_inputs_generation_chunks(analysis, loc_lines) # fetch number of event chunks - loss_num_chunks = cls._get_loss_generation_chunks(analysis) + loss_num_chunks = cls._get_loss_generation_chunks(analysis, events_total) input_run_data_uuid = uuid.uuid4().hex loss_run_data_uuid = uuid.uuid4().hex @@ -632,7 +640,7 @@ def get_analysis_task_controller() -> Type[Controller]: controller_path = settings.get( 'worker', 'ANALYSIS_TASK_CONTROLLER', - fallback='src.server.oasisapi.analyses.task_controller.Controller' + fallback='src.server.oasisapi.analyses.v2_api.task_controller.Controller' ) controller_module, controller_class = controller_path.rsplit('.', maxsplit=1) diff --git a/src/server/oasisapi/analyses/tasks.py b/src/server/oasisapi/analyses/v2_api/tasks.py similarity index 93% rename from src/server/oasisapi/analyses/tasks.py rename to src/server/oasisapi/analyses/v2_api/tasks.py index cf4d14628..721479e60 100644 --- a/src/server/oasisapi/analyses/tasks.py +++ b/src/server/oasisapi/analyses/v2_api/tasks.py @@ -10,8 +10,8 @@ from urllib.parse import urlparse from urllib.request import urlopen -from ....conf import celeryconf as celery_conf -from ....conf.iniconf import settings as worker_settings +from .....conf import celeryconf_v2 as celery_conf +from .....conf.iniconf import settings as worker_settings from botocore.exceptions import ClientError as S3_ClientError from azure.core.exceptions import ResourceNotFoundError as Blob_ResourceNotFoundError @@ -26,6 +26,7 @@ from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist from django.core.files import File +from django.db import transaction from django.db.models import When, Case, Value, F from django.core.files.base import ContentFile from django.core.files.storage import default_storage @@ -33,7 +34,7 @@ from django.http import HttpRequest from django.utils import timezone -from src.conf.iniconf import settings + from botocore.exceptions import ClientError as S3_ClientError from tempfile import TemporaryFile from urllib.request import urlopen @@ -44,12 +45,9 @@ from src.server.oasisapi.schemas.serializers import ModelParametersSerializer from src.server.oasisapi.files.upload import wait_for_blob_copy -from .models import AnalysisTaskStatus +from ..models import AnalysisTaskStatus, Analysis from .task_controller import get_analysis_task_controller -from ..celery_app import celery_app -from src.server.oasisapi.files.views import handle_json_data -from src.server.oasisapi.schemas.serializers import ModelParametersSerializer -from .models import Analysis +from ...celery_app_v2 import v2 as celery_app_v2 logger = get_task_logger(__name__) @@ -238,7 +236,7 @@ def handle_task_failure(self, exc, task_id, args, kwargs, traceback): if self.name in ['record_run_analysis_result', 'record_generate_input_result']: _, analysis_pk, initiator_pk = args - from .models import Analysis + from ..models import Analysis initiator = get_user_model().objects.get(pk=initiator_pk) analysis = Analysis.objects.get(pk=analysis_pk) random_filename = '{}.txt'.format(uuid.uuid4().hex) @@ -298,8 +296,9 @@ def log_worker_monitor(sender, **k): logger.info('AWS_IS_GZIPPED: {}'.format(settings.AWS_IS_GZIPPED)) -@celery_app.task(name='run_register_worker', **celery_conf.worker_task_kwargs) -def run_register_worker(m_supplier, m_name, m_id, m_settings, m_version, m_conf): +@transaction.atomic +@celery_app_v2.task(name='run_register_worker_v2', **celery_conf.worker_task_kwargs) +def run_register_worker_v2(m_supplier, m_name, m_id, m_settings, m_version, m_conf): logger.info('model_supplier: {}, model_name: {}, model_id: {}'.format(m_supplier, m_name, m_id)) try: from django.contrib.auth.models import User @@ -330,6 +329,7 @@ def run_register_worker(m_supplier, m_name, m_id, m_settings, m_version, m_conf) request = HttpRequest() request.data = {**m_settings} request.method = 'post' + request.version = 'v2' request.user = model.creator handle_json_data(model, 'resource_file', request, ModelParametersSerializer) logger.info('Updated model settings') @@ -349,12 +349,16 @@ def run_register_worker(m_supplier, m_name, m_id, m_settings, m_version, m_conf) model.ver_ktools = m_version['ktools'] model.ver_oasislmf = m_version['oasislmf'] model.ver_platform = m_version['platform'] - model.save() logger.info('Updated model versions') except Exception as e: logger.info('Failed to set model veriosns:') logger.exception(str(e)) + # check current value of run_mode -> Set to V2 if null, if 'V1' set to both + if not model.run_mode: + model.run_mode = model.run_mode_choices.V2 + + model.save() # Log unhandled execptions except Exception as e: logger.exception(str(e)) @@ -386,7 +390,7 @@ def _find_celery_queue_reference(active_queues, queue_name): return None -@celery_app.task(bind=True, name='cancel_subtasks') +@celery_app_v2.task(bind=True, name='cancel_subtasks') def cancel_subtasks(self, analysis_pk): """ This is needed because AsyncResults().revoke() is not working correctly when called from the server container. using`app.control.revoke( .. )` does work @@ -405,7 +409,7 @@ def cancel_subtasks(self, analysis_pk): logger.info(i.reserved()) """ - from .models import Analysis + from ..models import Analysis analysis = Analysis.objects.get(pk=analysis_pk) _now = timezone.now() @@ -426,37 +430,35 @@ def cancel_subtasks(self, analysis_pk): subtask_qs.update(status=AnalysisTaskStatus.status_choices.CANCELLED, end_time=_now) -@celery_app.task(name='start_input_generation_task', **celery_conf.worker_task_kwargs) +@celery_app_v2.task(name='start_input_generation_task', **celery_conf.worker_task_kwargs) def start_input_generation_task(analysis_pk, initiator_pk, loc_lines): - from .models import Analysis + from ..models import Analysis analysis = Analysis.objects.get(pk=analysis_pk) initiator = get_user_model().objects.get(pk=initiator_pk) get_analysis_task_controller().generate_inputs(analysis, initiator, loc_lines) analysis.save() -@celery_app.task(name='start_loss_generation_task') +@celery_app_v2.task(name='start_loss_generation_task') def start_loss_generation_task(analysis_pk, initiator_pk, events_total): - from .models import Analysis + from ..models import Analysis analysis = Analysis.objects.get(pk=analysis_pk) initiator = get_user_model().objects.get(pk=initiator_pk) get_analysis_task_controller().generate_losses(analysis, initiator, events_total) analysis.save() -@celery_app.task(name='start_input_and_loss_generation_task') -def start_input_and_loss_generation_task(analysis_pk, initiator_pk): - from .models import Analysis +@celery_app_v2.task(name='start_input_and_loss_generation_task') +def start_input_and_loss_generation_task(analysis_pk, initiator_pk, loc_lines, events_total): + from ..models import Analysis analysis = Analysis.objects.get(pk=analysis_pk) initiator = get_user_model().objects.get(pk=initiator_pk) - - get_analysis_task_controller().generate_input_and_losses(analysis, initiator) - + get_analysis_task_controller().generate_input_and_losses(analysis, initiator, loc_lines, events_total) analysis.status = Analysis.status_choices.INPUTS_GENERATION_STARTED analysis.save() -@celery_app.task(bind=True, name='record_input_files') +@celery_app_v2.task(bind=True, name='record_input_files') def record_input_files(self, result, analysis_id=None, initiator_id=None, run_data_uuid=None, slug=None, analysis_finish_status=Analysis.status_choices.READY): record_sub_task_start.delay(analysis_id=analysis_id, task_slug=slug, task_id=self.request.id, dt=datetime.now().timestamp()) @@ -498,9 +500,9 @@ def record_input_files(self, result, analysis_id=None, initiator_id=None, run_da return result -@celery_app.task(bind=True, name='record_losses_files') +@celery_app_v2.task(bind=True, name='record_losses_files') def record_losses_files(self, result, analysis_id=None, initiator_id=None, slug=None, **kwargs): - from .models import Analysis + from ..models import Analysis record_sub_task_start.delay(analysis_id=analysis_id, task_slug=slug, task_id=self.request.id, dt=datetime.now().timestamp()) @@ -525,7 +527,7 @@ def record_losses_files(self, result, analysis_id=None, initiator_id=None, slug= return result -@celery_app.task(bind=True, name='record_sub_task_start') +@celery_app_v2.task(bind=True, name='record_sub_task_start') def record_sub_task_start(self, analysis_id=None, task_slug=None, task_id=None, dt=None): _now = timezone.now() if not dt else datetime.fromtimestamp(dt, tz=timezone.utc) @@ -551,7 +553,7 @@ def record_sub_task_start(self, analysis_id=None, task_slug=None, task_id=None, ) -@celery_app.task(bind=True, name='record_sub_task_success') +@celery_app_v2.task(bind=True, name='record_sub_task_success') def record_sub_task_success(self, res, analysis_id=None, initiator_id=None, task_slug=None): log_location = res.get('log_location') error_location = res.get('error_location') @@ -579,7 +581,7 @@ def record_sub_task_success(self, res, analysis_id=None, initiator_id=None, task ) -@celery_app.task(bind=True, name='record_sub_task_failure') +@celery_app_v2.task(bind=True, name='record_sub_task_failure') def record_sub_task_failure(self, *args, analysis_id=None, initiator_id=None, task_slug=None): tb = _traceback_from_errback_args(*args) task_id = self.request.parent_id @@ -602,7 +604,7 @@ def record_sub_task_failure(self, *args, analysis_id=None, initiator_id=None, ta ) -@celery_app.task(bind=True, name='chord_error_callback') +@celery_app_v2.task(bind=True, name='chord_error_callback') def chord_error_callback(self, analysis_id): unfinished_statuses = [AnalysisTaskStatus.status_choices.QUEUED, AnalysisTaskStatus.status_choices.STARTED] ids_to_revoke = AnalysisTaskStatus.objects.filter( @@ -621,7 +623,7 @@ def chord_error_callback(self, analysis_id): ) -@celery_app.task(name='handle_task_failure') +@celery_app_v2.task(name='handle_task_failure') def handle_task_failure( *args, analysis_id=None, @@ -634,7 +636,7 @@ def handle_task_failure( logger.info('analysis_pk: {}, initiator_pk: {}, traceback: {}, run_data_uuid: {}, failure_status: {}'.format( analysis_id, initiator_id, tb, run_data_uuid, failure_status)) try: - from .models import Analysis + from ..models import Analysis analysis = Analysis.objects.get(pk=analysis_id) analysis.status = failure_status @@ -676,7 +678,7 @@ def mark_task_as_queued_receiver(*args, headers=None, body=None, **kwargs): mark_task_as_queued(analysis_id, slug, headers['id'], timezone.now().timestamp()) -@celery_app.task(name='mark_task_as_queued') +@celery_app_v2.task(name='mark_task_as_queued') def mark_task_as_queued(analysis_id, slug, task_id, dt): AnalysisTaskStatus.objects.filter( analysis_id=analysis_id, @@ -688,7 +690,7 @@ def mark_task_as_queued(analysis_id, slug, task_id, dt): ) -@celery_app.task(name='subtask_error_log') +@celery_app_v2.task(name='subtask_error_log') def subtask_error_log(analysis_id, initiator_id, slug, task_id, log_file): AnalysisTaskStatus.objects.filter( analysis_id=analysis_id, @@ -703,10 +705,10 @@ def subtask_error_log(analysis_id, initiator_id, slug, task_id, log_file): ) -@celery_app.task(name='set_task_status') +@celery_app_v2.task(name='set_task_status_v2') def set_task_status(analysis_pk, task_status, dt): try: - from .models import Analysis + from ..models import Analysis analysis = Analysis.objects.get(pk=analysis_pk) analysis.status = task_status analysis.task_started = datetime.fromtimestamp(dt, tz=timezone.utc) @@ -717,7 +719,7 @@ def set_task_status(analysis_pk, task_status, dt): logger.exception(str(e)) -@celery_app.task(name='update_task_id') +@celery_app_v2.task(name='update_task_id') def update_task_id(task_update_list): for task in task_update_list: task_id, analysis_id, slug = task diff --git a/src/server/oasisapi/analysis_models/tests/__init__.py b/src/server/oasisapi/analyses/v2_api/tests/__init__.py similarity index 100% rename from src/server/oasisapi/analysis_models/tests/__init__.py rename to src/server/oasisapi/analyses/v2_api/tests/__init__.py diff --git a/src/server/oasisapi/analyses/tests/fakes.py b/src/server/oasisapi/analyses/v2_api/tests/fakes.py similarity index 95% rename from src/server/oasisapi/analyses/tests/fakes.py rename to src/server/oasisapi/analyses/v2_api/tests/fakes.py index f9fb62970..f5292939f 100644 --- a/src/server/oasisapi/analyses/tests/fakes.py +++ b/src/server/oasisapi/analyses/v2_api/tests/fakes.py @@ -2,8 +2,8 @@ from celery.states import STARTED from model_mommy import mommy -from ...files.tests.fakes import fake_related_file -from ..models import Analysis, AnalysisTaskStatus +from src.server.oasisapi.files.tests.fakes import fake_related_file +from ...models import Analysis, AnalysisTaskStatus class FakeAsyncResultFactory(object): diff --git a/src/server/oasisapi/analyses/tests/test_analysis_api.py b/src/server/oasisapi/analyses/v2_api/tests/test_analysis_api.py similarity index 83% rename from src/server/oasisapi/analyses/tests/test_analysis_api.py rename to src/server/oasisapi/analyses/v2_api/tests/test_analysis_api.py index 4cb5f6930..c72600f41 100644 --- a/src/server/oasisapi/analyses/tests/test_analysis_api.py +++ b/src/server/oasisapi/analyses/v2_api/tests/test_analysis_api.py @@ -12,24 +12,25 @@ from mock import patch from rest_framework_simplejwt.tokens import AccessToken -from ...files.tests.fakes import fake_related_file -from ...analysis_models.tests.fakes import fake_analysis_model -from ...portfolios.tests.fakes import fake_portfolio -from ...auth.tests.fakes import fake_user, add_fake_group -from ...data_files.tests.fakes import fake_data_file -from ..models import Analysis +from src.server.oasisapi.files.tests.fakes import fake_related_file +from src.server.oasisapi.analysis_models.v2_api.tests.fakes import fake_analysis_model +from src.server.oasisapi.portfolios.v2_api.tests.fakes import fake_portfolio +from src.server.oasisapi.auth.tests.fakes import fake_user, add_fake_group +from src.server.oasisapi.data_files.v2_api.tests.fakes import fake_data_file +from ...models import Analysis from .fakes import fake_analysis # Override default deadline for all tests to 8s settings.register_profile("ci", deadline=800.0) settings.load_profile("ci") +NAMESPACE = 'v2-analyses' class AnalysisApi(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -37,7 +38,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): analysis = fake_analysis() response = self.app.get( - reverse('analysis-detail', kwargs={'version': 'v1', 'pk': analysis.pk + 1}), + reverse(f'{NAMESPACE}:analysis-detail', kwargs={'pk': analysis.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -50,7 +51,7 @@ def test_name_is_not_provided___response_is_400(self): user = fake_user() response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -66,7 +67,7 @@ def test_cleaned_name_is_empty___response_is_400(self, name): user = fake_user() response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -77,6 +78,26 @@ def test_cleaned_name_is_empty___response_is_400(self, name): self.assertEqual(400, response.status_code) + @given(name=text(alphabet=string.ascii_letters, max_size=10, min_size=1)) + def test_cleaned_name_portfolio_and_model_are_present___run_mode_null_response_is_400(self, name): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + self.maxDiff = None + user = fake_user() + model = fake_analysis_model() + portfolio = fake_portfolio(location_file=fake_related_file()) + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': model.pk}), + content_type='application/json' + ) + self.assertEqual(400, response.status_code) + @given(name=text(alphabet=string.ascii_letters, max_size=10, min_size=1)) def test_cleaned_name_portfolio_and_model_are_present___object_is_created(self, name): with TemporaryDirectory() as d: @@ -84,10 +105,12 @@ def test_cleaned_name_portfolio_and_model_are_present___object_is_created(self, self.maxDiff = None user = fake_user() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() portfolio = fake_portfolio(location_file=fake_related_file()) response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -107,10 +130,11 @@ def test_cleaned_name_portfolio_and_model_are_present___object_is_created(self, analysis.output_file = fake_related_file() analysis.run_traceback_file = fake_related_file() analysis.run_log_file = fake_related_file() + analysis.run_mode = Analysis.run_mode_choices.V2 analysis.save() response = self.app.get( - analysis.get_absolute_url(), + analysis.get_absolute_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -125,16 +149,17 @@ def test_cleaned_name_portfolio_and_model_are_present___object_is_created(self, 'name': name, 'portfolio': portfolio.pk, 'model': model.pk, - 'settings_file': response.request.application_url + analysis.get_absolute_settings_file_url(), - 'settings': response.request.application_url + analysis.get_absolute_settings_url(), - 'input_file': response.request.application_url + analysis.get_absolute_input_file_url(), - 'lookup_errors_file': response.request.application_url + analysis.get_absolute_lookup_errors_file_url(), - 'lookup_success_file': response.request.application_url + analysis.get_absolute_lookup_success_file_url(), - 'lookup_validation_file': response.request.application_url + analysis.get_absolute_lookup_validation_file_url(), - 'input_generation_traceback_file': response.request.application_url + analysis.get_absolute_input_generation_traceback_file_url(), - 'output_file': response.request.application_url + analysis.get_absolute_output_file_url(), - 'run_log_file': response.request.application_url + analysis.get_absolute_run_log_file_url(), - 'run_traceback_file': response.request.application_url + analysis.get_absolute_run_traceback_file_url(), + 'settings_file': response.request.application_url + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), + 'settings': response.request.application_url + analysis.get_absolute_settings_url(namespace=NAMESPACE), + 'input_file': response.request.application_url + analysis.get_absolute_input_file_url(namespace=NAMESPACE), + 'lookup_errors_file': response.request.application_url + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), + 'lookup_success_file': response.request.application_url + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), + 'lookup_validation_file': response.request.application_url + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), + 'input_generation_traceback_file': response.request.application_url + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), + 'output_file': response.request.application_url + analysis.get_absolute_output_file_url(namespace=NAMESPACE), + 'run_log_file': response.request.application_url + analysis.get_absolute_run_log_file_url(namespace=NAMESPACE), + 'run_mode': Analysis.run_mode_choices.V2, + 'run_traceback_file': response.request.application_url + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), 'status': Analysis.status_choices.NEW, 'status_count': {'CANCELLED': 0, 'COMPLETED': 0, @@ -144,15 +169,16 @@ def test_cleaned_name_portfolio_and_model_are_present___object_is_created(self, 'STARTED': 0, 'TOTAL': 0, 'TOTAL_IN_QUEUE': 0}, - 'storage_links': 'http://testserver/v1/analyses/1/storage_links/', + 'storage_links': 'http://testserver/v2/analyses/1/storage_links/', 'sub_task_count': 0, 'sub_task_error_ids': [], - 'sub_task_list': 'http://testserver/v1/analyses/1/sub_task_list/', - 'summary_levels_file': response.request.application_url + analysis.get_absolute_summary_levels_file_url(), + 'sub_task_list': 'http://testserver/v2/analyses/1/sub_task_list/', + 'summary_levels_file': response.request.application_url + analysis.get_absolute_summary_levels_file_url(namespace=NAMESPACE), 'task_started': None, 'task_finished': None, 'groups': [], 'analysis_chunks': None, + 'chunking_configuration': 'http://testserver/v2/analyses/1/chunking_configuration/', 'lookup_chunks': None, 'priority': 4, }, response.json) @@ -166,10 +192,12 @@ def test_complex_model_file_present___object_is_created(self, name): self.maxDiff = None user = fake_user() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() portfolio = fake_portfolio(location_file=fake_related_file()) response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -185,7 +213,7 @@ def test_complex_model_file_present___object_is_created(self, name): analysis.save() response = self.app.get( - analysis.get_absolute_url(), + analysis.get_absolute_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -209,6 +237,7 @@ def test_complex_model_file_present___object_is_created(self, name): 'input_generation_traceback_file': None, 'output_file': None, 'run_log_file': None, + 'run_mode': None, 'run_traceback_file': None, 'status': 'NEW', 'status_count': {'CANCELLED': 0, @@ -219,16 +248,17 @@ def test_complex_model_file_present___object_is_created(self, name): 'STARTED': 0, 'TOTAL': 0, 'TOTAL_IN_QUEUE': 0}, - 'storage_links': 'http://testserver/v1/analyses/1/storage_links/', + 'storage_links': 'http://testserver/v2/analyses/1/storage_links/', 'sub_task_count': 0, 'sub_task_error_ids': [], - 'sub_task_list': 'http://testserver/v1/analyses/1/sub_task_list/', - 'storage_links': response.request.application_url + analysis.get_absolute_storage_url(), + 'sub_task_list': 'http://testserver/v2/analyses/1/sub_task_list/', + 'storage_links': response.request.application_url + analysis.get_absolute_storage_url(namespace=NAMESPACE), 'summary_levels_file': None, 'groups': [], 'task_started': None, 'task_finished': None, 'analysis_chunks': None, + 'chunking_configuration': 'http://testserver/v2/analyses/1/chunking_configuration/', 'lookup_chunks': None, 'priority': 4, }, response.json) @@ -239,7 +269,7 @@ def test_model_does_not_exist___response_is_400(self): model = fake_analysis_model() response = self.app.patch( - analysis.get_absolute_url(), + analysis.get_absolute_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -254,9 +284,11 @@ def test_model_does_exist___response_is_200(self): user = fake_user() analysis = fake_analysis() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() response = self.app.patch( - analysis.get_absolute_url(), + analysis.get_absolute_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -279,11 +311,13 @@ def test_portfolio_group_inheritance___same_groups_as_portfolio(self, name, grou user = fake_user() group = add_fake_group(user, group_name) model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() portfolio = fake_portfolio(location_file=fake_related_file()) # Deny due to not in the same group as model response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -299,7 +333,7 @@ def test_portfolio_group_inheritance___same_groups_as_portfolio(self, name, grou # Deny due to not in the same group as portfolio response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -315,7 +349,7 @@ def test_portfolio_group_inheritance___same_groups_as_portfolio(self, name, grou # Successfully create response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -327,7 +361,7 @@ def test_portfolio_group_inheritance___same_groups_as_portfolio(self, name, grou analysis = Analysis.objects.get(pk=response.json['id']) response = self.app.get( - analysis.get_absolute_url(), + analysis.get_absolute_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -362,6 +396,7 @@ def test_multiple_analyses_with_different_groups___user_should_not_see_each_othe group3.save() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V1 model.groups.add(group1) model.groups.add(group2) model.groups.add(group3) @@ -377,7 +412,7 @@ def test_multiple_analyses_with_different_groups___user_should_not_see_each_othe # Create an analysis with portfolio1 - group2 response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user1)) }, @@ -389,7 +424,7 @@ def test_multiple_analyses_with_different_groups___user_should_not_see_each_othe # Create an analysis with portfolio2 - group3 response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user2)) }, @@ -401,7 +436,7 @@ def test_multiple_analyses_with_different_groups___user_should_not_see_each_othe # User1 should only se analysis 1 with groups2 response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user1)) }, @@ -413,7 +448,7 @@ def test_multiple_analyses_with_different_groups___user_should_not_see_each_othe # User2 should only se analysis2 with groups3 response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user2)) }, @@ -440,6 +475,7 @@ def test_multiple_analyses_with_different_groups___user_without_group_should_not user2 = fake_user() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 model.groups.add(group1) model.save() @@ -449,7 +485,7 @@ def test_multiple_analyses_with_different_groups___user_without_group_should_not # Create an analysis response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user1)) }, @@ -460,7 +496,7 @@ def test_multiple_analyses_with_different_groups___user_without_group_should_not # User2 should not see any analysis response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user2)) }, @@ -484,6 +520,7 @@ def test_user_with_and_without_group_can_access_portfolio_without_group(self, na user_without_group = fake_user() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V1 model.save() portfolio1 = fake_portfolio(location_file=fake_related_file()) @@ -491,7 +528,7 @@ def test_user_with_and_without_group_can_access_portfolio_without_group(self, na # Create an analysis response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_without_group)) }, @@ -502,7 +539,7 @@ def test_user_with_and_without_group_can_access_portfolio_without_group(self, na # user_with_group1 should see the analysis response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_with_group1)) }, @@ -513,7 +550,7 @@ def test_user_with_and_without_group_can_access_portfolio_without_group(self, na # user_without_group should see the analysis response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_without_group)) }, @@ -537,6 +574,7 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): user_without_group = fake_user() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 model.save() portfolio1 = fake_portfolio(location_file=fake_related_file()) @@ -544,7 +582,7 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): # Create an analysis response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_without_group)) }, @@ -555,7 +593,7 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): # user_with_group1 should see the analysis response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_with_group1)) }, @@ -567,7 +605,7 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): # user_with_group1 should be allowed to write response = self.app.post( - reverse('analysis-settings', kwargs={'version': 'v1', 'pk': analysis_id}), + reverse(f'{NAMESPACE}:analysis-settings', kwargs={'pk': analysis_id}), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_with_group1)) }, @@ -578,7 +616,7 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): # user_without_group should see the analysis response = self.app.get( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_without_group)) }, @@ -589,7 +627,7 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): # user_without_group should be allowed to write response = self.app.post( - reverse('analysis-settings', kwargs={'version': 'v1', 'pk': analysis_id}), + reverse(f'{NAMESPACE}:analysis-settings', kwargs={'pk': analysis_id}), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_without_group)) }, @@ -604,14 +642,17 @@ def test_modify_analysis_without_group___successfully(self, name, group_name1): def test_create_no_priority___successfully_set_default(self, name): portfolio = fake_portfolio(location_file=fake_related_file()) + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() # Create an analysis response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(fake_user())) }, - params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': fake_analysis_model().pk}), + params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': model.pk}), content_type='application/json', ) self.assertEqual(201, response.status_code) @@ -628,13 +669,17 @@ def test_create_as_admin_low_priority___successfully(self, name): user.save() portfolio = fake_portfolio(location_file=fake_related_file()) + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 + model.save() + # Create an analysis response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, - params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': fake_analysis_model().pk, 'priority': 1}), + params=json.dumps({'name': name, 'portfolio': portfolio.pk, 'model': model.pk, 'priority': 1}), content_type='application/json', ) self.assertEqual(201, response.status_code) @@ -649,7 +694,7 @@ def test_create_as_no_admin_low_priority___rejected(self, name): # Create an analysis response = self.app.post( - reverse('analysis-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(fake_user())) }, @@ -665,7 +710,7 @@ class AnalysisRun(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.post(analysis.get_absolute_run_url(), expect_errors=True) + response = self.app.post(analysis.get_absolute_run_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -673,7 +718,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): analysis = fake_analysis() response = self.app.post( - reverse('analysis-run', kwargs={'version': 'v1', 'pk': analysis.pk + 1}), + reverse(f'{NAMESPACE}:analysis-run', kwargs={'pk': analysis.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -688,13 +733,27 @@ def test_user_is_authenticated_object_exists___run_is_called(self): analysis = fake_analysis() self.app.post( - analysis.get_absolute_run_url(), + analysis.get_absolute_run_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } ) - run_mock.assert_called_once_with(analysis, user) + run_mock.assert_called_once_with(analysis, user, run_mode_override=None) + + def test_user_is_authenticated_object_exists___run_is_called__with_override(self): + with patch('src.server.oasisapi.analyses.models.Analysis.run', autospec=True) as run_mock: + user = fake_user() + analysis = fake_analysis() + url_param = '?run_mode_override=V2' + + self.app.post( + analysis.get_absolute_run_url(namespace=NAMESPACE) + url_param, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + run_mock.assert_called_once_with(analysis, user, run_mode_override='V2') def test_user_is_not_in_same_model_group___run_is_denied(self): user = fake_user() @@ -704,7 +763,7 @@ def test_user_is_not_in_same_model_group___run_is_denied(self): analysis.model.save() response = self.app.post( - analysis.get_absolute_run_url(), + analysis.get_absolute_run_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -719,7 +778,7 @@ class AnalysisCancel(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.post(analysis.get_absolute_cancel_analysis_url(), expect_errors=True) + response = self.app.post(analysis.get_absolute_cancel_analysis_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -727,7 +786,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): analysis = fake_analysis() response = self.app.post( - reverse('analysis-cancel', kwargs={'version': 'v1', 'pk': analysis.pk + 1}), + reverse(f'{NAMESPACE}:analysis-cancel', kwargs={'pk': analysis.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -742,7 +801,7 @@ def test_user_is_authenticated_object_exists___cancel_is_called(self): analysis = fake_analysis() self.app.post( - analysis.get_absolute_cancel_analysis_url(), + analysis.get_absolute_cancel_analysis_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -758,7 +817,7 @@ def test_user_is_not_in_same_model_group___cancel_is_denied(self): analysis.model.save() response = self.app.post( - analysis.get_absolute_cancel_analysis_url(), + analysis.get_absolute_cancel_analysis_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -773,7 +832,7 @@ class AnalysisGenerateInputs(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.post(analysis.get_absolute_generate_inputs_url(), expect_errors=True) + response = self.app.post(analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -781,7 +840,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): analysis = fake_analysis() response = self.app.post( - reverse('analysis-generate-inputs', kwargs={'version': 'v1', 'pk': analysis.pk + 1}), + reverse(f'{NAMESPACE}:analysis-generate-inputs', kwargs={'pk': analysis.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -796,13 +855,26 @@ def test_user_is_authenticated_object_exists___generate_inputs_is_called(self): analysis = fake_analysis() self.app.post( - analysis.get_absolute_generate_inputs_url(), + analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } ) + generate_inputs_mock.assert_called_once_with(analysis, user, run_mode_override=None) + + def test_user_is_authenticated_object_exists___generate_inputs_is_called__with_override(self): + with patch('src.server.oasisapi.analyses.models.Analysis.generate_inputs', autospec=True) as generate_inputs_mock: + user = fake_user() + analysis = fake_analysis() + url_param = '?run_mode_override=V1' - generate_inputs_mock.assert_called_once_with(analysis, user) + self.app.post( + analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE) + url_param, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + generate_inputs_mock.assert_called_once_with(analysis, user, run_mode_override='V1') def test_user_is_not_in_same_model_group___run_is_denied(self): user = fake_user() @@ -812,7 +884,7 @@ def test_user_is_not_in_same_model_group___run_is_denied(self): analysis.model.save() response = self.app.post( - analysis.get_absolute_generate_inputs_url(), + analysis.get_absolute_generate_inputs_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -827,7 +899,7 @@ class AnalysisCancelInputsGeneration(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.post(analysis.get_absolute_cancel_inputs_generation_url(), expect_errors=True) + response = self.app.post(analysis.get_absolute_cancel_inputs_generation_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -835,7 +907,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): analysis = fake_analysis() response = self.app.post( - reverse('analysis-cancel-generate-inputs', kwargs={'version': 'v1', 'pk': analysis.pk + 1}), + reverse(f'{NAMESPACE}:analysis-cancel-generate-inputs', kwargs={'pk': analysis.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -850,7 +922,7 @@ def test_user_is_authenticated_object_exists___generate_inputs_generation_is_cal analysis = fake_analysis() self.app.post( - analysis.get_absolute_cancel_inputs_generation_url(), + analysis.get_absolute_cancel_inputs_generation_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -866,7 +938,7 @@ def test_user_is_not_in_same_model_group___cancel_is_denied(self): analysis.model.save() response = self.app.post( - analysis.get_absolute_cancel_inputs_generation_url(), + analysis.get_absolute_cancel_inputs_generation_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -881,7 +953,7 @@ class AnalysisCopy(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.post(analysis.get_absolute_copy_url(), expect_errors=True) + response = self.app.post(analysis.get_absolute_copy_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -889,7 +961,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): analysis = fake_analysis() response = self.app.post( - reverse('analysis-copy', kwargs={'version': 'v1', 'pk': analysis.pk + 1}), + reverse(f'{NAMESPACE}:analysis-copy', kwargs={'pk': analysis.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -903,7 +975,7 @@ def test_new_object_is_created(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -917,7 +989,7 @@ def test_no_new_name_is_provided___copy_is_appended_to_name(self, name): analysis = fake_analysis(name=name) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -931,7 +1003,7 @@ def test_new_name_is_provided___new_name_is_set_on_new_object(self, orig_name, n analysis = fake_analysis(name=orig_name) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -947,7 +1019,7 @@ def test_state_is_reset(self, status): analysis = fake_analysis(status=status) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -960,7 +1032,7 @@ def test_creator_is_set_to_caller(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -974,7 +1046,7 @@ def test_run_task_id_is_reset(self, task_id): analysis = fake_analysis(run_task_id=task_id) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -988,7 +1060,7 @@ def test_generate_inputs_task_id_is_reset(self, task_id): analysis = fake_analysis(generate_inputs_task_id=task_id) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -1001,7 +1073,7 @@ def test_portfolio_is_not_supplied___portfolio_is_copied(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -1015,7 +1087,7 @@ def test_portfolio_is_supplied___portfolio_is_replaced(self): new_portfolio = fake_portfolio(location_file=fake_related_file()) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1030,7 +1102,7 @@ def test_model_is_not_supplied___model_is_copied(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -1042,9 +1114,11 @@ def test_model_is_supplied___model_is_replaced(self): user = fake_user() analysis = fake_analysis() new_model = fake_analysis_model() + new_model.run_mode = new_model.run_mode_choices.V2 + new_model.save() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1059,7 +1133,7 @@ def test_complex_model_file_is_not_supplied___model_is_copied(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -1077,7 +1151,7 @@ def test_complex_model_file_is_supplied___model_is_replaced(self): new_cmf_2 = fake_data_file() response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1097,7 +1171,7 @@ def test_settings_file_is_not_supplied___settings_file_is_copied(self): analysis = fake_analysis(settings_file=fake_related_file(file='{}')) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -1112,7 +1186,7 @@ def test_input_file_is_not_supplied___input_file_is_not_copied(self): analysis = fake_analysis(input_file=fake_related_file()) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) } @@ -1127,7 +1201,7 @@ def test_lookup_errors_file_is_cleared(self): analysis = fake_analysis(lookup_errors_file=fake_related_file()) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1142,7 +1216,7 @@ def test_lookup_success_file_is_cleared(self): analysis = fake_analysis(lookup_success_file=fake_related_file()) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1157,7 +1231,7 @@ def test_lookup_validation_file_is_cleared(self): analysis = fake_analysis(lookup_validation_file=fake_related_file()) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1172,7 +1246,7 @@ def test_output_file_is_cleared(self): analysis = fake_analysis(output_file=fake_related_file()) response = self.app.post( - analysis.get_absolute_copy_url(), + analysis.get_absolute_copy_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1185,7 +1259,7 @@ class AnalysisSettingsJson(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_settings_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_settings_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_settings_json_is_not_present___get_response_is_404(self): @@ -1193,7 +1267,7 @@ def test_settings_json_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_settings_url(), + analysis.get_absolute_settings_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1207,7 +1281,7 @@ def test_settings_json_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_settings_url(), + analysis.get_absolute_settings_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1247,7 +1321,7 @@ def test_settings_json_is_not_valid___response_is_400(self): } response = self.app.post( - analysis.get_absolute_settings_url(), + analysis.get_absolute_settings_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1297,7 +1371,7 @@ def test_settings_json_is_uploaded___can_be_retrieved(self): } self.app.post( - analysis.get_absolute_settings_url(), + analysis.get_absolute_settings_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1306,7 +1380,7 @@ def test_settings_json_is_uploaded___can_be_retrieved(self): ) response = self.app.get( - analysis.get_absolute_settings_url(), + analysis.get_absolute_settings_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1319,7 +1393,7 @@ class AnalysisSettingsFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_settings_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_settings_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_settings_file_is_not_present___get_response_is_404(self): @@ -1327,7 +1401,7 @@ def test_settings_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_settings_file_url(), + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1341,7 +1415,7 @@ def test_settings_file_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_settings_file_url(), + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1357,7 +1431,7 @@ def test_settings_file_is_not_a_valid_format___response_is_400(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_settings_file_url(), + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1377,7 +1451,7 @@ def test_settings_file_is_uploaded___file_can_be_retrieved(self, file_content): analysis = fake_analysis() self.app.post( - analysis.get_absolute_settings_file_url(), + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1387,7 +1461,7 @@ def test_settings_file_is_uploaded___file_can_be_retrieved(self, file_content): ) response = self.app.get( - analysis.get_absolute_settings_file_url(), + analysis.get_absolute_settings_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1401,7 +1475,7 @@ class AnalysisInputFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_input_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_input_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_input_file_is_not_present___get_response_is_404(self): @@ -1409,7 +1483,7 @@ def test_input_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_input_file_url(), + analysis.get_absolute_input_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1426,7 +1500,7 @@ def test_input_file_is_present___file_can_be_retrieved(self, file_content, conte analysis = fake_analysis(input_file=fake_related_file(file=file_content, content_type=content_type)) response = self.app.get( - analysis.get_absolute_input_file_url(), + analysis.get_absolute_input_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1440,7 +1514,7 @@ class AnalysisLookupErrorsFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_lookup_errors_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_lookup_errors_file_is_not_present___get_response_is_404(self): @@ -1448,7 +1522,7 @@ def test_lookup_errors_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_lookup_errors_file_url(), + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1463,7 +1537,7 @@ def test_lookup_errors_file_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_lookup_errors_file_url(), + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1481,7 +1555,7 @@ def test_lookup_errors_file_is_present___file_can_be_retrieved(self, file_conten analysis = fake_analysis(lookup_errors_file=fake_related_file(file=file_content, content_type=content_type)) response = self.app.get( - analysis.get_absolute_lookup_errors_file_url(), + analysis.get_absolute_lookup_errors_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1495,7 +1569,7 @@ class AnalysisLookupSuccessFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_lookup_success_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_lookup_success_file_is_not_present___get_response_is_404(self): @@ -1503,7 +1577,7 @@ def test_lookup_success_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_lookup_success_file_url(), + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1518,7 +1592,7 @@ def test_lookup_success_file_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_lookup_success_file_url(), + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1536,7 +1610,7 @@ def test_lookup_success_file_is_present___file_can_be_retrieved(self, file_conte analysis = fake_analysis(lookup_success_file=fake_related_file(file=file_content, content_type=content_type)) response = self.app.get( - analysis.get_absolute_lookup_success_file_url(), + analysis.get_absolute_lookup_success_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1550,7 +1624,7 @@ class AnalysisLookupValidationFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_lookup_validation_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_lookup_validation_file_is_not_present___get_response_is_404(self): @@ -1558,7 +1632,7 @@ def test_lookup_validation_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_lookup_validation_file_url(), + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1573,7 +1647,7 @@ def test_lookup_validation_file_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_lookup_validation_file_url(), + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1591,7 +1665,7 @@ def test_lookup_validation_file_is_present___file_can_be_retrieved(self, file_co analysis = fake_analysis(lookup_validation_file=fake_related_file(file=file_content, content_type=content_type)) response = self.app.get( - analysis.get_absolute_lookup_validation_file_url(), + analysis.get_absolute_lookup_validation_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1605,7 +1679,7 @@ class AnalysisInputGenerationTracebackFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_input_generation_traceback_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_input_generation_traceback_file_is_not_present___get_response_is_404(self): @@ -1613,7 +1687,7 @@ def test_input_generation_traceback_file_is_not_present___get_response_is_404(se analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_input_generation_traceback_file_url(), + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1627,7 +1701,7 @@ def test_input_generation_traceback_file_is_not_present___delete_response_is_404 analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_input_generation_traceback_file_url(), + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1644,7 +1718,7 @@ def test_input_generation_traceback_file_is_present___file_can_be_retrieved(self analysis = fake_analysis(input_generation_traceback_file=fake_related_file(file=file_content, content_type='text/plain')) response = self.app.get( - analysis.get_absolute_input_generation_traceback_file_url(), + analysis.get_absolute_input_generation_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1658,7 +1732,7 @@ class AnalysisOutputFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_output_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_output_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_output_file_is_not_present___get_response_is_404(self): @@ -1666,7 +1740,7 @@ def test_output_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_output_file_url(), + analysis.get_absolute_output_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1680,7 +1754,7 @@ def test_output_file_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_output_file_url(), + analysis.get_absolute_output_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1696,7 +1770,7 @@ def test_output_file_is_not_valid_format___post_response_is_405(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_output_file_url(), + analysis.get_absolute_output_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1716,7 +1790,7 @@ def test_output_file_is_present___file_can_be_retrieved(self, file_content, cont analysis = fake_analysis(output_file=fake_related_file(file=file_content, content_type=content_type)) response = self.app.get( - analysis.get_absolute_output_file_url(), + analysis.get_absolute_output_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1730,7 +1804,7 @@ class AnalysisRunTracebackFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): analysis = fake_analysis() - response = self.app.get(analysis.get_absolute_run_traceback_file_url(), expect_errors=True) + response = self.app.get(analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_run_traceback_file_is_not_present___get_response_is_404(self): @@ -1738,7 +1812,7 @@ def test_run_traceback_file_is_not_present___get_response_is_404(self): analysis = fake_analysis() response = self.app.get( - analysis.get_absolute_run_traceback_file_url(), + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1752,7 +1826,7 @@ def test_run_traceback_file_is_not_present___delete_response_is_404(self): analysis = fake_analysis() response = self.app.delete( - analysis.get_absolute_run_traceback_file_url(), + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1768,7 +1842,7 @@ def test_run_traceback_file_is_not_valid_format___post_response_is_405(self): analysis = fake_analysis() response = self.app.post( - analysis.get_absolute_run_traceback_file_url(), + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1788,7 +1862,7 @@ def test_run_traceback_file_is_present___file_can_be_retrieved(self, file_conten analysis = fake_analysis(run_traceback_file=fake_related_file(file=file_content, content_type=content_type)) response = self.app.get( - analysis.get_absolute_run_traceback_file_url(), + analysis.get_absolute_run_traceback_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, diff --git a/src/server/oasisapi/analyses/tests/test_analysis_model.py b/src/server/oasisapi/analyses/v2_api/tests/test_analysis_model.py similarity index 82% rename from src/server/oasisapi/analyses/tests/test_analysis_model.py rename to src/server/oasisapi/analyses/v2_api/tests/test_analysis_model.py index fed953c07..18ea768ef 100644 --- a/src/server/oasisapi/analyses/tests/test_analysis_model.py +++ b/src/server/oasisapi/analyses/v2_api/tests/test_analysis_model.py @@ -16,11 +16,11 @@ from unittest.mock import ANY, MagicMock from src.conf import iniconf -from ...portfolios.tests.fakes import fake_portfolio -from ...files.tests.fakes import fake_related_file -from ...auth.tests.fakes import fake_user +from src.server.oasisapi.portfolios.v2_api.tests.fakes import fake_portfolio +from src.server.oasisapi.files.tests.fakes import fake_related_file +from src.server.oasisapi.auth.tests.fakes import fake_user # from ..models import AnalysisTaskStatus -from ..models import Analysis +from ...models import Analysis # Override default deadline for all tests to 8s settings.register_profile("ci", deadline=800.0) @@ -177,15 +177,20 @@ def test_state_is_ready___run_is_started(self, status, task_gen_id, task_run_id) task_sig = Mock() task_chain = Mock() + analysis.model.run_mode = analysis.model.run_mode_choices.V2 + analysis.model.save() + res_factory = FakeAsyncResultFactory(target_task_id=task_gen_id) task_sig.apply_async.return_value = res_factory(task_gen_id) - with patch('src.server.oasisapi.analyses.models.Analysis.start_input_and_loss_generation_signature', PropertyMock(return_value=task_sig)): + with patch('src.server.oasisapi.analyses.models.Analysis.v2_start_input_and_loss_generation_signature', PropertyMock(return_value=task_sig)): analysis.generate_and_run(initiator) + loc_lines = 4 + events = None task_sig.on_error.assert_called_once() task_sig.apply_async.assert_called_with( - args=[analysis.pk, initiator.pk], + args=[analysis.pk, initiator.pk, loc_lines, events], priority=4 ) @@ -205,14 +210,51 @@ def test_state_is_running_or_generating_inputs___validation_error_is_raised_revo task_sig = Mock() task_sig.delay.return_value = res_factory(task_id) - with patch('src.server.oasisapi.analyses.models.Analysis.start_input_and_loss_generation_signature', PropertyMock(return_value=task_sig)): + with patch('src.server.oasisapi.analyses.models.Analysis.v2_start_input_and_loss_generation_signature', PropertyMock(return_value=task_sig)): + analysis = fake_analysis(status=status, run_task_id=task_id) + analysis.model.run_mode = analysis.model.run_mode_choices.V2 + analysis.model.save() + + with self.assertRaises(ValidationError) as ex: + analysis.generate_and_run(initiator) + + self.maxDiff = None + self.assertEqual({ + 'portfolio': ['"location_file" must not be null'], + 'settings_file': ['Must not be null'], + 'status': ['Analysis status must be one of [NEW, INPUTS_GENERATION_ERROR, INPUTS_GENERATION_CANCELLED, READY, RUN_COMPLETED, RUN_CANCELLED, RUN_ERROR]'], + }, ex.exception.detail) + + self.assertEqual(status, analysis.status) + self.assertFalse(res_factory.revoke_called) + + @given( + status=sampled_from([ + Analysis.status_choices.INPUTS_GENERATION_QUEUED, + Analysis.status_choices.INPUTS_GENERATION_STARTED, + Analysis.status_choices.RUN_QUEUED, + Analysis.status_choices.RUN_STARTED, + ]), + task_id=text(min_size=1, max_size=10, alphabet=string.ascii_letters), + ) + def test_state_is_running_or_generating_inputs___run_mode_invalid__validation_error_raised(self, status, task_id): + res_factory = FakeAsyncResultFactory(target_task_id=task_id) + initiator = fake_user() + + task_sig = Mock() + task_sig.delay.return_value = res_factory(task_id) + + with patch('src.server.oasisapi.analyses.models.Analysis.v2_start_input_and_loss_generation_signature', PropertyMock(return_value=task_sig)): analysis = fake_analysis(status=status, run_task_id=task_id) + analysis.model.run_mode = analysis.model.run_mode_choices.V1 + analysis.model.save() with self.assertRaises(ValidationError) as ex: analysis.generate_and_run(initiator) self.maxDiff = None self.assertEqual({ + 'model': ['Model pk "1" - Unsupported Operation, "run_mode" must be "V2", not "V1"'], 'portfolio': ['"location_file" must not be null'], 'settings_file': ['Must not be null'], 'status': ['Analysis status must be one of [NEW, INPUTS_GENERATION_ERROR, INPUTS_GENERATION_CANCELLED, READY, RUN_COMPLETED, RUN_CANCELLED, RUN_ERROR]'], @@ -243,10 +285,10 @@ def test_state_is_ready___run_is_started(self, status, task_id): task_obj.id = task_id mock_task = MagicMock(return_value=task_obj) - with patch('src.server.oasisapi.analyses.models.celery_app.send_task', new=mock_task): - analysis.run(initiator) - mock_task.assert_called_once_with('start_loss_generation_task', (analysis.pk, initiator.pk, 1), - {}, queue='celery', link_error=ANY, priority=4) + with patch('src.server.oasisapi.analyses.models.celery_app_v2.send_task', new=mock_task): + analysis.run(initiator, run_mode_override='V2') + mock_task.assert_called_once_with('start_loss_generation_task', (analysis.pk, initiator.pk, None), + {}, queue='celery-v2', link_error=ANY, priority=4) @given( status=sampled_from([ @@ -263,11 +305,11 @@ def test_state_is_running_or_generating_inputs___validation_error_is_raised_revo initiator = fake_user() sig_res = Mock() - with patch('src.server.oasisapi.analyses.models.Analysis.run_analysis_signature', PropertyMock(return_value=sig_res)): + with patch('src.server.oasisapi.analyses.models.Analysis.v2_run_analysis_signature', PropertyMock(return_value=sig_res)): analysis = fake_analysis(status=status, run_task_id=task_id) with self.assertRaises(ValidationError) as ex: - analysis.run(initiator) + analysis.run(initiator, run_mode_override='V2') self.assertEqual( {'status': ['Analysis must be in one of the following states [READY, RUN_COMPLETED, RUN_ERROR, RUN_CANCELLED]']}, ex.exception.detail) @@ -279,12 +321,12 @@ def test_run_analysis_signature_is_correct(self): with override_settings(MEDIA_ROOT=d): analysis = fake_analysis(input_file=fake_related_file(), settings_file=fake_related_file()) - sig = analysis.run_analysis_signature + sig = analysis.v2_run_analysis_signature self.assertEqual(sig.task, 'start_loss_generation_task') self.assertEqual( sig.options['queue'], - iniconf.settings.get('worker', 'LOSSES_GENERATION_CONTROLLER_QUEUE', fallback='celery') + iniconf.settings.get('worker', 'LOSSES_GENERATION_CONTROLLER_QUEUE', fallback='celery-v2') ) @@ -307,10 +349,10 @@ def test_state_is_not_running___run_is_started(self, status, task_id): task_obj = type('', (), {})() task_obj.id = task_id mock_task = MagicMock(return_value=task_obj) - with patch('src.server.oasisapi.analyses.models.celery_app.send_task', new=mock_task): - analysis.generate_inputs(initiator) + with patch('src.server.oasisapi.analyses.models.celery_app_v2.send_task', new=mock_task): + analysis.generate_inputs(initiator, run_mode_override='V2') mock_task.assert_called_once_with('start_input_generation_task', (analysis.pk, initiator.pk, - 4), {}, queue='celery', link_error=ANY, priority=4) + 4), {}, queue='celery-v2', link_error=ANY, priority=4) @given( status=sampled_from([ @@ -326,11 +368,11 @@ def test_state_is_running_or_generating_inputs___validation_error_is_raised_revo initiator = fake_user() sig_res = Mock() - with patch('src.server.oasisapi.analyses.models.Analysis.generate_input_signature', PropertyMock(return_value=sig_res)): + with patch('src.server.oasisapi.analyses.models.Analysis.v2_generate_input_signature', PropertyMock(return_value=sig_res)): analysis = fake_analysis(status=status, run_task_id=task_id, portfolio=fake_portfolio(location_file=fake_related_file())) with self.assertRaises(ValidationError) as ex: - analysis.generate_inputs(initiator) + analysis.generate_inputs(initiator, run_mode_override='V2') self.assertEqual({'status': [ 'Analysis status must be one of [NEW, INPUTS_GENERATION_ERROR, INPUTS_GENERATION_CANCELLED, READY, RUN_COMPLETED, RUN_CANCELLED, RUN_ERROR]' @@ -346,11 +388,11 @@ def test_portfolio_has_no_location_file___validation_error_is_raised_revoke_is_n initiator = fake_user() sig_res = Mock() - with patch('src.server.oasisapi.analyses.models.Analysis.generate_input_signature', PropertyMock(return_value=sig_res)): + with patch('src.server.oasisapi.analyses.models.Analysis.v2_generate_input_signature', PropertyMock(return_value=sig_res)): analysis = fake_analysis(status=Analysis.status_choices.NEW, run_task_id=task_id) with self.assertRaises(ValidationError) as ex: - analysis.generate_inputs(initiator) + analysis.generate_inputs(initiator, run_mode_override='V2') self.assertEqual({'portfolio': ['"location_file" must not be null']}, ex.exception.detail) @@ -362,10 +404,10 @@ def test_generate_input_signature_is_correct(self): with override_settings(MEDIA_ROOT=d): analysis = fake_analysis(portfolio=fake_portfolio(location_file=fake_related_file())) - sig = analysis.generate_input_signature + sig = analysis.v2_generate_input_signature self.assertEqual(sig.task, 'start_input_generation_task') self.assertEqual( sig.options['queue'], - iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery') + iniconf.settings.get('worker', 'INPUT_GENERATION_CONTROLLER_QUEUE', fallback='celery-v2') ) diff --git a/src/server/oasisapi/analyses/tests/test_analysis_task_status_model.py b/src/server/oasisapi/analyses/v2_api/tests/test_analysis_task_status_model.py similarity index 100% rename from src/server/oasisapi/analyses/tests/test_analysis_task_status_model.py rename to src/server/oasisapi/analyses/v2_api/tests/test_analysis_task_status_model.py diff --git a/src/server/oasisapi/analyses/tests/test_analysis_tasks.py b/src/server/oasisapi/analyses/v2_api/tests/test_analysis_tasks.py similarity index 100% rename from src/server/oasisapi/analyses/tests/test_analysis_tasks.py rename to src/server/oasisapi/analyses/v2_api/tests/test_analysis_tasks.py diff --git a/src/server/oasisapi/analyses/tests/test_task_controller.py b/src/server/oasisapi/analyses/v2_api/tests/test_task_controller.py similarity index 82% rename from src/server/oasisapi/analyses/tests/test_task_controller.py rename to src/server/oasisapi/analyses/v2_api/tests/test_task_controller.py index 8a626dddf..452dd6033 100644 --- a/src/server/oasisapi/analyses/tests/test_task_controller.py +++ b/src/server/oasisapi/analyses/v2_api/tests/test_task_controller.py @@ -1,12 +1,12 @@ # from unittest.mock import patch -from unittest.mock import Mock +# from unittest.mock import Mock, patch from uuid import uuid4 # from django.test import TestCase # from src.server.oasisapi.analyses.models import AnalysisTaskStatus -# from src.server.oasisapi.analyses.task_controller import TaskParams -# from src.server.oasisapi.analyses.tests.fakes import fake_analysis +# from src.server.oasisapi.analyses.v2_api.task_controller import TaskParams, Controller +# from src.server.oasisapi.analyses.v2_api.tests.fakes import fake_analysis # from src.server.oasisapi.auth.tests.fakes import fake_user @@ -14,12 +14,15 @@ class FakeChord: def __init__(self): self.head = None self.body = None + self.queue = None self.task_ids = None self.delay_called = False + self.link_callback = None - def __call__(self, head, body=None): + def __call__(self, head, body=None, queue=None): self.head = head self.body = body + self.queue = queue self.task_ids = [uuid4().hex for h in head] return self @@ -28,6 +31,10 @@ def _make_fake_task(self, _id): res.id = _id return res + def link_error(self, sig): + self.link_callback = sig + return self + def delay(self): self.delay_called = True res = Mock() @@ -35,15 +42,12 @@ def delay(self): res.parent.children = [self._make_fake_task(_id) for _id in self.task_ids] return res -########################################### -# No BaseController is available in this branch -# # class TaskController(TestCase): # def test_get_subtask_signature_returns_the_correct_task_name_with_params_and_linked_tasks(self): # analysis = fake_analysis() # initiator = fake_user() # -# sig = BaseController.get_subtask_signature( +# sig = Controller.get_subtask_signature( # analysis, # initiator, # 'the_task', @@ -75,7 +79,7 @@ def delay(self): # analysis = fake_analysis() # initiator = fake_user() # -# sig = BaseController.get_generate_inputs_results_callback(analysis, initiator) +# sig = Controller.get_generate_inputs_results_callback(analysis, initiator) # # self.assertEqual(sig.task, 'generate_input_success') # self.assertEqual(sig.args, (analysis.pk, initiator.pk, )) @@ -86,7 +90,7 @@ def delay(self): # analysis = fake_analysis() # initiator = fake_user() # -# sig = BaseController.get_generate_inputs_error_callback(analysis, initiator) +# sig = Controller.get_generate_inputs_error_callback(analysis, initiator) # # self.assertEqual(sig.task, 'record_generate_inputs_failure') # self.assertEqual(sig.args, (analysis.pk, initiator.pk, )) @@ -97,7 +101,7 @@ def delay(self): # analysis = fake_analysis() # initiator = fake_user() # -# sig = BaseController.get_generate_losses_results_callback(analysis, initiator) +# sig = Controller.get_generate_losses_results_callback(analysis, initiator) # # self.assertEqual(sig.task, 'record_run_analysis_result') # self.assertEqual(sig.args, (analysis.pk, initiator.pk, )) @@ -108,7 +112,7 @@ def delay(self): # analysis = fake_analysis() # initiator = fake_user() # -# sig = BaseController.get_generate_losses_error_callback(analysis, initiator) +# sig = Controller.get_generate_losses_error_callback(analysis, initiator) # # self.assertEqual(sig.task, 'record_run_analysis_failure') # self.assertEqual(sig.args, (analysis.pk, initiator.pk, )) @@ -126,13 +130,13 @@ def delay(self): # subtask = Mock() # subtask.delay.return_value = delay_res # -# class MockController(BaseController): +# class MockController(Controller): # @classmethod # def get_generate_inputs_tasks_params(cls, analysis, initiator): # return TaskParams('foo', bar='boo') # # @classmethod -# def get_subtask_signature(cls, analysis, initiator, task_name, params, queue): +# def get_subtask_signature(cls, task_name, analysis, initiator, run_data_uuid, slug, queue, params): # return subtask # # @classmethod @@ -167,13 +171,13 @@ def delay(self): # subtask = Mock() # subtask.delay.return_value = delay_res # -# class MockController(BaseController): +# class MockController(Controller): # @classmethod # def get_generate_losses_tasks_params(cls, analysis, initiator): # return TaskParams('foo', bar='boo') # # @classmethod -# def get_subtask_signature(cls, analysis, initiator, task_name, params, queue): +# def get_subtask_signature(cls, task_name, analysis, initiator, run_data_uuid, slug, queue, params): # return subtask # # @classmethod @@ -204,22 +208,31 @@ def delay(self): # subtasks = [ # Mock(), # Mock(), +# Mock(), +# Mock(), +# Mock(), +# Mock(), +# Mock(), # ] # # body = Mock() -# +# chain = Mock() # chord = FakeChord() # -# class MockController(BaseController): +# class MockController(Controller): # @classmethod -# def get_generate_inputs_tasks_params(cls, analysis, initiator): -# return [ -# TaskParams('foo', bar='boo'), -# TaskParams('far', boo='bar'), -# ] +# #def get_generate_inputs_tasks_params(cls, analysis, initiator): +# def get_inputs_generation_tasks(cls, analysis, initiator, run_data_uuid, num_chunks): +# return tuple(zip( +# [ +# TaskParams('foo', bar='boo'), +# TaskParams('far', boo='bar'), +# ] +# )) +# # # @classmethod -# def get_subtask_signature(cls, analysis, initiator, task_name, params, queue): +# def get_subtask_signature(cls, task_name, analysis, initiator, run_data_uuid, slug, queue, params): # return subtasks.pop() # # @classmethod @@ -234,8 +247,11 @@ def delay(self): # def get_chord_error_callback(cls): # return 'chord_error_callback' # -# with patch('src.server.oasisapi.analyses.task_controller.chord', chord): -# MockController.generate_inputs(analysis, initiator) +# with patch('src.server.oasisapi.analyses.v2_api.task_controller.chord', chord): +# #with patch('src.server.oasisapi.analyses.v2_api.task_controller.chord', chord), \ +# # patch('src.server.oasisapi.analyses.v2_api.task_controller.chain', chain): +# +# MockController.generate_inputs(analysis, initiator, loc_lines=10) # # self.assertEqual(body.link_error.call_count, 2) # body.link_error.assert_any_call('generate_inputs_error_callback') @@ -272,7 +288,7 @@ def delay(self): # # chord = FakeChord() # -# class MockController(BaseController): +# class MockController(Controller): # @classmethod # def get_generate_losses_tasks_params(cls, analysis, initiator): # return [ @@ -281,7 +297,7 @@ def delay(self): # ] # # @classmethod -# def get_subtask_signature(cls, analysis, initiator, task_name, params, queue): +# def get_subtask_signature(cls, task_name, analysis, initiator, run_data_uuid, slug, queue, params): # return subtasks.pop() # # @classmethod @@ -296,7 +312,7 @@ def delay(self): # def get_chord_error_callback(cls): # return 'chord_error_callback' # -# with patch('src.server.oasisapi.analyses.task_controller.chord', chord): +# with patch('src.server.oasisapi.analyses.v2_api.task_controller.chord', chord): # MockController.generate_losses(analysis, initiator) # # self.assertEqual(body.link_error.call_count, 2) diff --git a/src/server/oasisapi/analyses/v2_api/urls.py b/src/server/oasisapi/analyses/v2_api/urls.py new file mode 100644 index 000000000..50c1adb5b --- /dev/null +++ b/src/server/oasisapi/analyses/v2_api/urls.py @@ -0,0 +1,22 @@ +from django.conf.urls import include, url +from rest_framework.routers import SimpleRouter +from .viewsets import AnalysisViewSet, AnalysisSettingsView, AnalysisTaskStatusViewSet + + +app_name = 'analyses' +v2_api_router = SimpleRouter() +v2_api_router.include_root_view = False +v2_api_router.register('analyses', AnalysisViewSet, basename='analysis') +v2_api_router.register('analysis-task-statuses', AnalysisTaskStatusViewSet, basename='analysis-task-status') + +analyses_settings = AnalysisSettingsView.as_view({ + 'get': 'analysis_settings', + 'post': 'analysis_settings', + 'delete': 'analysis_settings' +}) + + +urlpatterns = [ + url(r'analyses/(?P\d+)/settings/', analyses_settings, name='analysis-settings'), + url(r'', include(v2_api_router.urls)), +] diff --git a/src/server/oasisapi/analyses/viewsets.py b/src/server/oasisapi/analyses/v2_api/viewsets.py similarity index 89% rename from src/server/oasisapi/analyses/viewsets.py rename to src/server/oasisapi/analyses/v2_api/viewsets.py index ac45fac90..4dbc2f31a 100644 --- a/src/server/oasisapi/analyses/viewsets.py +++ b/src/server/oasisapi/analyses/v2_api/viewsets.py @@ -13,18 +13,24 @@ from rest_framework.serializers import Serializer from rest_framework.settings import api_settings -from .models import Analysis, AnalysisTaskStatus +from ..models import Analysis, AnalysisTaskStatus from .serializers import AnalysisSerializer, AnalysisCopySerializer, AnalysisTaskStatusSerializer, \ AnalysisStorageSerializer, AnalysisListSerializer -from ..analysis_models.models import AnalysisModel -from ..data_files.serializers import DataFileSerializer -from ..files.serializers import RelatedFileSerializer -from ..files.views import handle_related_file, handle_json_data -from ..filters import TimeStampedFilter, CsvMultipleChoiceFilter, CsvModelMultipleChoiceFilter -from ..permissions.group_auth import VerifyGroupAccessModelViewSet, verify_user_is_in_obj_groups -from ..portfolios.models import Portfolio -from ..schemas.custom_swagger import FILE_RESPONSE, SUBTASK_STATUS_PARAM, SUBTASK_SLUG_PARAM -from ..schemas.serializers import AnalysisSettingsSerializer +from ...analysis_models.models import AnalysisModel +from ...analysis_models.v2_api.serializers import ModelChunkingConfigSerializer +from ...data_files.v2_api.serializers import DataFileSerializer +from ...files.serializers import RelatedFileSerializer +from ...files.views import handle_related_file, handle_json_data +from ...filters import TimeStampedFilter, CsvMultipleChoiceFilter, CsvModelMultipleChoiceFilter +from ...permissions.group_auth import VerifyGroupAccessModelViewSet, verify_user_is_in_obj_groups +from ...portfolios.models import Portfolio +from ...schemas.serializers import AnalysisSettingsSerializer +from ...schemas.custom_swagger import ( + FILE_RESPONSE, + RUN_MODE_PARAM, + SUBTASK_STATUS_PARAM, + SUBTASK_SLUG_PARAM, +) class AnalysisFilter(TimeStampedFilter): @@ -53,6 +59,22 @@ class AnalysisFilter(TimeStampedFilter): field_name='status', label=_('Status in') ) + run_mode = filters.ChoiceFilter( + help_text=_('Filter results by results in the current analysis status, one of [{}]'.format( + ', '.join(Analysis.run_mode_choices._db_values)) + ), + choices=Analysis.run_mode_choices, + ) + run_mode__in = CsvMultipleChoiceFilter( + help_text=_( + 'Filter results by results where the current analysis status ' + 'is one of a given set (provide multiple parameters or comma separated list), ' + 'from [{}]'.format(', '.join(Analysis.run_mode_choices._db_values)) + ), + choices=Analysis.run_mode_choices, + field_name='run_mode', + label=_('Run mode in') + ) model = NumberFilter( help_text=_('Filter results by the id of the model the analysis belongs to'), field_name='model' @@ -158,6 +180,8 @@ class Meta: ] +# https://stackoverflow.com/questions/62572389/django-drf-yasg-how-to-add-description-to-tags + @method_decorator(name='list', decorator=swagger_auto_schema(responses={200: AnalysisSerializer(many=True)})) class AnalysisViewSet(VerifyGroupAccessModelViewSet): """ @@ -250,6 +274,8 @@ def get_serializer_class(self): return AnalysisStorageSerializer elif self.action in self.file_action_types_with_settings_file: return RelatedFileSerializer + elif self.action in ['chunking_configuration']: + return ModelChunkingConfigSerializer else: return Serializer @@ -260,7 +286,7 @@ def parser_classes(self): else: return api_settings.DEFAULT_PARSER_CLASSES - @swagger_auto_schema(responses={200: AnalysisSerializer}) + @swagger_auto_schema(responses={200: AnalysisSerializer}, manual_parameters=[RUN_MODE_PARAM]) @action(methods=['post'], detail=True) def run(self, request, pk=None, version=None): """ @@ -269,8 +295,9 @@ def run(self, request, pk=None, version=None): `RUN_ERROR` """ obj = self.get_object() + run_mode_override = request.GET.get('run_mode_override', None) verify_user_is_in_obj_groups(request.user, obj.model, 'You are not allowed to run this model') - obj.run(request.user) + obj.run(request.user, run_mode_override=run_mode_override) return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) @swagger_auto_schema(responses={200: AnalysisSerializer}) @@ -309,7 +336,7 @@ def cancel_analysis_run(self, request, pk=None, version=None): obj.cancel_analysis() return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) - @swagger_auto_schema(responses={200: AnalysisSerializer}) + @swagger_auto_schema(responses={200: AnalysisSerializer}, manual_parameters=[RUN_MODE_PARAM]) @action(methods=['post'], detail=True) def generate_inputs(self, request, pk=None, version=None): """ @@ -317,8 +344,9 @@ def generate_inputs(self, request, pk=None, version=None): The analysis must have one of the following statuses, `INPUTS_GENERATION_QUEUED` or `INPUTS_GENERATION_STARTED` """ obj = self.get_object() + run_mode_override = request.GET.get('run_mode_override', None) verify_user_is_in_obj_groups(request.user, obj.model, 'You are not allowed to run this model') - obj.generate_inputs(request.user) + obj.generate_inputs(request.user, run_mode_override=run_mode_override) return Response(AnalysisSerializer(instance=obj, context=self.get_serializer_context()).data) @swagger_auto_schema(responses={200: AnalysisSerializer}) @@ -526,6 +554,19 @@ def sub_task_list(self, request, pk=None, version=None): serializer = AnalysisTaskStatusSerializer(sub_task_queryset, many=True, context=context) return Response(serializer.data) + @action(methods=['get', 'post'], detail=True) + def chunking_configuration(self, request, pk=None, version=None): + method = request.method.lower() + obj = self.get_object() + if method == 'get': + serializer = self.get_serializer(obj.chunking_options) + else: + serializer = self.get_serializer(obj.chunking_options, data=request.data) + serializer.is_valid(raise_exception=True) + obj.chunking_options = serializer.save() + obj.save() + return Response(serializer.data) + class AnalysisSettingsView(VerifyGroupAccessModelViewSet): """ diff --git a/src/server/oasisapi/analysis_models/admin.py b/src/server/oasisapi/analysis_models/admin.py index 8a492324d..7ff3d6bca 100644 --- a/src/server/oasisapi/analysis_models/admin.py +++ b/src/server/oasisapi/analysis_models/admin.py @@ -1,5 +1,5 @@ from django.contrib import admin -from .models import AnalysisModel, SettingsTemplate +from .models import AnalysisModel, SettingsTemplate, ModelScalingOptions, ModelChunkingOptions from django.contrib.admin.actions import delete_selected as delete_selected_ @@ -28,8 +28,13 @@ def activate_model(modeladmin, request, queryset): @admin.register(AnalysisModel) class CatModelAdmin(admin.ModelAdmin): actions = [delete_hard, activate_model] - - list_display = ['model_id', 'supplier_id', 'version_id', 'creator', 'deleted'] + list_display = [ + 'model_id', + 'supplier_id', + 'version_id', + 'creator', + 'deleted' + ] def get_queryset(self, request): return self.model.all_objects @@ -41,4 +46,33 @@ def get_queryset(self, request): @admin.register(SettingsTemplate) class SettingsTemplateAdmin(admin.ModelAdmin): - list_display = ['file', 'name', 'creator'] + list_display = [ + 'file', + 'name', + 'creator' + ] + + +@admin.register(ModelScalingOptions) +class ModelScalingOptionsAdmin(admin.ModelAdmin): + list_display = [ + 'scaling_types', + 'scaling_strategy', + 'worker_count_fixed', + 'worker_count_max', + 'worker_count_min', + 'chunks_per_worker', + ] + + +@admin.register(ModelChunkingOptions) +class ModelChunkingTemplateAdmin(admin.ModelAdmin): + list_display = [ + 'lookup_strategy', + 'loss_strategy', + 'dynamic_locations_per_lookup', + 'dynamic_events_per_analysis', + 'dynamic_chunks_max', + 'fixed_analysis_chunks', + 'fixed_lookup_chunks', + ] diff --git a/src/server/oasisapi/analysis_models/migrations/0007_modelscalingoptions_worker_count_min.py b/src/server/oasisapi/analysis_models/migrations/0007_modelscalingoptions_worker_count_min.py new file mode 100644 index 000000000..d022ca5f5 --- /dev/null +++ b/src/server/oasisapi/analysis_models/migrations/0007_modelscalingoptions_worker_count_min.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2023-09-25 08:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('analysis_models', '0006_auto_20230724_1134'), + ] + + operations = [ + migrations.AddField( + model_name='modelscalingoptions', + name='worker_count_min', + field=models.PositiveSmallIntegerField(default=0), + ), + ] diff --git a/src/server/oasisapi/analysis_models/migrations/0008_analysismodel_run_mode.py b/src/server/oasisapi/analysis_models/migrations/0008_analysismodel_run_mode.py new file mode 100644 index 000000000..e53570d81 --- /dev/null +++ b/src/server/oasisapi/analysis_models/migrations/0008_analysismodel_run_mode.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.23 on 2023-12-06 21:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('analysis_models', '0007_modelscalingoptions_worker_count_min'), + ] + + operations = [ + migrations.AddField( + model_name='analysismodel', + name='run_mode', + field=models.CharField(choices=[('BOTH', 'Works on both Execution modes'), ('V1', 'Available for Single-Instance Execution'), ('V2', 'Available for Distributed Execution')], default=None, help_text='Execution modes Available, v1 = Single-Instance, v2 = Distributed Execution', max_length=4, null=True), + ), + ] diff --git a/src/server/oasisapi/analysis_models/migrations/0009_alter_analysismodel_run_mode.py b/src/server/oasisapi/analysis_models/migrations/0009_alter_analysismodel_run_mode.py new file mode 100644 index 000000000..5902ba362 --- /dev/null +++ b/src/server/oasisapi/analysis_models/migrations/0009_alter_analysismodel_run_mode.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2024-01-22 10:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('analysis_models', '0008_analysismodel_run_mode'), + ] + + operations = [ + migrations.AlterField( + model_name='analysismodel', + name='run_mode', + field=models.CharField(choices=[('V1', 'Available for Single-Instance Execution'), ('V2', 'Available for Distributed Execution')], default=None, help_text='Execution modes Available, v1 = Single-Instance, v2 = Distributed Execution', max_length=2, null=True), + ), + ] diff --git a/src/server/oasisapi/analysis_models/models.py b/src/server/oasisapi/analysis_models/models.py index 41107388d..84b4b0572 100644 --- a/src/server/oasisapi/analysis_models/models.py +++ b/src/server/oasisapi/analysis_models/models.py @@ -51,6 +51,7 @@ class ModelScalingOptions(models.Model): choices=scaling_types, default=scaling_types.FIXED_WORKERS, editable=True) worker_count_fixed = models.PositiveSmallIntegerField(default=1, null=False) worker_count_max = models.PositiveSmallIntegerField(default=10, null=False) + worker_count_min = models.PositiveSmallIntegerField(default=0, null=False) chunks_per_worker = models.PositiveIntegerField(default=10, null=False) @@ -111,17 +112,23 @@ def get_filestore(self): else: return None - def get_absolute_settings_template_url(self, model_pk, request=None): - return reverse('models-setting_templates-content', kwargs={'version': 'v1', 'pk': self.pk, 'models_pk': model_pk}, request=request) + def get_absolute_settings_template_url(self, model_pk, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}models-setting_templates-content', kwargs={'pk': self.pk, 'models_pk': model_pk}, request=self._update_ns(request)) class AnalysisModel(TimeStampedModel): + run_mode_choices = Choices( + ('V1', 'Available for Single-Instance Execution'), + ('V2', 'Available for Distributed Execution'), + ) + supplier_id = models.CharField(max_length=255, help_text=_('The supplier ID for the model.')) model_id = models.CharField(max_length=255, help_text=_('The model ID for the model.')) version_id = models.CharField(max_length=255, help_text=_('The version ID for the model.')) resource_file = models.ForeignKey(RelatedFile, on_delete=models.CASCADE, null=True, default=None, related_name='analysis_model_resource_file') creator = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE) - groups = models.ManyToManyField(Group, blank=True, null=False, default=None, help_text='Groups allowed to access this object') + groups = models.ManyToManyField(Group, blank=True, default=None, help_text='Groups allowed to access this object') data_files = models.ManyToManyField(DataFile, blank=True, related_name='analyses_model_data_files') template_files = models.ManyToManyField(SettingsTemplate, blank=True, related_name='analyses_model_settings_template') ver_ktools = models.CharField(max_length=255, null=True, default=None, help_text=_('The worker ktools version.')) @@ -129,6 +136,8 @@ class AnalysisModel(TimeStampedModel): ver_platform = models.CharField(max_length=255, null=True, default=None, help_text=_('The worker platform version.')) oasislmf_config = models.TextField(default='') deleted = models.BooleanField(default=False, editable=False) + run_mode = models.CharField(max_length=max(len(c) for c in run_mode_choices._db_values), + choices=run_mode_choices, default=None, null=True, help_text=_('Execution modes Available, v1 = Single-Instance, v2 = Distributed Execution')) scaling_options = models.OneToOneField(ModelScalingOptions, on_delete=models.CASCADE, auto_created=True, default=None, null=True) chunking_options = models.OneToOneField(ModelChunkingOptions, on_delete=models.CASCADE, auto_created=True, default=None, null=True) @@ -144,6 +153,20 @@ class Meta: def __str__(self): return '{}-{}-{}'.format(self.supplier_id, self.model_id, self.version_id) + def _update_ns(self, request=None): + """ WORKAROUND - this is needed for when a copy request is issued + from the portfolio view '/{ver}/portfolios/{id}/create_analysis/' + + The inncorrect namespace '{ver}-portfolios' is inherited from the + original request. This needs to be replaced with '{ver}-analyses' + """ + if not request: + return None + ns_ver, ns_view = request.version.split('-') + if ns_view != 'models': + request.version = f'{ns_ver}-models' + return request + @property def queue_name(self): return str(self) @@ -172,20 +195,33 @@ def activate(self, request=None): pass self.save() - def get_absolute_resources_file_url(self, request=None): - return reverse('analysis-model-resource-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) - - def get_absolute_versions_url(self, request=None): - return reverse('analysis-model-versions', kwargs={'version': 'v1', 'pk': self.pk}, request=request) - - def get_absolute_settings_url(self, request=None): - return reverse('model-settings', kwargs={'version': 'v1', 'pk': self.pk}, request=request) - - def get_absolute_scaling_configuration_url(self, request=None): - return reverse('analysis-model-scaling-configuration', kwargs={'version': 'v1', 'pk': self.pk}, request=request) - - def get_absolute_chunking_configuration_url(self, request=None): - return reverse('analysis-model-chunking-configuration', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_resources_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-model-resource-file', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def get_absolute_versions_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-model-versions', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def get_absolute_settings_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}model-settings', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def get_absolute_scaling_configuration_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-model-scaling-configuration', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def get_absolute_chunking_configuration_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}analysis-model-chunking-configuration', kwargs={'pk': self.pk}, request=self._update_ns(request)) + + def update_run_mode(self, namespace=None): + if self.resource_file: + model_settings = self.resource_file.read_json() + run_mode = model_settings.get('model_run_mode', '').upper() + if run_mode in (self.run_mode_choices.V1, self.run_mode_choices.V2): + self.run_mode = run_mode + self.save(update_fields=["run_mode"]) class QueueModelAssociation(models.Model): diff --git a/src/server/oasisapi/data_files/tests/__init__.py b/src/server/oasisapi/analysis_models/v1_api/__init__.py similarity index 100% rename from src/server/oasisapi/data_files/tests/__init__.py rename to src/server/oasisapi/analysis_models/v1_api/__init__.py diff --git a/src/server/oasisapi/analysis_models/v1_api/serializers.py b/src/server/oasisapi/analysis_models/v1_api/serializers.py new file mode 100644 index 000000000..6e39398c1 --- /dev/null +++ b/src/server/oasisapi/analysis_models/v1_api/serializers.py @@ -0,0 +1,123 @@ +from drf_yasg.utils import swagger_serializer_method +from django.core.exceptions import ObjectDoesNotExist +from rest_framework import serializers +from rest_framework.exceptions import ValidationError + +from ..models import AnalysisModel, SettingsTemplate +from ...analyses.models import Analysis + + +class AnalysisModelSerializer(serializers.ModelSerializer): + settings = serializers.SerializerMethodField() + versions = serializers.SerializerMethodField() + ns = 'v1-models' + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = AnalysisModel + fields = ( + 'id', + 'supplier_id', + 'model_id', + 'version_id', + 'created', + 'modified', + 'data_files', + 'settings', + 'versions', + 'run_mode', + ) + + def create(self, validated_data): + data = validated_data.copy() + if 'request' in self.context: + data['creator'] = self.context.get('request').user + return super(AnalysisModelSerializer, self).create(data) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_settings(self, instance): + request = self.context.get('request') + return instance.get_absolute_settings_url(request=request, namespace=self.ns) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_versions(self, instance): + request = self.context.get('request') + return instance.get_absolute_versions_url(request=request, namespace=self.ns) + + +class TemplateSerializer(serializers.ModelSerializer): + """ Catch-all Analysis settings Template Serializer, + intended to be called from a nested ViewSet + """ + file_url = serializers.SerializerMethodField() + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = SettingsTemplate + fields = ( + 'id', + 'name', + 'description', + 'file_url', + ) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_file_url(self, instance): + request = self.context.get('request') + model_pk = request.parser_context.get('kwargs', {}).get('models_pk') + + if model_pk and instance.file: + return instance.get_absolute_settings_template_url(model_pk, request=request) + else: + return None + + +class CreateTemplateSerializer(serializers.ModelSerializer): + """ Used for creating a new template with an option to copy an existing + analysis_settings.json file from an analyses via the 'analysis_id' param. + """ + analysis_id = serializers.IntegerField(required=False) + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = SettingsTemplate + fields = ( + 'id', + 'name', + 'description', + 'analysis_id', + ) + + def validate(self, attrs): + analysis_id = attrs.pop('analysis_id', None) + if analysis_id: + try: + analysis = Analysis.objects.get(id=analysis_id) + except ObjectDoesNotExist: + raise ValidationError({"Detail": f"analysis_id = {analysis_id} not found"}) + if not analysis.settings_file: + raise ValidationError({"Detail": f"analysis_id = {analysis_id} has no attached settings file"}) + + new_settings = analysis.copy_file(analysis.settings_file) + new_settings.name = attrs.get('name') + new_settings.save() + attrs['file'] = new_settings + + return super(CreateTemplateSerializer, self).validate(attrs) + + def create(self, validated_data): + data = dict(validated_data) + if not data.get('creator') and 'request' in self.context: + data['creator'] = self.context.get('request').user + return super(CreateTemplateSerializer, self).create(data) + + +class ModelVersionsSerializer(serializers.ModelSerializer): + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = AnalysisModel + fields = ( + 'ver_ktools', + 'ver_oasislmf', + 'ver_platform', + ) diff --git a/src/server/oasisapi/portfolios/tests/__init__.py b/src/server/oasisapi/analysis_models/v1_api/tests/__init__.py similarity index 100% rename from src/server/oasisapi/portfolios/tests/__init__.py rename to src/server/oasisapi/analysis_models/v1_api/tests/__init__.py diff --git a/src/server/oasisapi/analysis_models/tests/fakes.py b/src/server/oasisapi/analysis_models/v1_api/tests/fakes.py similarity index 76% rename from src/server/oasisapi/analysis_models/tests/fakes.py rename to src/server/oasisapi/analysis_models/v1_api/tests/fakes.py index ea3b39777..75033b921 100644 --- a/src/server/oasisapi/analysis_models/tests/fakes.py +++ b/src/server/oasisapi/analysis_models/v1_api/tests/fakes.py @@ -1,6 +1,6 @@ from model_mommy import mommy -from ..models import AnalysisModel +from ...models import AnalysisModel def fake_analysis_model(**kwargs): diff --git a/src/server/oasisapi/analysis_models/v1_api/tests/test_analysis_model.py b/src/server/oasisapi/analysis_models/v1_api/tests/test_analysis_model.py new file mode 100644 index 000000000..84ef13422 --- /dev/null +++ b/src/server/oasisapi/analysis_models/v1_api/tests/test_analysis_model.py @@ -0,0 +1,271 @@ +import json +import string + +from backports.tempfile import TemporaryDirectory +from django.test import override_settings +from django.urls import reverse +from django_webtest import WebTest, WebTestMixin +from hypothesis import given, settings +from hypothesis.extra.django import TestCase +from hypothesis.strategies import text +from rest_framework_simplejwt.tokens import AccessToken + +from ....auth.tests.fakes import fake_user +from ...models import AnalysisModel + +from .fakes import fake_analysis_model + +# Override default deadline for all tests to 8s +settings.register_profile("ci", deadline=800.0) +settings.load_profile("ci") +NAMESPACE = 'v1-models' + + +class AnalysisModelApi(WebTest, TestCase): + @given( + supplier_id=text(alphabet=string.ascii_letters, min_size=1, max_size=10), + version_id=text(alphabet=string.whitespace, min_size=0, max_size=10), + ) + def test_version_id_is_missing___response_is_400(self, supplier_id, version_id): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-model-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({ + 'supplier_id': supplier_id, + 'version_id': version_id, + }), + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + self.assertFalse(AnalysisModel.objects.exists()) + + @given( + supplier_id=text(alphabet=string.whitespace, min_size=0, max_size=10), + version_id=text(alphabet=string.ascii_letters, min_size=1, max_size=10), + ) + def test_supplier_id_is_missing___response_is_400(self, supplier_id, version_id): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-model-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({ + 'supplier_id': supplier_id, + 'version_id': version_id, + }), + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + self.assertFalse(AnalysisModel.objects.exists()) + + @given( + supplier_id=text(alphabet=string.ascii_letters, min_size=1, max_size=10), + model_id=text(alphabet=string.ascii_letters, min_size=1, max_size=10), + version_id=text(alphabet=string.ascii_letters, min_size=1, max_size=10), + ) + def test_data_is_valid___object_is_created(self, supplier_id, model_id, version_id): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:analysis-model-list'), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({ + 'supplier_id': supplier_id, + 'model_id': model_id, + 'version_id': version_id, + }), + content_type='application/json', + ) + + model = AnalysisModel.objects.first() + + self.assertEqual(201, response.status_code) + self.assertEqual(model.supplier_id, supplier_id) + self.assertEqual(model.version_id, version_id) + self.assertEqual(model.model_id, model_id) + + +class ModelSettingsJson(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + + response = self.app.get(settings_url, expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + """ Add these check back in once models auto-update their settings fields + """ + + def test_settings_json_is_not_present___get_response_is_404(self): + user = fake_user() + models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + + response = self.app.get( + settings_url, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_settings_json_is_not_present___delete_response_is_404(self): + user = fake_user() + models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + + response = self.app.delete( + settings_url, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_settings_json_is_not_valid___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + + json_data = { + "model_settings": { + "event_set": { + "name": "Event Set", + "default": "P", + "options": [ + {"id": "P", "desc": "Proabilistic"}, + {"id": "H", "desc": "Historic"} + ] + }, + "event_occurrence_id": { + "name": "Occurrence Set", + "desc": "PiWind Occurrence selection", + "default": 1, + "options": [ + {"id": "1", "desc": "Long Term"} + ] + }, + "boolean_parameters": [ + {"name": "peril_wind", "desc": "Boolean option", "default": 1.1}, + {"name": "peril_surge", "desc": "Boolean option", "default": True} + ], + "float_parameter": [ + {"name": "float_1", "desc": "Some float value", "default": False, "max": 1.0, "min": 0.0}, + {"name": "float_2", "desc": "Some float value", "default": 0.3, "max": 1.0, "min": 0.0} + ] + }, + "lookup_settings": { + "supported_perils": [ + {"i": "WSS", "desc": "Single Peril: Storm Surge"}, + {"id": "WTC", "des": "Single Peril: Tropical Cyclone"}, + {"id": "WW11", "desc": "Group Peril: Windstorm with storm surge"}, + {"id": "WW2", "desc": "Group Peril: Windstorm w/o storm surge"} + ] + } + } + + response = self.app.post( + settings_url, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps(json_data), + content_type='application/json', + expect_errors=True, + ) + + validation_error = { + 'model_settings': ["Additional properties are not allowed ('float_parameter' was unexpected)"], + 'model_settings-event_set': ["'desc' is a required property"], + 'model_settings-event_occurrence_id-default': ["1 is not of type 'string'"], + 'model_settings-boolean_parameters-0-default': ["1.1 is not of type 'boolean'"], + 'lookup_settings-supported_perils-0': ["Additional properties are not allowed ('i' was unexpected)", "'id' is a required property"], + 'lookup_settings-supported_perils-1': ["Additional properties are not allowed ('des' was unexpected)", "'desc' is a required property"], + 'lookup_settings-supported_perils-2-id': ["'WW11' is too long"] + } + + self.assertEqual(400, response.status_code) + self.assertDictEqual.__self__.maxDiff = None + self.assertDictEqual(json.loads(response.body), validation_error) + + def test_settings_json_is_uploaded___can_be_retrieved(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + + json_data = { + "model_settings": { + "event_set": { + "name": "Event Set", + "desc": "Either Probablistic or Historic", + "default": "P", + "options": [ + {"id": "P", "desc": "Proabilistic"}, + {"id": "H", "desc": "Historic"} + ] + }, + "event_occurrence_id": { + "name": "Occurrence Set", + "desc": "PiWind Occurrence selection", + "default": "1", + "options": [ + {"id": "1", "desc": "Long Term"} + ] + }, + "boolean_parameters": [ + {"name": "peril_wind", "desc": "Boolean option", "default": False}, + {"name": "peril_surge", "desc": "Boolean option", "default": True} + ], + "float_parameters": [ + {"name": "float_1", "desc": "Some float value", "default": 0.5, "max": 1.0, "min": 0.0}, + {"name": "float_2", "desc": "Some float value", "default": 0.3, "max": 1.0, "min": 0.0} + ] + }, + "lookup_settings": { + "supported_perils": [ + {"id": "WSS", "desc": "Single Peril: Storm Surge"}, + {"id": "WTC", "desc": "Single Peril: Tropical Cyclone"}, + {"id": "WW1", "desc": "Group Peril: Windstorm with storm surge"}, + {"id": "WW2", "desc": "Group Peril: Windstorm w/o storm surge"} + ] + } + } + + self.app.post( + settings_url, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps(json_data), + content_type='application/json' + ) + + response = self.app.get( + settings_url, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + self.assertDictEqual.__self__.maxDiff = None + self.assertDictEqual(json.loads(response.body), json_data) + self.assertEqual(response.content_type, 'application/json') diff --git a/src/server/oasisapi/analysis_models/v1_api/urls.py b/src/server/oasisapi/analysis_models/v1_api/urls.py new file mode 100644 index 000000000..0fcb9c083 --- /dev/null +++ b/src/server/oasisapi/analysis_models/v1_api/urls.py @@ -0,0 +1,23 @@ +from django.conf.urls import include, url +from rest_framework_nested import routers +from .viewsets import AnalysisModelViewSet, ModelSettingsView, SettingsTemplateViewSet + + +app_name = 'models' +v1_api_router = routers.SimpleRouter() +v1_api_router.register('models', AnalysisModelViewSet, basename='analysis-model') + +v1_templates_router = routers.NestedSimpleRouter(v1_api_router, r'models', lookup='models') +v1_templates_router.register('setting_templates', SettingsTemplateViewSet, basename='models-setting_templates') + +model_settings = ModelSettingsView.as_view({ + 'get': 'model_settings', + 'post': 'model_settings', + 'delete': 'model_settings' +}) + +urlpatterns = [ + url(r'models/(?P\d+)/settings/', model_settings, name='model-settings'), + url(r'', include(v1_api_router.urls)), + url(r'', include(v1_templates_router.urls)), +] diff --git a/src/server/oasisapi/analysis_models/viewsets.py b/src/server/oasisapi/analysis_models/v1_api/viewsets.py similarity index 74% rename from src/server/oasisapi/analysis_models/viewsets.py rename to src/server/oasisapi/analysis_models/v1_api/viewsets.py index 1a3112cce..5196e262a 100644 --- a/src/server/oasisapi/analysis_models/viewsets.py +++ b/src/server/oasisapi/analysis_models/v1_api/viewsets.py @@ -11,23 +11,14 @@ from rest_framework.response import Response from rest_framework.settings import api_settings -from .models import AnalysisModel, SettingsTemplate -from .serializers import ( - AnalysisModelSerializer, - ModelVersionsSerializer, - CreateTemplateSerializer, - TemplateSerializer, - ModelScalingConfigSerializer, - ModelChunkingConfigSerializer, -) - -from ..data_files.serializers import DataFileSerializer -from ..filters import TimeStampedFilter -from ..files.views import handle_related_file, handle_json_data -from ..files.serializers import RelatedFileSerializer -from ..permissions.group_auth import VerifyGroupAccessModelViewSet -from ..schemas.custom_swagger import FILE_RESPONSE -from ..schemas.serializers import ModelParametersSerializer, AnalysisSettingsSerializer +from ..models import AnalysisModel, SettingsTemplate +from .serializers import AnalysisModelSerializer, ModelVersionsSerializer, CreateTemplateSerializer, TemplateSerializer + +from ...data_files.v1_api.serializers import DataFileSerializer +from ...filters import TimeStampedFilter +from ...files.views import handle_json_data +from ...files.serializers import RelatedFileSerializer +from ...schemas.serializers import ModelParametersSerializer, AnalysisSettingsSerializer class AnalysisModelFilter(TimeStampedFilter): @@ -152,7 +143,7 @@ def content(self, request, pk=None, models_pk=None, version=None): return handle_json_data(self.get_object(), 'file', request, AnalysisSettingsSerializer) -class AnalysisModelViewSet(VerifyGroupAccessModelViewSet): +class AnalysisModelViewSet(viewsets.ModelViewSet): """ list: Returns a list of Model objects. @@ -188,9 +179,11 @@ class AnalysisModelViewSet(VerifyGroupAccessModelViewSet): Partially updates the specified model (only provided fields are updated) """ + # queryset = AnalysisModel.objects.all() + queryset = AnalysisModel.objects.exclude(run_mode='V2') serializer_class = AnalysisModelSerializer filterset_class = AnalysisModelFilter - group_access_model = AnalysisModel + # lookup_field = 'id' def get_serializer_class(self): if self.action in ['resource_file', 'set_resource_file']: @@ -199,10 +192,6 @@ def get_serializer_class(self): return DataFileSerializer elif self.action in ['versions']: return ModelVersionsSerializer - elif self.action in ['scaling_configuration']: - return ModelScalingConfigSerializer - elif self.action in ['chunking_configuration']: - return ModelChunkingConfigSerializer else: return super(AnalysisModelViewSet, self).get_serializer_class() @@ -236,48 +225,6 @@ def versions(self, request, pk=None, version=None): obj = self.get_object() return Response(ModelVersionsSerializer(instance=obj, context=self.get_serializer_context()).data) - @action(methods=['get', 'post'], detail=True) - def scaling_configuration(self, request, pk=None, version=None): - method = request.method.lower() - if method == 'get': - serializer = self.get_serializer(self.get_object().scaling_options) - else: - serializer = self.get_serializer(self.get_object().scaling_options, data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - return Response(serializer.data) - - @action(methods=['get', 'post'], detail=True) - def chunking_configuration(self, request, pk=None, version=None): - method = request.method.lower() - if method == 'get': - serializer = self.get_serializer(self.get_object().chunking_options) - else: - serializer = self.get_serializer(self.get_object().chunking_options, data=request.data) - serializer.is_valid(raise_exception=True) - serializer.save() - return Response(serializer.data) - - @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) - @action(methods=['get', 'delete'], detail=True) - def resource_file(self, request, pk=None, version=None): - """ - get: - Gets the models `resource_file` contents - - delete: - Disassociates the moodels `resource_file` contents - """ - return handle_related_file(self.get_object(), 'resource_file', request, ['application/json']) - - @resource_file.mapping.post - def set_resource_file(self, request, pk=None, version=None): - """ - post: - Sets the models `resource_file` contents - """ - return handle_related_file(self.get_object(), 'resource_file', request, ['application/json']) - @swagger_auto_schema(responses={200: DataFileSerializer(many=True)}) @action(methods=['get'], detail=True) def data_files(self, request, pk=None, version=None): @@ -297,4 +244,9 @@ class ModelSettingsView(viewsets.ModelViewSet): @swagger_auto_schema(method='post', request_body=ModelParametersSerializer, responses={201: RelatedFileSerializer}) @action(methods=['get', 'post', 'delete'], detail=True) def model_settings(self, request, pk=None, version=None): - return handle_json_data(self.get_object(), 'resource_file', request, ModelParametersSerializer) + obj = self.get_object() + response = handle_json_data(obj, 'resource_file', request, ModelParametersSerializer) + # Update Model's execution mode if 'model_run_mode' is in model_settings.json + if request.method.lower() == 'post': + obj.update_run_mode() + return response diff --git a/mysqlclient b/src/server/oasisapi/analysis_models/v2_api/__init__.py similarity index 100% rename from mysqlclient rename to src/server/oasisapi/analysis_models/v2_api/__init__.py diff --git a/src/server/oasisapi/analysis_models/serializers.py b/src/server/oasisapi/analysis_models/v2_api/serializers.py similarity index 85% rename from src/server/oasisapi/analysis_models/serializers.py rename to src/server/oasisapi/analysis_models/v2_api/serializers.py index a7ae59904..5e85cd4f5 100644 --- a/src/server/oasisapi/analysis_models/serializers.py +++ b/src/server/oasisapi/analysis_models/v2_api/serializers.py @@ -4,20 +4,20 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError -from .models import AnalysisModel, SettingsTemplate -from ..analyses.models import Analysis +from ..models import AnalysisModel, SettingsTemplate +from ...analyses.models import Analysis -from .models import AnalysisModel, ModelScalingOptions, ModelChunkingOptions -from ..permissions.group_auth import validate_and_update_groups, validate_data_files +from ..models import AnalysisModel, ModelScalingOptions, ModelChunkingOptions +from ...permissions.group_auth import validate_and_update_groups, validate_data_files class AnalysisModelSerializer(serializers.ModelSerializer): - resource_file = serializers.SerializerMethodField() settings = serializers.SerializerMethodField() versions = serializers.SerializerMethodField() scaling_configuration = serializers.SerializerMethodField() chunking_configuration = serializers.SerializerMethodField() groups = serializers.SlugRelatedField(many=True, read_only=False, slug_field='name', required=False, queryset=Group.objects.all()) + namespace = 'v2-models' class Meta: model = AnalysisModel @@ -29,12 +29,12 @@ class Meta: 'created', 'modified', 'data_files', - 'resource_file', 'settings', 'versions', 'scaling_configuration', 'chunking_configuration', 'groups', + 'run_mode', ) def validate(self, attrs): @@ -53,30 +53,25 @@ def create(self, validated_data): data['creator'] = self.context.get('request').user return super(AnalysisModelSerializer, self).create(data) - @swagger_serializer_method(serializer_or_field=serializers.URLField) - def get_resource_file(self, instance): - request = self.context.get('request') - return instance.get_absolute_resources_file_url(request=request) - @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_settings(self, instance): request = self.context.get('request') - return instance.get_absolute_settings_url(request=request) + return instance.get_absolute_settings_url(request=request, namespace=self.namespace) @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_versions(self, instance): request = self.context.get('request') - return instance.get_absolute_versions_url(request=request) + return instance.get_absolute_versions_url(request=request, namespace=self.namespace) @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_scaling_configuration(self, instance): request = self.context.get('request') - return instance.get_absolute_scaling_configuration_url(request=request) + return instance.get_absolute_scaling_configuration_url(request=request, namespace=self.namespace) @swagger_serializer_method(serializer_or_field=serializers.URLField) def get_chunking_configuration(self, instance): request = self.context.get('request') - return instance.get_absolute_chunking_configuration_url(request=request) + return instance.get_absolute_chunking_configuration_url(request=request, namespace=self.namespace) class TemplateSerializer(serializers.ModelSerializer): @@ -204,6 +199,7 @@ class Meta: 'scaling_strategy', 'worker_count_fixed', 'worker_count_max', + 'worker_count_min', 'chunks_per_worker' ) @@ -211,9 +207,12 @@ def validate(self, attrs): non_neg_fields = [ 'worker_count_fixed', 'worker_count_max', + 'worker_count_min', 'chunks_per_worker' ] errors = dict() + + # check for negative values for k in non_neg_fields: value = self.initial_data.get(k) if value is not None: @@ -227,6 +226,18 @@ def validate(self, attrs): continue # Data is valid attrs[k] = value + + # Check that `worker_count_min` < `worker_count_max` + m_id = self.context['request'].parser_context['kwargs']['pk'] + current_val = ModelScalingOptions.objects.get(id=m_id) + + wrk_min = self.initial_data.get('worker_count_min', current_val.worker_count_min) + wrk_max = self.initial_data.get('worker_count_max', current_val.worker_count_max) + if 'worker_count_min' in self.initial_data and (wrk_min > wrk_max): + errors['worker_count_min'] = f"Value '{wrk_min}' must be less than 'worker_count_max: {wrk_max}'" + if 'worker_count_max' in self.initial_data and (wrk_max < wrk_min): + errors['worker_count_max'] = f"Value '{wrk_max}' must be greater than 'worker_count_min: {wrk_min}'" + if errors: raise serializers.ValidationError(errors) return super(ModelScalingConfigSerializer, self).validate(attrs) diff --git a/tests/integration/.gitkeep b/src/server/oasisapi/analysis_models/v2_api/tests/__init__.py similarity index 100% rename from tests/integration/.gitkeep rename to src/server/oasisapi/analysis_models/v2_api/tests/__init__.py diff --git a/src/server/oasisapi/analysis_models/v2_api/tests/fakes.py b/src/server/oasisapi/analysis_models/v2_api/tests/fakes.py new file mode 100644 index 000000000..69a27ddaf --- /dev/null +++ b/src/server/oasisapi/analysis_models/v2_api/tests/fakes.py @@ -0,0 +1,7 @@ +from model_mommy import mommy + +from src.server.oasisapi.analysis_models.models import AnalysisModel + + +def fake_analysis_model(**kwargs): + return mommy.make(AnalysisModel, **kwargs) diff --git a/src/server/oasisapi/analysis_models/tests/test_analysis_model.py b/src/server/oasisapi/analysis_models/v2_api/tests/test_analysis_model.py similarity index 86% rename from src/server/oasisapi/analysis_models/tests/test_analysis_model.py rename to src/server/oasisapi/analysis_models/v2_api/tests/test_analysis_model.py index c69729c6d..6eb190fc1 100644 --- a/src/server/oasisapi/analysis_models/tests/test_analysis_model.py +++ b/src/server/oasisapi/analysis_models/v2_api/tests/test_analysis_model.py @@ -8,7 +8,7 @@ from django_webtest import WebTest, WebTestMixin from hypothesis import given, settings from hypothesis.extra.django import TestCase -from hypothesis.strategies import text +from hypothesis.strategies import text, sampled_from from rest_framework_simplejwt.tokens import AccessToken from .fakes import fake_analysis_model @@ -18,6 +18,7 @@ # Override default deadline for all tests to 8s settings.register_profile("ci", deadline=800.0) settings.load_profile("ci") +NAMESPACE = 'v2-models' class AnalysisModelApi(WebTest, TestCase): @@ -29,7 +30,7 @@ def test_version_id_is_missing___response_is_400(self, supplier_id, version_id): user = fake_user() response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -52,7 +53,7 @@ def test_supplier_id_is_missing___response_is_400(self, supplier_id, version_id) user = fake_user() response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -76,7 +77,7 @@ def test_data_is_valid___object_is_created(self, supplier_id, model_id, version_ user = fake_user() response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -106,7 +107,7 @@ def test_create_with_default_groups___response_is_201(self, supplier_id, model_i add_fake_group(user, group_name) response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -136,7 +137,7 @@ def test_create_with_valid_groups___response_is_201(self, supplier_id, model_id, add_fake_group(user, group_name) response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -165,7 +166,7 @@ def test_create_with_invalid_groups___response_is_403(self, supplier_id, model_i user = fake_user() response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -197,7 +198,7 @@ def test_create_with_default_groups_and_get_with_other_groups___response_is_empt # List models as the user that created it response = self.app.get( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_with_group)) @@ -210,7 +211,7 @@ def test_create_with_default_groups_and_get_with_other_groups___response_is_empt # Test with a user not a member of the group the model was created with user_without_group = fake_user() response = self.app.get( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_without_group)) @@ -223,13 +224,13 @@ def test_create_with_default_groups_and_get_with_other_groups___response_is_empt @given( group_name=text(alphabet=string.ascii_letters, min_size=1, max_size=10), ) - def test_as_admin_create_with_non_existing_groups___successfully(self, group_name): + def test_as_admin_create_with_non_existing_groups___successfully_1(self, group_name): admin_user = fake_user() admin_user.is_staff = True admin_user.save() response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(admin_user)) @@ -249,11 +250,11 @@ def test_as_admin_create_with_non_existing_groups___successfully(self, group_nam @given( group_name=text(alphabet=string.ascii_letters, min_size=1, max_size=10), ) - def test_as_admin_create_with_non_existing_groups___successfully(self, group_name): + def test_as_admin_create_with_non_existing_groups___successfully_2(self, group_name): ordinary_user = fake_user() response = self.app.post( - reverse('analysis-model-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:analysis-model-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(ordinary_user)) @@ -273,8 +274,8 @@ def test_as_admin_create_with_non_existing_groups___successfully(self, group_nam class ModelSettingsJson(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): models = fake_analysis_model() - - response = self.app.get(models.get_absolute_settings_url(), expect_errors=True) + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + response = self.app.get(settings_url, expect_errors=True) self.assertIn(response.status_code, [401, 403]) """ Add these check back in once models auto-update their settings fields @@ -284,8 +285,9 @@ def test_settings_json_is_not_present___get_response_is_404(self): user = fake_user() models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) response = self.app.get( - models.get_absolute_settings_url(), + settings_url, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -298,8 +300,9 @@ def test_settings_json_is_not_present___delete_response_is_404(self): user = fake_user() models = fake_analysis_model() + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) response = self.app.delete( - models.get_absolute_settings_url(), + settings_url, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -350,8 +353,9 @@ def test_settings_json_is_not_valid___response_is_400(self): } } + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) response = self.app.post( - models.get_absolute_settings_url(), + settings_url, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -417,8 +421,9 @@ def test_settings_json_is_uploaded___can_be_retrieved(self): } } + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) self.app.post( - models.get_absolute_settings_url(), + settings_url, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -427,7 +432,7 @@ def test_settings_json_is_uploaded___can_be_retrieved(self): ) response = self.app.get( - models.get_absolute_settings_url(), + settings_url, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -435,3 +440,38 @@ def test_settings_json_is_uploaded___can_be_retrieved(self): self.assertDictEqual.__self__.maxDiff = None self.assertDictEqual(json.loads(response.body), json_data) self.assertEqual(response.content_type, 'application/json') + + @given( + run_mode_requested=sampled_from([ + AnalysisModel.run_mode_choices.V1, + AnalysisModel.run_mode_choices.V2, + ]) + ) + def test_settings_json_is_uploaded___run_mode_is_set(self, run_mode_requested): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + models = fake_analysis_model() + json_data = { + "model_run_mode": run_mode_requested, + "model_settings": {}, + "lookup_settings": {} + } + + settings_url = reverse(f'{NAMESPACE}:model-settings', kwargs={'pk': models.pk}) + self.app.post( + settings_url, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps(json_data), + content_type='application/json' + ) + + response = self.app.get( + reverse(f'{NAMESPACE}:analysis-model-detail', kwargs={'pk': models.id}), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + self.assertEqual(run_mode_requested, response.json.get('run_mode')) diff --git a/src/server/oasisapi/analysis_models/v2_api/urls.py b/src/server/oasisapi/analysis_models/v2_api/urls.py new file mode 100644 index 000000000..1706ddf6e --- /dev/null +++ b/src/server/oasisapi/analysis_models/v2_api/urls.py @@ -0,0 +1,23 @@ +from django.conf.urls import include, url +from rest_framework_nested import routers +from .viewsets import AnalysisModelViewSet, ModelSettingsView, SettingsTemplateViewSet + + +app_name = 'models' +v2_api_router = routers.SimpleRouter() +v2_api_router.register('models', AnalysisModelViewSet, basename='analysis-model') + +v2_templates_router = routers.NestedSimpleRouter(v2_api_router, r'models', lookup='models') +v2_templates_router.register('setting_templates', SettingsTemplateViewSet, basename='models-setting_templates') + +model_settings = ModelSettingsView.as_view({ + 'get': 'model_settings', + 'post': 'model_settings', + 'delete': 'model_settings' +}) + +urlpatterns = [ + url(r'models/(?P\d+)/settings/', model_settings, name='model-settings'), + url(r'', include(v2_api_router.urls)), + url(r'', include(v2_templates_router.urls)), +] diff --git a/src/server/oasisapi/analysis_models/v2_api/viewsets.py b/src/server/oasisapi/analysis_models/v2_api/viewsets.py new file mode 100644 index 000000000..26acf546a --- /dev/null +++ b/src/server/oasisapi/analysis_models/v2_api/viewsets.py @@ -0,0 +1,444 @@ +from __future__ import absolute_import + +from django.utils.translation import gettext_lazy as _ +from django.utils.decorators import method_decorator +from django_filters import rest_framework as filters +from django.http import Http404 +from drf_yasg.utils import swagger_auto_schema +from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.response import Response +from rest_framework.settings import api_settings + +from ..models import AnalysisModel, SettingsTemplate +from .serializers import ( + AnalysisModelSerializer, + ModelVersionsSerializer, + CreateTemplateSerializer, + TemplateSerializer, + ModelScalingConfigSerializer, + ModelChunkingConfigSerializer, +) + +from ...data_files.v2_api.serializers import DataFileSerializer +from ...filters import TimeStampedFilter +from ...files.views import handle_json_data +from ...files.serializers import RelatedFileSerializer +from ...permissions.group_auth import VerifyGroupAccessModelViewSet +from ...schemas.serializers import ModelParametersSerializer, AnalysisSettingsSerializer + + +class AnalysisModelFilter(TimeStampedFilter): + supplier_id = filters.CharFilter( + help_text=_('Filter results by case insensitive `supplier_id` equal to the given string'), + lookup_expr='iexact' + ) + supplier_id__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `supplier_id` containing the given string'), + lookup_expr='icontains', + field_name='supplier_id' + ) + model_id = filters.CharFilter( + help_text=_('Filter results by case insensitive `model_id` equal to the given string'), + lookup_expr='iexact' + ) + model_id__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `model_id` containing the given string'), + lookup_expr='icontains', + field_name='model_id' + ) + version_id = filters.CharFilter( + help_text=_('Filter results by case insensitive `version_id` equal to the given string'), + lookup_expr='iexact' + ) + version_id__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `version_id` containing the given string'), + lookup_expr='icontains', + field_name='version_id' + ) + user = filters.CharFilter( + help_text=_('Filter results by case insensitive `user` equal to the given string'), + lookup_expr='iexact', + field_name='creator__username' + ) + + class Meta: + model = AnalysisModel + fields = [ + 'supplier_id', + 'supplier_id__contains', + 'model_id', + 'model_id__contains', + 'version_id', + 'version_id__contains', + 'user', + ] + + +@method_decorator(name='create', decorator=swagger_auto_schema(request_body=CreateTemplateSerializer)) +class SettingsTemplateViewSet(viewsets.ModelViewSet): + """ + list: + Returns a list of analysis_settings files stored under a model as templates. + + retrieve: + Returns the specific templates entry. + + create: + Creates an analysis_settings template with an option to copy the settings from an analyses. + + update: + Updates the specified template + + partial_update: + Partially updates the template + """ + queryset = SettingsTemplate.objects.all() + serializer_class = TemplateSerializer + + def get_queryset(self): + models_pk = self.kwargs.get('models_pk') + if models_pk: + if not models_pk.isnumeric(): + raise Http404 + try: + template_queryset = AnalysisModel.objects.get(id=models_pk).template_files.all() + except AnalysisModel.DoesNotExist: + raise Http404 + return template_queryset + else: + return AnalysisModel.objects.none() + + def get_serializer_class(self): + if self.action in ['create']: + return CreateTemplateSerializer + else: + return super(SettingsTemplateViewSet, self).get_serializer_class() + + def list(self, request, models_pk=None, **kwargs): + context = {'request': request} + template_list = self.get_queryset() + serializer = TemplateSerializer(template_list, many=True, context=context) + return Response(serializer.data) + + def create(self, request, models_pk=None, **kwargs): + request_data = self.request.data + context = {'request': request} + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + new_template = serializer.create(serializer.validated_data) + new_template.save() + model = AnalysisModel.objects.get(id=models_pk) + model.template_files.add(new_template) + return Response(TemplateSerializer(new_template, context=context).data) + + @swagger_auto_schema(methods=['get'], responses={200: AnalysisSettingsSerializer}) + @swagger_auto_schema(methods=['post'], request_body=AnalysisSettingsSerializer, responses={201: RelatedFileSerializer}) + @action(methods=['get', 'post', 'delete'], detail=True) + def content(self, request, pk=None, models_pk=None, version=None): + """ + get: + Gets the analyses template `settings` contents + + post: + Sets the analyses template `settings` contents + + delete: + Disassociates the analyses template `settings_file` contents + """ + return handle_json_data(self.get_object(), 'file', request, AnalysisSettingsSerializer) + + +class AnalysisModelViewSet(VerifyGroupAccessModelViewSet): + """ + list: + Returns a list of Model objects. + + ### Examples + + To get all models with 'foo' in their name + + /models/?supplier_id__contains=foo + + To get all models with 'bar' in their name + + /models/?version_id__contains=bar + + To get all models created on 1970-01-01 + + /models/?created__date=1970-01-01 + + To get all models updated before 2000-01-01 + + /models/?modified__lt=2000-01-01 + + retrieve: + Returns the specific model entry. + + create: + Creates a model based on the input data + + update: + Updates the specified model + + partial_update: + Partially updates the specified model (only provided fields are updated) + """ + + serializer_class = AnalysisModelSerializer + filterset_class = AnalysisModelFilter + group_access_model = AnalysisModel + + def get_serializer_class(self): + if self.action in ['resource_file', 'set_resource_file']: + return RelatedFileSerializer + elif self.action in ['data_files']: + return DataFileSerializer + elif self.action in ['versions']: + return ModelVersionsSerializer + elif self.action in ['scaling_configuration']: + return ModelScalingConfigSerializer + elif self.action in ['chunking_configuration']: + return ModelChunkingConfigSerializer + else: + return super(AnalysisModelViewSet, self).get_serializer_class() + + @property + def parser_classes(self): + if getattr(self, 'action', None) in ['resource_file']: + return [MultiPartParser] + else: + return api_settings.DEFAULT_PARSER_CLASSES + + def create(self, *args, **kwargs): + request_data = self.request.data + unique_keys = ["supplier_id", "model_id", "version_id"] + + # check if the model is Soft-deleted + if all(k in request_data for k in unique_keys): + keys = {k: request_data[k] for k in unique_keys} + model = AnalysisModel.all_objects.filter(**keys) + if model.exists(): + model = model.first() + if model.deleted: + # If yes, then 'restore' and update + model.activate(self.request) + return Response(AnalysisModelSerializer(instance=model, + context=self.get_serializer_context()).data) + + return super(AnalysisModelViewSet, self).create(self.request) + + @action(methods=['get'], detail=True) + def versions(self, request, pk=None, version=None): + obj = self.get_object() + return Response(ModelVersionsSerializer(instance=obj, context=self.get_serializer_context()).data) + + @action(methods=['get', 'post'], detail=True) + def scaling_configuration(self, request, pk=None, version=None): + """ + get: + Configuration for creating parallel workers to process an analysis + + There are three scaling strategy to select from, `FIXED_WORKERS`, `QUEUE_LOAD`, and `DYNAMIC_TASKS`. These set the replica sets for this** + model deployment, but only when an analysis is submitted. + + When the model queue is inactive the number of workers is set to `worker_count_min`. (always running) + + ### Example 1 - Fixed workers + * When **any analyses** are either queued or running, create 5 workers. + * When **no analysis** are queued, spin down to `0` workers. + ``` + { + scaling_strategy: 'FIXED_WORKERS', + worker_count_fixed: 5, + worker_count_min: 0 + } + ``` + + ### Example 2 - Queue Load + The number of workers started depends on how many analyses are added to a model's queue. + * One worker created per analysis + * When inactive spin down to either `worker_count_min` (if set) or default to 0. + ``` + { + scaling_strategy: 'QUEUE_LOAD', + worker_count_max: 10 + } + ``` + + ### Example 3 - Dynamic + The number of workers started is controlled by the total sub-tasks, or chunks, on the model's queue. + + * If an analysis is split into 12 parts, then `12 (chunks on queue) / 4 (chunks_per_worker) = 3` Add three workers. + * The total number of sub-tasks is summed across all running analysis. + * The scaling is capped by `worker_count_max`, no more than 15 workers will be created. + * When idle spin the workers down to `worker_count_min`, a single worker is always 'warm' and waiting to process requests. + ``` + { + scaling_strategy: 'DYNAMIC_TASKS', + worker_count_max: 15, + chunks_per_worker: 4, + worker_count_min: 1 + } + ``` + + post: + Configuration for creating parallel workers to process an analysis + + There are three scaling strategy to select from, `FIXED_WORKERS`, `QUEUE_LOAD`, and `DYNAMIC_TASKS`. These set the replica sets for this** + model deployment, but only when an analysis is submitted. + + When the model queue is inactive the number of workers is set to `worker_count_min`. (always running) + + ### Example 1 - Fixed workers + * When **any analyses** are either queued or running, create 5 workers. + * When **no analysis** are queued, spin down to `0` workers. + ``` + { + scaling_strategy: 'FIXED_WORKERS', + worker_count_fixed: 5, + worker_count_min: 0 + } + ``` + + ### Example 2 - Queue Load + The number of workers started depends on how many analyses are added to a model's queue. + * One worker created per analysis + * When inactive spin down to either `worker_count_min` (if set) or default to 0. + ``` + { + scaling_strategy: 'QUEUE_LOAD', + worker_count_max: 10 + } + ``` + + ### Example 3 - Dynamic + The number of workers started is controlled by the total sub-tasks, or chunks, on the model's queue. + + * If an analysis is split into 12 parts, then `12 (chunks on queue) / 4 (chunks_per_worker) = 3` Add three workers. + * The total number of sub-tasks is summed across all running analysis. + * The scaling is capped by `worker_count_max`, no more than 15 workers will be created. + * When idle spin the workers down to `worker_count_min`, a single worker is always 'warm' and waiting to process requests. + ``` + { + scaling_strategy: 'DYNAMIC_TASKS', + worker_count_max: 15, + chunks_per_worker: 4, + worker_count_min: 1 + } + ``` + """ + method = request.method.lower() + if method == 'get': + serializer = self.get_serializer(self.get_object().scaling_options) + else: + serializer = self.get_serializer(self.get_object().scaling_options, data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) + + @action(methods=['get', 'post'], detail=True) + def chunking_configuration(self, request, pk=None, version=None): + """ + get: + Configuration used to split and distribute an analysis execution + + Each phase of the execution can either be set to `fixed` or `dynamic`, where fixed will split + all jobs into a fixed number of parts, and dyanmic scales depending on the input size. + + + ### Example 1 - Fixed chunking + Every analysis is broken into `10` lookup chunks, and `20` event batches. + ``` + { + lookup_strategy: `FIXED_CHUNKS`, + fixed_lookup_chunks: 10, + loss_strategy: `FIXED_CHUNKS`, + fixed_analysis_chunks: 20, + } + ``` + + ### Example 2 - Dyanmic chunking + * A lookup chunk is generated for each `1000` rows in the location file. + * An event batch is generated for each `1000` events in the selected event set. (analysis settings) + * Each execution phase is capped by `dynamic_chunks_max`, the number of chunks cannot exceed 250. + + ``` + { + lookup_strategy: `DYNAMIC_CHUNKS`, + dynamic_locations_per_lookup: 1000, + loss_strategy: `DYNAMIC_CHUNKS`, + dynamic_events_per_analysis: 1000, + dynamic_chunks_max: 250, + } + ``` + + post: + Configuration used to split and distribute an analysis execution + + Each phase of the execution can either be set to `fixed` or `dynamic`, where fixed will split + all jobs into a fixed number of parts, and dyanmic scales depending on the input size. + + + ### Example 1 - Fixed chunking + Every analysis is broken into `10` lookup chunks, and `20` event batches. + ``` + { + lookup_strategy: `FIXED_CHUNKS`, + fixed_lookup_chunks: 10, + loss_strategy: `FIXED_CHUNKS`, + fixed_analysis_chunks: 20, + } + ``` + + ### Example 2 - Dyanmic chunking + * A lookup chunk is generated for each `1000` rows in the location file. + * An event batch is generated for each `1000` events in the selected event set. (analysis settings) + * Each execution phase is capped by `dynamic_chunks_max`, the number of chunks cannot exceed 250. + + ``` + { + lookup_strategy: `DYNAMIC_CHUNKS`, + dynamic_locations_per_lookup: 1000, + loss_strategy: `DYNAMIC_CHUNKS`, + dynamic_events_per_analysis: 1000, + dynamic_chunks_max: 250, + } + ``` + """ + method = request.method.lower() + if method == 'get': + serializer = self.get_serializer(self.get_object().chunking_options) + else: + serializer = self.get_serializer(self.get_object().chunking_options, data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) + + @swagger_auto_schema(responses={200: DataFileSerializer(many=True)}) + @action(methods=['get'], detail=True) + def data_files(self, request, pk=None, version=None): + df = self.get_object().data_files.all() + context = {'request': request} + + df_serializer = DataFileSerializer(df, many=True, context=context) + return Response(df_serializer.data) + + +class ModelSettingsView(viewsets.ModelViewSet): + queryset = AnalysisModel.objects.all() + serializer_class = AnalysisModelSerializer + filterset_class = AnalysisModelFilter + + @swagger_auto_schema(method='get', responses={200: ModelParametersSerializer}) + @swagger_auto_schema(method='post', request_body=ModelParametersSerializer, responses={201: RelatedFileSerializer}) + @action(methods=['get', 'post', 'delete'], detail=True) + def model_settings(self, request, pk=None, version=None): + obj = self.get_object() + response = handle_json_data(obj, 'resource_file', request, ModelParametersSerializer) + # Update Model's execution mode if 'model_run_mode' is in model_settings.json + if request.method.lower() == 'post': + obj.update_run_mode() + return response diff --git a/src/server/oasisapi/asgi.py b/src/server/oasisapi/asgi.py index 219f81b6e..18b55e45c 100644 --- a/src/server/oasisapi/asgi.py +++ b/src/server/oasisapi/asgi.py @@ -13,5 +13,6 @@ application = get_default_application() # ONLY run the websocket from here (add safeguard to remove HTTP router) -# if 'http' in application.application_mapping: -# del application.application_mapping['http'] +if os.getenv('OASIS_DISABLE_HTTP', default=True): + if 'http' in application.application_mapping: + del application.application_mapping['http'] diff --git a/src/server/oasisapi/auth/views.py b/src/server/oasisapi/auth/views.py index 60125fc4d..c2c07cb03 100644 --- a/src/server/oasisapi/auth/views.py +++ b/src/server/oasisapi/auth/views.py @@ -26,7 +26,8 @@ class TokenRefreshView(BaseTokenRefreshView): @swagger_auto_schema( manual_parameters=[TOKEN_REFRESH_HEADER], responses={status.HTTP_200_OK: TokenRefreshResponseSerializer}, - security=[]) + security=[], + tags=['authentication']) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) @@ -39,6 +40,7 @@ class TokenObtainPairView(BaseTokenObtainPairView): @swagger_auto_schema( responses={status.HTTP_200_OK: TokenObtainPairResponseSerializer}, - security=[]) + security=[], + tags=['authentication']) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) diff --git a/src/server/oasisapi/base_urls.py b/src/server/oasisapi/base_urls.py new file mode 100644 index 000000000..c4cb4aab1 --- /dev/null +++ b/src/server/oasisapi/base_urls.py @@ -0,0 +1,17 @@ +from django.conf.urls import include, url +from django.contrib import admin + +from .info.views import PerilcodesView +from .info.views import ServerInfoView +from .healthcheck.views import HealthcheckView + +# app_name = 'base' + +urlpatterns = [ + url(r'^healthcheck/$', HealthcheckView.as_view(), name='healthcheck'), + url(r'^oed_peril_codes/$', PerilcodesView.as_view(), name='perilcodes'), + url(r'^server_info/$', ServerInfoView.as_view(), name='serverinfo'), + url(r'^auth/', include('rest_framework.urls')), + url(r'^', include('src.server.oasisapi.auth.urls', namespace='auth')), + url(r'^admin/', admin.site.urls), +] diff --git a/src/server/oasisapi/celery_app.py b/src/server/oasisapi/celery_app.py deleted file mode 100644 index fe5cee43a..000000000 --- a/src/server/oasisapi/celery_app.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import absolute_import - -import os - -from celery import Celery -from django.conf import settings - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'src.server.oasisapi.settings') - -celery_app = Celery('oasisapi') -celery_app.config_from_object('django.conf:settings') -celery_app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) diff --git a/src/server/oasisapi/celery_app_v1.py b/src/server/oasisapi/celery_app_v1.py new file mode 100644 index 000000000..db93cf833 --- /dev/null +++ b/src/server/oasisapi/celery_app_v1.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import + +import os +from celery import Celery +from django.conf import settings + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'src.server.oasisapi.settings.v1') +v1 = Celery('v1', include=['src.server.oasisapi.analyses.v1_api']) +v1.config_from_object('django.conf:settings') +v1.autodiscover_tasks(lambda: settings.INSTALLED_APPS) +v1.conf.update(CELERY_QUEUE_MAX_PRIORITY=None) diff --git a/src/server/oasisapi/celery_app_v2.py b/src/server/oasisapi/celery_app_v2.py new file mode 100644 index 000000000..c4199d056 --- /dev/null +++ b/src/server/oasisapi/celery_app_v2.py @@ -0,0 +1,27 @@ +from __future__ import absolute_import + +import os +from celery import Celery +from django.conf import settings +from kombu import Queue, Exchange +from ...conf.celeryconf_v2 import CELERY_QUEUE_MAX_PRIORITY + + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'src.server.oasisapi.settings.v2') +v2 = Celery('v2', include=['src.server.oasisapi.analyses.v2_api']) +v2.config_from_object('django.conf:settings') +v2.autodiscover_tasks(lambda: settings.INSTALLED_APPS) + +# CONFIG WARNING: +# +# Both Celery apps (v1 & v2) share a single 'conf' objecting when running in the HTTP server, +# Only way to support both Old and new dispatch options is to default the `CELERY_MAX_QUEUE_PRIORITY=None` +# then to manually set the v2 queues to priority n +# +# Once a V2 task hits either the "v2-worker-monitor" or "v2-task-controller" the 'src.server.oasisapi.settings.v2' settings will be active, +# which will load with the correct default "CELERY_MAX_QUEUE_PRIORITY=10" +v2.conf.task_queues = [ + Queue('celery-v2', Exchange('celery-v2'), routing_key='celery-v2', queue_arguments={'x-max-priority': CELERY_QUEUE_MAX_PRIORITY}), + Queue('task-controller', Exchange('task-controller'), routing_key='task-controller', + queue_arguments={'x-max-priority': CELERY_QUEUE_MAX_PRIORITY}), +] diff --git a/src/server/oasisapi/data_files/models.py b/src/server/oasisapi/data_files/models.py index 496b02122..ded78d107 100644 --- a/src/server/oasisapi/data_files/models.py +++ b/src/server/oasisapi/data_files/models.py @@ -33,7 +33,7 @@ class DataFile(TimeStampedModel): default=None, related_name="content_data_file" ) - groups = models.ManyToManyField(Group, blank=True, null=False, default=None, help_text='Groups allowed to access this object') + groups = models.ManyToManyField(Group, blank=True, default=None, help_text='Groups allowed to access this object') class Meta: ordering = ['id'] @@ -59,5 +59,6 @@ def get_content_type(self): else: return None - def get_absolute_data_file_url(self, request=None): - return reverse('data-file-content', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_data_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}data-file-content', kwargs={'pk': self.pk}, request=request) diff --git a/src/server/oasisapi/data_files/v1_api/__init__.py b/src/server/oasisapi/data_files/v1_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/data_files/v1_api/serializers.py b/src/server/oasisapi/data_files/v1_api/serializers.py new file mode 100644 index 000000000..e7788dcfb --- /dev/null +++ b/src/server/oasisapi/data_files/v1_api/serializers.py @@ -0,0 +1,82 @@ +from drf_yasg.utils import swagger_serializer_method +from rest_framework import serializers + +from ..models import DataFile + + +class DataFileListSerializer(serializers.Serializer): + """ Read Only DataFile Deserializer for efficiently returning a list of all + DataFile from DB + """ + # model fields + id = serializers.IntegerField(read_only=True) + file_description = serializers.CharField(read_only=True) + file_category = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + modified = serializers.DateTimeField(read_only=True) + + # File fields + file = serializers.SerializerMethodField(read_only=True) + filename = serializers.SerializerMethodField(read_only=True) + stored = serializers.SerializerMethodField(read_only=True) + content_type = serializers.SerializerMethodField(read_only=True) + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_data_file_url(request=request) if instance.file_id else None + + def get_filename(self, instance): + return instance.get_filename() + + def get_stored(self, instance): + return instance.get_filestore() + + def get_content_type(self, instance): + return instance.get_content_type() + + +class DataFileSerializer(serializers.ModelSerializer): + file = serializers.SerializerMethodField() + filename = serializers.SerializerMethodField() + stored = serializers.SerializerMethodField() + content_type = serializers.SerializerMethodField() + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = DataFile + fields = ( + 'id', + 'file_description', + 'file_category', + 'created', + 'modified', + 'file', + 'filename', + 'stored', + 'content_type', + ) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_file(self, instance): + request = self.context.get('request') + return instance.get_absolute_data_file_url(request=request) if instance.file else None + + def get_filename(self, instance): + return instance.get_filename() + + def get_stored(self, instance): + return instance.get_filestore() + + def get_content_type(self, instance): + return instance.get_content_type() + + def create(self, validated_data): + data = dict(validated_data) + # file_rsp = handle_related_file(self.get_object(), 'file', request, None) + if not data.get('creator') and 'request' in self.context: + data['creator'] = self.context.get('request').user + return super(DataFileSerializer, self).create(data) diff --git a/src/server/oasisapi/data_files/v1_api/tests/__init__.py b/src/server/oasisapi/data_files/v1_api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/data_files/tests/fakes.py b/src/server/oasisapi/data_files/v1_api/tests/fakes.py similarity index 83% rename from src/server/oasisapi/data_files/tests/fakes.py rename to src/server/oasisapi/data_files/v1_api/tests/fakes.py index e39639501..a9c74126d 100644 --- a/src/server/oasisapi/data_files/tests/fakes.py +++ b/src/server/oasisapi/data_files/v1_api/tests/fakes.py @@ -1,16 +1,16 @@ -from model_mommy import mommy - -from ..models import DataFile - - -def fake_data_file(**kwargs): - """Create a fake DataFile for test purposes. - - Args: - **kwargs: Keyword Arguments passed to DataFile - - Returns: - ComplexModelDataFile: A faked DataFile - - """ - return mommy.make(DataFile, **kwargs) +from model_mommy import mommy + +from src.server.oasisapi.data_files.models import DataFile + + +def fake_data_file(**kwargs): + """Create a fake DataFile for test purposes. + + Args: + **kwargs: Keyword Arguments passed to DataFile + + Returns: + ComplexModelDataFile: A faked DataFile + + """ + return mommy.make(DataFile, **kwargs) diff --git a/src/server/oasisapi/data_files/v1_api/tests/test_data_files.py b/src/server/oasisapi/data_files/v1_api/tests/test_data_files.py new file mode 100644 index 000000000..cfbc330bb --- /dev/null +++ b/src/server/oasisapi/data_files/v1_api/tests/test_data_files.py @@ -0,0 +1,128 @@ +import json +import mimetypes +import string +from tempfile import TemporaryDirectory + +from django.test import override_settings +from django.urls import reverse +from django_webtest import WebTestMixin +from hypothesis import given, settings +from hypothesis.extra.django import TestCase +from hypothesis.strategies import text, binary, sampled_from + +from rest_framework_simplejwt.tokens import AccessToken + +from src.server.oasisapi.auth.tests.fakes import fake_user +from src.server.oasisapi.data_files.models import DataFile +from .fakes import fake_data_file + +# Override default deadline for all tests to 8s +settings.register_profile("ci", deadline=800.0) +settings.load_profile("ci") +NAMESPACE = 'v1-files' + + +class ComplexModelFilesApi(WebTestMixin, TestCase): + + @given( + file_description=text(alphabet=string.ascii_letters, min_size=1, max_size=10), + ) + def test_data_is_valid___object_is_created(self, file_description): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:data-file-list'), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({ + 'file_description': file_description, + }), + content_type='application/json', + ) + + model = DataFile.objects.first() + + self.assertEqual(201, response.status_code) + self.assertEqual(model.file_description, file_description) + + +class ComplexModelFileDataFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + cmf = fake_data_file() + + response = self.app.get(cmf.get_absolute_data_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_data_file_is_not_present___get_response_is_404(self): + user = fake_user() + cmf = fake_data_file() + + response = self.app.get( + cmf.get_absolute_data_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_data_file_is_not_present___delete_response_is_404(self): + user = fake_user() + cmf = fake_data_file() + + response = self.app.delete( + cmf.get_absolute_data_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_data_file_is_unknown_format___response_is_200(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + cmf = fake_data_file() + + response = self.app.post( + cmf.get_absolute_data_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.tar', b'an-unknown-mime-format'), + ), + ) + + self.assertEqual(200, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv', 'application/json', 'application/octet-stream', 'image/tiff'])) + def test_data_file_is_uploaded___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + cmf = fake_data_file() + + self.app.post( + cmf.get_absolute_data_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + response = self.app.get( + cmf.get_absolute_data_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) diff --git a/src/server/oasisapi/data_files/v1_api/urls.py b/src/server/oasisapi/data_files/v1_api/urls.py new file mode 100644 index 000000000..3e4316191 --- /dev/null +++ b/src/server/oasisapi/data_files/v1_api/urls.py @@ -0,0 +1,13 @@ +from django.conf.urls import include, url +from rest_framework.routers import SimpleRouter +from .viewsets import DataFileViewset + + +app_name = 'data_files' +v1_api_router = SimpleRouter() +v1_api_router.include_root_view = False +v1_api_router.register('data_files', DataFileViewset, basename='data-file') + +urlpatterns = [ + url(r'', include(v1_api_router.urls)), +] diff --git a/src/server/oasisapi/data_files/v1_api/viewsets.py b/src/server/oasisapi/data_files/v1_api/viewsets.py new file mode 100644 index 000000000..f1c2b0467 --- /dev/null +++ b/src/server/oasisapi/data_files/v1_api/viewsets.py @@ -0,0 +1,118 @@ +from django.utils.translation import gettext_lazy as _ +from django_filters import rest_framework as filters +from drf_yasg.utils import swagger_auto_schema +from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.settings import api_settings + +from ...files.serializers import RelatedFileSerializer +from ...files.views import handle_related_file +from ...filters import TimeStampedFilter +from ..models import DataFile +from ...schemas.custom_swagger import FILE_RESPONSE +from .serializers import DataFileSerializer, DataFileListSerializer + + +class DataFileFilter(TimeStampedFilter): + filename = filters.CharFilter( + help_text=_('Filter results by case insensitive `filename` equal to the given string'), + lookup_expr='iexact', + field_name='file__filename' + ) + filename__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `filename` containing the given string'), + lookup_expr='icontains', + field_name='file__filename' + ) + content_type = filters.CharFilter( + help_text=_('Filter results by case insensitive `content_type` equal to the given string'), + lookup_expr='iexact', + field_name='file__content_type' + ) + content_type__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `content_type` containing the given string'), + lookup_expr='icontains', + field_name='file__content_type' + ) + file_description = filters.CharFilter( + help_text=_('Filter results by case insensitive `file_description` equal to the given string'), + lookup_expr='iexact', + field_name='file_description' + ) + file_description__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `file_description` containing the given string'), + lookup_expr='icontains', + field_name='file_description' + ) + file_category = filters.CharFilter( + help_text=_('Filter results by case insensitive `file_category` equal to the given string'), + lookup_expr='iexact', + field_name='file_category' + ) + file_category__contains = filters.CharFilter( + help_text=_('Filter results by case insensitive `file_category` containing the given string'), + lookup_expr='icontains', + field_name='file_category' + ) + user = filters.CharFilter( + help_text=_('Filter results by case insensitive `user` equal to the given string'), + lookup_expr='iexact', + field_name='creator__username' + ) + + class Meta: + model = DataFile + fields = [ + 'filename', + 'filename__contains', + 'file_description', + 'file_description__contains', + 'file_category', + 'file_category__contains', + 'user', + ] + + +class DataFileViewset(viewsets.ModelViewSet): + queryset = DataFile.objects.all().select_related('file') + serializer_class = DataFileSerializer + filterset_class = DataFileFilter + + def get_serializer_class(self): + if self.action in ['content', 'set_content']: + return RelatedFileSerializer + elif self.action in ['list']: + return DataFileListSerializer + else: + return super(DataFileViewset, self).get_serializer_class() + + @property + def parser_classes(self): + if getattr(self, 'action', None) in ['set_content']: + return [MultiPartParser] + else: + return api_settings.DEFAULT_PARSER_CLASSES + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}) + @action(methods=['get', 'delete'], detail=True) + def content(self, request, pk=None, version=None): + """ + get: + Gets the data file's file contents + + delete: + Deletes the data file. + """ + + file_response = handle_related_file(self.get_object(), 'file', request, None) + return file_response + + @content.mapping.post + def set_content(self, request, pk=None, version=None): + """ + post: + Sets the data file's `file` contents + """ + + return handle_related_file(self.get_object(), 'file', request, None) diff --git a/src/server/oasisapi/data_files/serializers.py b/src/server/oasisapi/data_files/v2_api/serializers.py similarity index 97% rename from src/server/oasisapi/data_files/serializers.py rename to src/server/oasisapi/data_files/v2_api/serializers.py index 18d1ac9b7..4f85a0881 100644 --- a/src/server/oasisapi/data_files/serializers.py +++ b/src/server/oasisapi/data_files/v2_api/serializers.py @@ -2,15 +2,14 @@ from drf_yasg.utils import swagger_serializer_method from rest_framework import serializers -from .models import DataFile -from ..permissions.group_auth import validate_and_update_groups +from ..models import DataFile +from ...permissions.group_auth import validate_and_update_groups class DataFileListSerializer(serializers.Serializer): """ Read Only DataFile Deserializer for efficiently returning a list of all DataFile from DB """ - # model fields id = serializers.IntegerField(read_only=True) file_description = serializers.CharField(read_only=True) diff --git a/src/server/oasisapi/data_files/v2_api/tests/__init__.py b/src/server/oasisapi/data_files/v2_api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/data_files/v2_api/tests/fakes.py b/src/server/oasisapi/data_files/v2_api/tests/fakes.py new file mode 100644 index 000000000..76d23574d --- /dev/null +++ b/src/server/oasisapi/data_files/v2_api/tests/fakes.py @@ -0,0 +1,16 @@ +from model_mommy import mommy + +from src.server.oasisapi.data_files.models import DataFile + + +def fake_data_file(**kwargs): + """Create a fake DataFile for test purposes. + + Args: + **kwargs: Keyword Arguments passed to DataFile + + Returns: + ComplexModelDataFile: A faked DataFile + + """ + return mommy.make(DataFile, **kwargs) diff --git a/src/server/oasisapi/data_files/tests/test_data_files.py b/src/server/oasisapi/data_files/v2_api/tests/test_data_files.py similarity index 84% rename from src/server/oasisapi/data_files/tests/test_data_files.py rename to src/server/oasisapi/data_files/v2_api/tests/test_data_files.py index 058488e0d..67deffd64 100644 --- a/src/server/oasisapi/data_files/tests/test_data_files.py +++ b/src/server/oasisapi/data_files/v2_api/tests/test_data_files.py @@ -12,13 +12,14 @@ from rest_framework_simplejwt.tokens import AccessToken -from ...auth.tests.fakes import fake_user, add_fake_group -from ..models import DataFile +from src.server.oasisapi.auth.tests.fakes import fake_user, add_fake_group +from src.server.oasisapi.data_files.models import DataFile from .fakes import fake_data_file # Override default deadline for all tests to 8s settings.register_profile("ci", deadline=800.0) settings.load_profile("ci") +NAMESPACE = 'v2-files' class ComplexModelFilesApi(WebTestMixin, TestCase): @@ -32,7 +33,7 @@ def test_data_is_valid___object_is_created(self, file_description, group_name): add_fake_group(user, group_name) response = self.app.post( - reverse('data-file-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:data-file-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -53,7 +54,7 @@ class ComplexModelFileDataFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): cmf = fake_data_file() - response = self.app.get(cmf.get_absolute_data_file_url(), expect_errors=True) + response = self.app.get(cmf.get_absolute_data_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_data_file_is_not_present___get_response_is_404(self): @@ -61,7 +62,7 @@ def test_data_file_is_not_present___get_response_is_404(self): cmf = fake_data_file() response = self.app.get( - cmf.get_absolute_data_file_url(), + cmf.get_absolute_data_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -75,7 +76,7 @@ def test_data_file_is_not_present___delete_response_is_404(self): cmf = fake_data_file() response = self.app.delete( - cmf.get_absolute_data_file_url(), + cmf.get_absolute_data_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -91,7 +92,7 @@ def test_data_file_is_unknown_format___response_is_200(self): cmf = fake_data_file() response = self.app.post( - cmf.get_absolute_data_file_url(), + cmf.get_absolute_data_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -110,7 +111,7 @@ def test_data_file_is_uploaded___file_can_be_retrieved(self, file_content, conte cmf = fake_data_file() self.app.post( - cmf.get_absolute_data_file_url(), + cmf.get_absolute_data_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -120,7 +121,7 @@ def test_data_file_is_uploaded___file_can_be_retrieved(self, file_content, conte ) response = self.app.get( - cmf.get_absolute_data_file_url(), + cmf.get_absolute_data_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, diff --git a/src/server/oasisapi/data_files/v2_api/urls.py b/src/server/oasisapi/data_files/v2_api/urls.py new file mode 100644 index 000000000..0feb9528f --- /dev/null +++ b/src/server/oasisapi/data_files/v2_api/urls.py @@ -0,0 +1,13 @@ +from django.conf.urls import include, url +from rest_framework.routers import SimpleRouter +from .viewsets import DataFileViewset + + +app_name = 'data_files' +v2_api_router = SimpleRouter() +v2_api_router.include_root_view = False +v2_api_router.register('data_files', DataFileViewset, basename='data-file') + +urlpatterns = [ + url(r'', include(v2_api_router.urls)), +] diff --git a/src/server/oasisapi/data_files/viewsets.py b/src/server/oasisapi/data_files/v2_api/viewsets.py similarity index 93% rename from src/server/oasisapi/data_files/viewsets.py rename to src/server/oasisapi/data_files/v2_api/viewsets.py index a05c77a4d..c65eb3341 100644 --- a/src/server/oasisapi/data_files/viewsets.py +++ b/src/server/oasisapi/data_files/v2_api/viewsets.py @@ -6,12 +6,12 @@ from rest_framework.parsers import MultiPartParser from rest_framework.settings import api_settings -from ..files.serializers import RelatedFileSerializer -from ..files.views import handle_related_file -from ..filters import TimeStampedFilter -from .models import DataFile -from ..permissions.group_auth import VerifyGroupAccessModelViewSet -from ..schemas.custom_swagger import FILE_RESPONSE +from ...files.serializers import RelatedFileSerializer +from ...files.views import handle_related_file +from ...filters import TimeStampedFilter +from ..models import DataFile +from ...permissions.group_auth import VerifyGroupAccessModelViewSet +from ...schemas.custom_swagger import FILE_RESPONSE from .serializers import DataFileSerializer, DataFileListSerializer diff --git a/src/server/oasisapi/files/models.py b/src/server/oasisapi/files/models.py index b2b29e506..75a52b12a 100644 --- a/src/server/oasisapi/files/models.py +++ b/src/server/oasisapi/files/models.py @@ -87,7 +87,7 @@ class RelatedFile(TimeStampedModel): content_type = models.CharField(max_length=255) objects = RelatedFileManager() # ARCH2020 -- Is this actually used?? store_as_filename = models.BooleanField(default=False, blank=True, null=True) - groups = models.ManyToManyField(Group, blank=True, null=False, default=None, help_text='Groups allowed to access this object') + groups = models.ManyToManyField(Group, blank=True, default=None, help_text='Groups allowed to access this object') oed_validated = models.BooleanField(default=False, editable=False) def __str__(self): diff --git a/src/server/oasisapi/files/serializers.py b/src/server/oasisapi/files/serializers.py index 5f1644bfc..eb6ae8b56 100644 --- a/src/server/oasisapi/files/serializers.py +++ b/src/server/oasisapi/files/serializers.py @@ -4,7 +4,6 @@ from pathlib import Path from ods_tools.oed.exposure import OedExposure -from ods_tools.oed.common import OdsException from django.contrib.auth.models import Group from rest_framework import serializers @@ -71,8 +70,12 @@ def validate(self, attrs): EXPOSURE_ARGS[self.oed_field]: attrs['file'], 'validation_config': django_settings.PORTFOLIO_VALIDATION_CONFIG }) - except OdsException as e: - raise ValidationError('Failed to read exposure data, file is corrupted or set with incorrect format', e) + except Exception as e: + raise ValidationError({ + 'error': 'Failed to validate exposure data', + 'detail': str(e), + 'exception': type(e).__name__ + }) # Run OED Validation if run_validation: diff --git a/src/server/oasisapi/healthcheck/tests/test_healthcheck.py b/src/server/oasisapi/healthcheck/tests/test_healthcheck.py index 5d88e01d2..24735cd92 100644 --- a/src/server/oasisapi/healthcheck/tests/test_healthcheck.py +++ b/src/server/oasisapi/healthcheck/tests/test_healthcheck.py @@ -7,12 +7,15 @@ class Healthcheck(WebTest): - def test_user_is_not_authenticated___response_is_ok(self): - with patch.object(HealthcheckView, 'celery_is_ok', return_value=True) as mock_method: - response = self.app.get(reverse('healthcheck')) - self.assertEqual(200, response.status_code) - def test_user_is_authenticated___response_is_ok(self): - with patch.object(HealthcheckView, 'celery_is_ok', return_value=True) as mock_method: - response = self.app.get(reverse('healthcheck'), user=fake_user()) - self.assertEqual(200, response.status_code) + @patch.object(HealthcheckView, 'celery_v1_is_ok', return_value=True) + @patch.object(HealthcheckView, 'celery_v2_is_ok', return_value=True) + def test_user_is_not_authenticated___response_is_ok(self, mock_celery_v1, mock_celery_v2): + response = self.app.get(reverse('healthcheck')) + self.assertEqual(200, response.status_code) + + @patch.object(HealthcheckView, 'celery_v1_is_ok', return_value=True) + @patch.object(HealthcheckView, 'celery_v2_is_ok', return_value=True) + def test_user_is_authenticated___response_is_ok(self, mock_celery_v1, mock_celery_v2): + response = self.app.get(reverse('healthcheck'), user=fake_user()) + self.assertEqual(200, response.status_code) diff --git a/src/server/oasisapi/healthcheck/views.py b/src/server/oasisapi/healthcheck/views.py index 727f59e71..f9e7cb6f1 100644 --- a/src/server/oasisapi/healthcheck/views.py +++ b/src/server/oasisapi/healthcheck/views.py @@ -7,7 +7,8 @@ from rest_framework import views, status from rest_framework.response import Response -from ..celery_app import celery_app +from ..celery_app_v1 import v1 as celery_app_v1 +from ..celery_app_v2 import v2 as celery_app_v2 from ..schemas.custom_swagger import HEALTHCHECK @@ -20,7 +21,7 @@ class HealthcheckView(views.APIView): authentication_classes = [] permission_classes = [] - @swagger_auto_schema(responses={200: HEALTHCHECK}) + @swagger_auto_schema(responses={200: HEALTHCHECK}, tags=['info']) def get(self, request): """ Check db and celery connectivity and return a 200 if healthy, 503 if not. @@ -29,8 +30,11 @@ def get(self, request): status_text = 'OK' if not django_settings.CONSOLE_DEBUG: - if not self.celery_is_ok(): - status_text = 'ERROR - celery down' + if not self.celery_v1_is_ok(): + status_text = 'ERROR - celery v1 down' + code = status.HTTP_503_SERVICE_UNAVAILABLE + elif not self.celery_v2_is_ok(): + status_text = 'ERROR - celery v2 down' code = status.HTTP_503_SERVICE_UNAVAILABLE elif not self.db_is_ok(): status_text = 'ERROR - db down' @@ -38,14 +42,30 @@ def get(self, request): return Response({'status': status_text}, code) - def celery_is_ok(self) -> bool: + def celery_v1_is_ok(self) -> bool: + """ + Verify a healthy celery connection. + + :return: True if healthy, False it not. + """ + try: + i = celery_app_v1.control.inspect() + availability = i.ping() + if not availability: + return False + except Exception as e: + logging.error('Celery error: %s', e) + return False + return True + + def celery_v2_is_ok(self) -> bool: """ Verify a healthy celery connection. :return: True if healthy, False it not. """ try: - i = celery_app.control.inspect() + i = celery_app_v2.control.inspect() availability = i.ping() if not availability: return False diff --git a/src/server/oasisapi/info/views.py b/src/server/oasisapi/info/views.py index b1c7a6cc8..8a92f888d 100644 --- a/src/server/oasisapi/info/views.py +++ b/src/server/oasisapi/info/views.py @@ -14,6 +14,7 @@ class PerilcodesView(views.APIView): authentication_classes = [] permission_classes = [] + @swagger_auto_schema(tags=['info']) def get(self, request): peril_codes = {PERILS[p]['id']: {'desc': PERILS[p]['desc']} for p in PERILS.keys()} peril_groups = { @@ -34,7 +35,7 @@ class ServerInfoView(views.APIView): Return a list of all support OED peril codes in the oasislmf package """ - @swagger_auto_schema(responses={200: SERVER_INFO}) + @swagger_auto_schema(responses={200: SERVER_INFO}, tags=['info']) def get(self, request): server_version = "" server_config = dict() diff --git a/src/server/oasisapi/portfolios/models.py b/src/server/oasisapi/portfolios/models.py index 2db547bd1..631947ea5 100644 --- a/src/server/oasisapi/portfolios/models.py +++ b/src/server/oasisapi/portfolios/models.py @@ -18,7 +18,7 @@ class Portfolio(TimeStampedModel): name = models.CharField(max_length=255, help_text=_('The name of the portfolio')) creator = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, related_name='portfolios') - groups = models.ManyToManyField(Group, blank=True, null=False, default=None, help_text='Groups allowed to access this object') + groups = models.ManyToManyField(Group, blank=True, default=None, help_text='Groups allowed to access this object') accounts_file = models.ForeignKey(RelatedFile, on_delete=models.CASCADE, blank=True, null=True, default=None, related_name='accounts_file_portfolios') @@ -35,26 +35,33 @@ class Meta: def __str__(self): return self.name - def get_absolute_url(self, request=None): - return reverse('portfolio-detail', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-detail', kwargs={'pk': self.pk}, request=request) - def get_absolute_create_analysis_url(self, request=None): - return reverse('portfolio-create-analysis', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_create_analysis_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-create-analysis', kwargs={'pk': self.pk}, request=request) - def get_absolute_accounts_file_url(self, request=None): - return reverse('portfolio-accounts-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_accounts_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-accounts-file', kwargs={'pk': self.pk}, request=request) - def get_absolute_location_file_url(self, request=None): - return reverse('portfolio-location-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_location_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-location-file', kwargs={'pk': self.pk}, request=request) - def get_absolute_reinsurance_info_file_url(self, request=None): - return reverse('portfolio-reinsurance-info-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_reinsurance_info_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-reinsurance-info-file', kwargs={'pk': self.pk}, request=request) - def get_absolute_reinsurance_scope_file_url(self, request=None): - return reverse('portfolio-reinsurance-scope-file', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_reinsurance_scope_file_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-reinsurance-scope-file', kwargs={'pk': self.pk}, request=request) - def get_absolute_storage_url(self, request=None): - return reverse('portfolio-storage-links', kwargs={'version': 'v1', 'pk': self.pk}, request=request) + def get_absolute_storage_url(self, request=None, namespace=None): + override_ns = f'{namespace}:' if namespace else '' + return reverse(f'{override_ns}portfolio-storage-links', kwargs={'pk': self.pk}, request=request) def location_file_len(self): csv_compression_types = { @@ -84,13 +91,20 @@ def set_portolio_valid(self): file_ref.save() def run_oed_validation(self): - portfolio_exposure = OedExposure( - location=getattr(self.location_file, 'file', None), - account=getattr(self.accounts_file, 'file', None), - ri_info=getattr(self.reinsurance_info_file, 'file', None), - ri_scope=getattr(self.reinsurance_scope_file, 'file', None), - validation_config=settings.PORTFOLIO_VALIDATION_CONFIG) - validation_errors = portfolio_exposure.check() + try: + portfolio_exposure = OedExposure( + location=getattr(self.location_file, 'file', None), + account=getattr(self.accounts_file, 'file', None), + ri_info=getattr(self.reinsurance_info_file, 'file', None), + ri_scope=getattr(self.reinsurance_scope_file, 'file', None), + validation_config=settings.PORTFOLIO_VALIDATION_CONFIG) + validation_errors = portfolio_exposure.check() + except Exception as e: + raise ValidationError({ + 'error': 'Failed to validate portfolio', + 'detail': str(e), + 'exception': type(e).__name__ + }) # Set validation fields to true or raise exception if validation_errors: diff --git a/src/server/oasisapi/portfolios/v1_api/__init__.py b/src/server/oasisapi/portfolios/v1_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/portfolios/v1_api/serializers.py b/src/server/oasisapi/portfolios/v1_api/serializers.py new file mode 100644 index 000000000..aeaf20b01 --- /dev/null +++ b/src/server/oasisapi/portfolios/v1_api/serializers.py @@ -0,0 +1,436 @@ +from os import path +import mimetypes + +from drf_yasg.utils import swagger_serializer_method +from rest_framework import serializers +from rest_framework.exceptions import ValidationError +from django.core.files.storage import default_storage +from django.core.files import File +from django.core.files.base import ContentFile +from django.core.exceptions import ObjectDoesNotExist +from botocore.exceptions import ClientError as S3_ClientError +from azure.core.exceptions import ResourceNotFoundError as Blob_ResourceNotFoundError +from azure.storage.blob import BlobLeaseClient + +from ...analyses.v1_api.serializers import AnalysisSerializer +from ...files.models import file_storage_link +from ...files.models import RelatedFile +from ...files.upload import wait_for_blob_copy +from ..models import Portfolio + +from ...schemas.serializers import ( + LocFileSerializer, + AccFileSerializer, + ReinsInfoFileSerializer, + ReinsScopeFileSerializer, +) + + +class PortfolioListSerializer(serializers.Serializer): + """ Read Only Portfolio Deserializer for efficiently returning a list of all + Portfolios in DB + """ + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + + id = serializers.IntegerField(read_only=True) + name = serializers.CharField(read_only=True) + created = serializers.DateTimeField(read_only=True) + modified = serializers.DateTimeField(read_only=True) + accounts_file = serializers.SerializerMethodField(read_only=True) + location_file = serializers.SerializerMethodField(read_only=True) + reinsurance_info_file = serializers.SerializerMethodField(read_only=True) + reinsurance_scope_file = serializers.SerializerMethodField(read_only=True) + storage_links = serializers.SerializerMethodField(read_only=True) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_storage_links(self, instance): + request = self.context.get('request') + return instance.get_absolute_storage_url(request=request) + + @swagger_serializer_method(serializer_or_field=LocFileSerializer) + def get_location_file(self, instance): + if instance.location_file_id is None: + return None + request = self.context.get('request') + return { + "uri": instance.get_absolute_location_file_url(request=request), + "name": instance.location_file.filename, + "stored": str(instance.location_file.file) + } + + @swagger_serializer_method(serializer_or_field=AccFileSerializer) + def get_accounts_file(self, instance): + if instance.accounts_file_id is None: + return None + request = self.context.get('request') + return { + "uri": instance.get_absolute_accounts_file_url(request=request), + "name": instance.accounts_file.filename, + "stored": str(instance.accounts_file.file) + } + + @swagger_serializer_method(serializer_or_field=ReinsInfoFileSerializer) + def get_reinsurance_info_file(self, instance): + if instance.reinsurance_info_file_id is None: + return None + + request = self.context.get('request') + return { + "uri": instance.get_absolute_reinsurance_info_file_url(request=request), + "name": instance.reinsurance_info_file.filename, + "stored": str(instance.reinsurance_info_file.file) + } + + @swagger_serializer_method(serializer_or_field=ReinsScopeFileSerializer) + def get_reinsurance_scope_file(self, instance): + if instance.reinsurance_scope_file_id is None: + return None + request = self.context.get('request') + return { + "uri": instance.get_absolute_reinsurance_scope_file_url(request=request), + "name": instance.reinsurance_scope_file.filename, + "stored": str(instance.reinsurance_scope_file.file) + } + + +class PortfolioSerializer(serializers.ModelSerializer): + accounts_file = serializers.SerializerMethodField() + location_file = serializers.SerializerMethodField() + reinsurance_info_file = serializers.SerializerMethodField() + reinsurance_scope_file = serializers.SerializerMethodField() + storage_links = serializers.SerializerMethodField() + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = Portfolio + fields = ( + 'id', + 'name', + 'created', + 'modified', + 'location_file', + 'accounts_file', + 'reinsurance_info_file', + 'reinsurance_scope_file', + 'storage_links', + ) + + def create(self, validated_data): + data = dict(validated_data) + if not data.get('creator') and 'request' in self.context: + data['creator'] = self.context.get('request').user + return super(PortfolioSerializer, self).create(data) + + @swagger_serializer_method(serializer_or_field=serializers.URLField) + def get_storage_links(self, instance): + request = self.context.get('request') + return instance.get_absolute_storage_url(request=request) + + @swagger_serializer_method(serializer_or_field=LocFileSerializer) + def get_location_file(self, instance): + if not instance.location_file: + return None + else: + request = self.context.get('request') + return { + "uri": instance.get_absolute_location_file_url(request=request), + "name": instance.location_file.filename, + "stored": str(instance.location_file.file) + } + + @swagger_serializer_method(serializer_or_field=AccFileSerializer) + def get_accounts_file(self, instance): + if not instance.accounts_file: + return None + else: + request = self.context.get('request') + return { + "uri": instance.get_absolute_accounts_file_url(request=request), + "name": instance.accounts_file.filename, + "stored": str(instance.accounts_file.file) + } + + @swagger_serializer_method(serializer_or_field=ReinsInfoFileSerializer) + def get_reinsurance_info_file(self, instance): + if not instance.reinsurance_info_file: + return None + else: + request = self.context.get('request') + return { + "uri": instance.get_absolute_reinsurance_info_file_url(request=request), + "name": instance.reinsurance_info_file.filename, + "stored": str(instance.reinsurance_info_file.file) + } + + @swagger_serializer_method(serializer_or_field=ReinsScopeFileSerializer) + def get_reinsurance_scope_file(self, instance): + if not instance.reinsurance_scope_file: + return None + else: + request = self.context.get('request') + return { + "uri": instance.get_absolute_reinsurance_scope_file_url(request=request), + "name": instance.reinsurance_scope_file.filename, + "stored": str(instance.reinsurance_scope_file.file) + } + + +class PortfolioStorageSerializer(serializers.ModelSerializer): + accounts_file = serializers.SerializerMethodField() + location_file = serializers.SerializerMethodField() + reinsurance_info_file = serializers.SerializerMethodField() + reinsurance_scope_file = serializers.SerializerMethodField() + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = Portfolio + fields = ( + 'location_file', + 'accounts_file', + 'reinsurance_info_file', + 'reinsurance_scope_file', + ) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_location_file(self, instance): + return file_storage_link(instance.location_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_accounts_file(self, instance): + return file_storage_link(instance.accounts_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_reinsurance_info_file(self, instance): + return file_storage_link(instance.reinsurance_info_file, True) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_reinsurance_scope_file(self, instance): + return file_storage_link(instance.reinsurance_scope_file, True) + + def is_in_storage(self, value): + # Check AWS storage + if hasattr(default_storage, 'bucket'): + try: + default_storage.bucket.Object(value).load() + return True + except S3_ClientError as e: + if e.response['Error']['Code'] == "404": + return False + else: + raise e + # Check Azure Blob storage + elif hasattr(default_storage, 'azure_container'): + try: + blob = default_storage.client.get_blob_client(value) + blob.get_blob_properties() + return True + except Blob_ResourceNotFoundError: + return False + else: + return default_storage.exists(value) + + def validate(self, attrs): + file_keys = [k for k in self.fields.keys()] + + # Check for at least one entry + file_values = [v for k, v in self.initial_data.items() if k in file_keys] + if len(file_values) == 0: + raise serializers.ValidationError('At least one file field reference required from [{}]'.format(', '.join(file_keys))) + + errors = dict() + for k in file_keys: + value = self.initial_data.get(k) + if value is not None: + + # Check type is string + if not isinstance(value, str): + errors[k] = "Value is not type string, found {}".format(type(value)) + continue + # Check String is not empry + elif len(value.strip()) < 1: + errors[k] = "Value is emtpry or whitespace string." + continue + # Check that the file exisits + elif not self.is_in_storage(value): + errors[k] = "File '{}' not found in default storage".format(value) + continue + + # Data is valid + attrs[k] = value + if errors: + raise serializers.ValidationError(errors) + return super(PortfolioStorageSerializer, self).validate(attrs) + + def inferr_content_type(self, stored_filename): + inferred_type = mimetypes.MimeTypes().guess_type(stored_filename)[0] + if not inferred_type and stored_filename.lower().endswith('parquet'): + # mimetypes dosn't work for parquet so handle that here + inferred_type = 'application/octet-stream' + if not inferred_type: + inferred_type = default_storage.default_content_type + return inferred_type + + def get_content_type(self, stored_filename): + try: # fetch content_type stored in Django's DB + return RelatedFile.objects.get(file=path.basename(stored_filename)).content_type + except ObjectDoesNotExist: + # Find content_type from S3 Object header + if hasattr(default_storage, 'bucket'): + try: + object_header = default_storage.connection.meta.client.head_object( + Bucket=default_storage.bucket_name, + Key=stored_filename) + return object_header['ContentType'] + except S3_ClientError: + return self.inferr_content_type(stored_filename) + + # Find content_type from Blob Storage + elif hasattr(default_storage, 'azure_container'): + blob_client = default_storage.client.get_blob_client(stored_filename) + blob_properties = blob_client.get_blob_properties() + return blob_properties.content_settings.content_type + + else: + return self.inferr_content_type(stored_filename) + + def update(self, instance, validated_data): + files_for_removal = list() + + for field in validated_data: + old_file_name = validated_data[field] + content_type = self.get_content_type(old_file_name) + fname = path.basename(old_file_name) + new_file_name = default_storage.get_alternative_name(fname, '') + + # S3 storage - File copy needed + if hasattr(default_storage, 'bucket'): + new_file = ContentFile(b'') + new_file.name = new_file_name + new_related_file = RelatedFile.objects.create( + file=new_file, + filename=fname, + content_type=content_type, + creator=self.context['request'].user, + store_as_filename=True, + ) + bucket = default_storage.bucket + stored_file = default_storage.open(new_related_file.file.name) + stored_file.obj.copy({"Bucket": bucket.name, "Key": old_file_name}) + stored_file.obj.wait_until_exists() + + elif hasattr(default_storage, 'azure_container'): + # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-blob-copy?tabs=python + new_file_name = default_storage.get_alternative_name(old_file_name, '') + new_blobname = '/'.join([default_storage.location, path.basename(new_file_name)]) + + # Copies a blob asynchronously. + source_blob = default_storage.client.get_blob_client(old_file_name) + dest_blob = default_storage.client.get_blob_client(new_blobname) + + try: + lease = BlobLeaseClient(source_blob) + lease.acquire() + dest_blob.start_copy_from_url(source_blob.url) + wait_for_blob_copy(dest_blob) + lease.break_lease() + except Exception as e: + # copy failed, break file lease and re-raise + lease.break_lease() + raise e + + stored_blob = default_storage.open(new_file_name) + new_related_file = RelatedFile.objects.create( + file=File(stored_blob, name=new_file_name), + filename=fname, + content_type=content_type, + creator=self.context['request'].user, + store_as_filename=True, + ) + + # Shared-fs + else: + stored_file = default_storage.open(old_file_name) + new_file = File(stored_file, name=new_file_name) + new_related_file = RelatedFile.objects.create( + file=new_file, + filename=fname, + content_type=content_type, + creator=self.context['request'].user, + store_as_filename=True, + ) + + # Mark prev ref for deleation if it exisits + if hasattr(instance, field): + prev_file = getattr(instance, field) + if prev_file: + files_for_removal.append(prev_file) + + # Set new file ref + setattr(instance, field, new_related_file) + + # Update & Delete prev linked files + instance.save(update_fields=[k for k in validated_data]) + for f in files_for_removal: + f.delete() + return instance + + +class CreateAnalysisSerializer(AnalysisSerializer): + class Meta(AnalysisSerializer.Meta): + ref_name = __qualname__.split('.')[0] + 'V1' + fields = ['name', 'model'] + + def __init__(self, portfolio=None, *args, **kwargs): + self.portfolio = portfolio + super(CreateAnalysisSerializer, self).__init__(*args, **kwargs) + + def validate(self, attrs): + attrs['portfolio'] = self.portfolio + if not self.portfolio.location_file: + raise ValidationError({'portfolio': '"location_file" must not be null'}) + + return attrs + + def create(self, validated_data): + data = dict(validated_data) + if 'request' in self.context: + data['creator'] = self.context.get('request').user + return super(CreateAnalysisSerializer, self).create(data) + + +class PortfolioValidationSerializer(serializers.ModelSerializer): + accounts_validated = serializers.SerializerMethodField() + location_validated = serializers.SerializerMethodField() + reinsurance_info_validated = serializers.SerializerMethodField() + reinsurance_scope_validated = serializers.SerializerMethodField() + + class Meta: + ref_name = __qualname__.split('.')[0] + 'V1' + model = Portfolio + fields = ( + 'location_validated', + 'accounts_validated', + 'reinsurance_info_validated', + 'reinsurance_scope_validated', + ) + + @swagger_serializer_method(serializer_or_field=serializers.CharField) # should it be BooleanField ? + def get_location_validated(self, instance): + if instance.location_file: + return instance.location_file.oed_validated + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_accounts_validated(self, instance): + if instance.accounts_file: + return instance.accounts_file.oed_validated + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_reinsurance_info_validated(self, instance): + if instance.reinsurance_info_file: + return instance.reinsurance_info_file.oed_validated + + @swagger_serializer_method(serializer_or_field=serializers.CharField) + def get_reinsurance_scope_validated(self, instance): + if instance.reinsurance_scope_file: + return instance.reinsurance_scope_file.oed_validated diff --git a/src/server/oasisapi/portfolios/v1_api/tests/__init__.py b/src/server/oasisapi/portfolios/v1_api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/portfolios/tests/fakes.py b/src/server/oasisapi/portfolios/v1_api/tests/fakes.py similarity index 63% rename from src/server/oasisapi/portfolios/tests/fakes.py rename to src/server/oasisapi/portfolios/v1_api/tests/fakes.py index 65e90ddbe..5d69077cf 100644 --- a/src/server/oasisapi/portfolios/tests/fakes.py +++ b/src/server/oasisapi/portfolios/v1_api/tests/fakes.py @@ -1,6 +1,6 @@ from model_mommy import mommy -from ..models import Portfolio +from src.server.oasisapi.portfolios.models import Portfolio def fake_portfolio(**kwargs): diff --git a/src/server/oasisapi/portfolios/v1_api/tests/test_portfolio.py b/src/server/oasisapi/portfolios/v1_api/tests/test_portfolio.py new file mode 100644 index 000000000..015cbb28e --- /dev/null +++ b/src/server/oasisapi/portfolios/v1_api/tests/test_portfolio.py @@ -0,0 +1,1339 @@ +import json +import mimetypes +import string +import io +import pandas as pd + +from backports.tempfile import TemporaryDirectory +from django.test import override_settings +from django.urls import reverse +from django_webtest import WebTestMixin +from hypothesis import given, settings +from hypothesis.extra.django import TestCase +from hypothesis.strategies import text, binary, sampled_from +from mock import patch +from rest_framework_simplejwt.tokens import AccessToken +from ods_tools.oed.exposure import OedExposure + +from src.server.oasisapi.files.tests.fakes import fake_related_file +from src.server.oasisapi.analysis_models.v1_api.tests.fakes import fake_analysis_model +from src.server.oasisapi.analyses.models import Analysis +from src.server.oasisapi.auth.tests.fakes import fake_user +from src.server.oasisapi.portfolios.models import Portfolio +from .fakes import fake_portfolio + +# Override default deadline for all tests to 8s +settings.register_profile("ci", deadline=800.0) +settings.load_profile("ci") +NAMESPACE = 'v1-portfolios' + + +class PortfolioApi(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + portfolio = fake_portfolio() + + response = self.app.get(portfolio.get_absolute_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.get( + reverse(f'{NAMESPACE}:portfolio-detail', kwargs={'pk': portfolio.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_name_is_not_provided___response_is_400(self): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:portfolio-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params={}, + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + + @given(name=text(alphabet=' \t\n\r', max_size=10)) + def test_cleaned_name_is_empty___response_is_400(self, name): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:portfolio-list'), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name}), + content_type='application/json' + ) + + self.assertEqual(400, response.status_code) + + @given(name=text(alphabet=string.ascii_letters, max_size=10, min_size=1)) + def test_cleaned_name_is_present___object_is_created(self, name): + self.maxDiff = None + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + user = fake_user() + + response = self.app.post( + reverse(f'{NAMESPACE}:portfolio-list'), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name}), + content_type='application/json' + ) + self.assertEqual(201, response.status_code) + + portfolio = Portfolio.objects.get(pk=response.json['id']) + portfolio.accounts_file = fake_related_file() + portfolio.location_file = fake_related_file() + portfolio.reinsurance_scope_file = fake_related_file() + portfolio.reinsurance_info_file = fake_related_file() + portfolio.save() + + response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(200, response.status_code) + self.assertEqual({ + 'id': portfolio.pk, + 'name': name, + 'created': portfolio.created.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + 'modified': portfolio.modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), + 'accounts_file': { + "uri": response.request.application_url + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + "name": portfolio.accounts_file.filename, + "stored": str(portfolio.accounts_file.file) + }, + 'location_file': { + "uri": response.request.application_url + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + "name": portfolio.location_file.filename, + "stored": str(portfolio.location_file.file) + }, + 'reinsurance_info_file': { + "uri": response.request.application_url + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + "name": portfolio.reinsurance_info_file.filename, + "stored": str(portfolio.reinsurance_info_file.file) + }, + 'reinsurance_scope_file': { + "uri": response.request.application_url + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + "name": portfolio.reinsurance_scope_file.filename, + "stored": str(portfolio.reinsurance_scope_file.file) + }, + 'storage_links': response.request.application_url + portfolio.get_absolute_storage_url(namespace=NAMESPACE) + }, response.json) + + +class PortfolioApiCreateAnalysis(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + portfolio = fake_portfolio() + + response = self.app.get(portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_user_is_authenticated_object_does_not_exist___response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + reverse(f'{NAMESPACE}:portfolio-create-analysis', kwargs={'pk': portfolio.pk + 1}), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + } + ) + + self.assertEqual(404, response.status_code) + + def test_name_is_not_provided___response_is_400(self): + user = fake_user() + model = fake_analysis_model() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params={'model': model.pk}, + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + + def test_portfolio_does_not_have_location_file_set___response_is_400(self): + user = fake_user() + model = fake_analysis_model() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': 'name', 'model': model.pk}), + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + self.assertIn('"location_file" must not be null', response.json['portfolio']) + + def test_model_is_not_provided___response_is_400(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': 'name'}), + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + + def test_model_does_not_exist___response_is_400(self): + user = fake_user() + portfolio = fake_portfolio() + model = fake_analysis_model() + + response = self.app.post( + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': 'name', 'model': model.pk + 1}), + content_type='application/json', + ) + + self.assertEqual(400, response.status_code) + + @given(name=text(alphabet=' \t\n\r', max_size=10)) + def test_cleaned_name_is_empty___response_is_400(self, name): + user = fake_user() + model = fake_analysis_model() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), + expect_errors=True, + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name, 'model': model.pk}), + content_type='application/json' + ) + + self.assertEqual(400, response.status_code) + + @given(name=text(alphabet=string.ascii_letters, max_size=10, min_size=1)) + def test_cleaned_name_and_model_are_present___object_is_created_inputs_are_generated(self, name): + with patch('src.server.oasisapi.analyses.models.Analysis.generate_inputs', autospec=True) as generate_mock: + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d): + self.maxDiff = None + + user = fake_user() + model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V1 + model.save() + portfolio = fake_portfolio(location_file=fake_related_file()) + + response = self.app.post( + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + params=json.dumps({'name': name, 'model': model.pk}), + content_type='application/json' + ) + self.assertEqual(201, response.status_code) + + analysis = Analysis.objects.get(pk=response.json['id']) + analysis.settings_file = fake_related_file() + analysis.input_file = fake_related_file() + analysis.lookup_errors_file = fake_related_file() + analysis.lookup_success_file = fake_related_file() + analysis.lookup_validation_file = fake_related_file() + analysis.input_generation_traceback_file = fake_related_file() + analysis.output_file = fake_related_file() + analysis.run_traceback_file = fake_related_file() + analysis.save() + + ANALYSES_NAMESPACE = 'v1-analyses' + response = self.app.get( + analysis.get_absolute_url(namespace=ANALYSES_NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(200, response.status_code) + self.assertEqual(response.json['id'], analysis.pk) + self.assertEqual(response.json['created'], analysis.created.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + self.assertEqual(response.json['modified'], analysis.modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + self.assertEqual(response.json['name'], name) + self.assertEqual(response.json['portfolio'], portfolio.pk) + self.assertEqual(response.json['model'], model.pk) + self.assertEqual(response.json['settings_file'], response.request.application_url + + analysis.get_absolute_settings_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['input_file'], response.request.application_url + + analysis.get_absolute_input_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['lookup_errors_file'], response.request.application_url + + analysis.get_absolute_lookup_errors_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['lookup_success_file'], response.request.application_url + + analysis.get_absolute_lookup_success_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['lookup_validation_file'], response.request.application_url + + analysis.get_absolute_lookup_validation_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['input_generation_traceback_file'], response.request.application_url + + analysis.get_absolute_input_generation_traceback_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['output_file'], response.request.application_url + + analysis.get_absolute_output_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['run_traceback_file'], response.request.application_url + + analysis.get_absolute_run_traceback_file_url(namespace=ANALYSES_NAMESPACE)) + generate_mock.assert_called_once_with(analysis, user, run_mode_override='V1') + + +class PortfolioAccountsFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + portfolio = fake_portfolio() + + response = self.app.get(portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_accounts_file_is_not_present___get_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.get( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_accounts_file_is_not_present___delete_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.delete( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_accounts_file_is_not_a_valid_format___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.tar', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv', 'application/json'])) + def test_accounts_file_is_uploaded___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=False, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + response = self.app.get( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + def test_accounts_file_invalid_uploaded___parquet_exception_raised(self): + content_type = 'text/csv' + file_content = b'\xf2hb\xca\xd2\xe6\xf3\xb0\xc1\xc7' + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + expect_errors=True + ) + self.assertEqual(400, response.status_code) + + def test_accounts_file_is_uploaded_as_parquet___file_can_be_retrieved(self): + content_type = 'text/csv' + test_data = pd.DataFrame.from_dict({"A": [1, 2, 3], "B": [4, 5, 6]}) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + csv_response = self.app.get( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE) + '?file_format=csv', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + parquet_response = self.app.get( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE) + '?file_format=parquet', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + csv_obj = io.StringIO(csv_response.text) + prq_obj = io.BytesIO(parquet_response.content) + setattr(prq_obj, 'name', 'account.parquet') + + input_data = OedExposure(account=test_data) + return_csv = OedExposure(account=csv_obj) + return_prq = OedExposure(account=prq_obj) + + pd.testing.assert_frame_equal( + return_csv.account.dataframe, + input_data.account.dataframe) + pd.testing.assert_frame_equal( + return_prq.account.dataframe, + input_data.account.dataframe) + + +class PortfolioLocationFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + portfolio = fake_portfolio() + + response = self.app.get(portfolio.get_absolute_location_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_location_file_is_not_present___get_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.get( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_location_file_is_not_present___delete_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.delete( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_location_file_is_not_a_valid_format___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.tar', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv', 'application/json'])) + def test_location_file_is_uploaded___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=False, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + response = self.app.get( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + def test_location_file_invalid_uploaded___parquet_exception_raised(self): + content_type = 'text/csv' + file_content = b'\xf2hb\xca\xd2\xe6\xf3\xb0\xc1\xc7' + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + expect_errors=True + ) + self.assertEqual(400, response.status_code) + + def test_location_file_is_uploaded_as_parquet___file_can_be_retrieved(self): + content_type = 'text/csv' + test_data = pd.DataFrame.from_dict({"A": [1, 2, 3], "B": [4, 5, 6]}) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + csv_response = self.app.get( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE) + '?file_format=csv', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + parquet_response = self.app.get( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE) + '?file_format=parquet', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + csv_obj = io.StringIO(csv_response.text) + prq_obj = io.BytesIO(parquet_response.content) + setattr(prq_obj, 'name', 'location.parquet') + + input_data = OedExposure(location=test_data) + return_csv = OedExposure(location=csv_obj) + return_prq = OedExposure(location=prq_obj) + + pd.testing.assert_frame_equal( + return_csv.location.dataframe, + input_data.location.dataframe) + pd.testing.assert_frame_equal( + return_prq.location.dataframe, + input_data.location.dataframe) + + +class PortfolioReinsuranceSourceFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + portfolio = fake_portfolio() + + response = self.app.get(portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_reinsurance_scope_file_is_not_present___get_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.get( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_reinsurance_scope_file_is_not_present___delete_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.delete( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_reinsurance_scope_file_is_not_a_valid_format___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.tar', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv', 'application/json'])) + def test_reinsurance_scope_file_is_uploaded___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=False, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + response = self.app.get( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + def test_reinsurance_scope_file_invalid_uploaded___parquet_exception_raised(self): + content_type = 'text/csv' + file_content = b'\xf2hb\xca\xd2\xe6\xf3\xb0\xc1\xc7' + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + expect_errors=True + ) + self.assertEqual(400, response.status_code) + + def test_reinsurance_scope_file_is_uploaded_as_parquet___file_can_be_retrieved(self): + content_type = 'text/csv' + test_data = pd.DataFrame.from_dict({"A": [1, 2, 3], "B": [4, 5, 6]}) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + csv_response = self.app.get( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE) + '?file_format=csv', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + parquet_response = self.app.get( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE) + '?file_format=parquet', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + csv_obj = io.StringIO(csv_response.text) + prq_obj = io.BytesIO(parquet_response.content) + setattr(prq_obj, 'name', 'ri_scope.parquet') + + input_data = OedExposure(ri_scope=test_data) + return_csv = OedExposure(ri_scope=csv_obj) + return_prq = OedExposure(ri_scope=prq_obj) + + pd.testing.assert_frame_equal( + return_csv.ri_scope.dataframe, + input_data.ri_scope.dataframe) + pd.testing.assert_frame_equal( + return_prq.ri_scope.dataframe, + input_data.ri_scope.dataframe) + + +class PortfolioReinsuranceInfoFile(WebTestMixin, TestCase): + def test_user_is_not_authenticated___response_is_forbidden(self): + portfolio = fake_portfolio() + + response = self.app.get(portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), expect_errors=True) + self.assertIn(response.status_code, [401, 403]) + + def test_reinsurance_info_file_is_not_present___get_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.get( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_reinsurance_info_file_is_not_present___delete_response_is_404(self): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.delete( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + + self.assertEqual(404, response.status_code) + + def test_reinsurance_info_file_is_not_a_valid_format___response_is_400(self): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file.tar', b'content'), + ), + expect_errors=True, + ) + + self.assertEqual(400, response.status_code) + + @given(file_content=binary(min_size=1), content_type=sampled_from(['text/csv', 'application/json'])) + def test_reinsurance_info_file_is_uploaded___file_can_be_retrieved(self, file_content, content_type): + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=False, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + response = self.app.get( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + self.assertEqual(response.body, file_content) + self.assertEqual(response.content_type, content_type) + + def test_reinsurance_info_file_invalid_uploaded___parquet_exception_raised(self): + content_type = 'text/csv' + file_content = b'\xf2hb\xca\xd2\xe6\xf3\xb0\xc1\xc7' + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + response = self.app.post( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + expect_errors=True + ) + self.assertEqual(400, response.status_code) + + def test_reinsurance_info_file_is_uploaded_as_parquet___file_can_be_retrieved(self): + content_type = 'text/csv' + test_data = pd.DataFrame.from_dict({"A": [1, 2, 3], "B": [4, 5, 6]}) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_PARQUET_STORAGE=True, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + csv_response = self.app.get( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE) + '?file_format=csv', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + parquet_response = self.app.get( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE) + '?file_format=parquet', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + csv_obj = io.StringIO(csv_response.text) + prq_obj = io.BytesIO(parquet_response.content) + setattr(prq_obj, 'name', 'ri_info.parquet') + + input_data = OedExposure(ri_info=test_data) + return_csv = OedExposure(ri_info=csv_obj) + return_prq = OedExposure(ri_info=prq_obj) + + pd.testing.assert_frame_equal( + return_csv.ri_info.dataframe, + input_data.ri_info.dataframe) + pd.testing.assert_frame_equal( + return_prq.ri_info.dataframe, + input_data.ri_info.dataframe) + + +LOCATION_DATA_VALID = """PortNumber,AccNumber,LocNumber,IsTenant,BuildingID,CountryCode,Latitude,Longitude,StreetAddress,PostalCode,OccupancyCode,ConstructionCode,LocPerilsCovered,BuildingTIV,OtherTIV,ContentsTIV,BITIV,LocCurrency,OEDVersion +1,A11111,10002082046,1,1,GB,52.76698052,-0.895469856,1 ABINGDON ROAD,LE13 0HL,1050,5000,WW1,220000,0,0,0,GBP,2.0.0 +1,A11111,10002082047,1,1,GB,52.76697956,-0.89536613,2 ABINGDON ROAD,LE13 0HL,1050,5000,WW1,790000,0,0,0,GBP,2.0.0 +1,A11111,10002082048,1,1,GB,52.76697845,-0.895247587,3 ABINGDON ROAD,LE13 0HL,1050,5000,WW1,160000,0,0,0,GBP,2.0.0 +1,A11111,10002082049,1,1,GB,52.76696096,-0.895473908,4 ABINGDON ROAD,LE13 0HL,1050,5000,WW1,30000,0,0,0,GBP,2.0.0 +""" + +ACCOUNT_DATA_VALID = """PortNumber,AccNumber,AccCurrency,PolNumber,PolPerilsCovered,PolInceptionDate,PolExpiryDate,LayerNumber,LayerParticipation,LayerLimit,LayerAttachment,OEDVersion +1,A11111,GBP,Layer1,WW1,2018-01-01,2018-12-31,1,0.3,5000000,500000,2.0.0 +1,A11111,GBP,Layer2,WW1,2018-01-01,2018-12-31,2,0.3,100000000,5500000,2.0.0 +""" + +INFO_DATA_VALID = """ReinsNumber,ReinsLayerNumber,ReinsName,ReinsPeril,ReinsInceptionDate,ReinsExpiryDate,CededPercent,RiskLimit,RiskAttachment,OccLimit,OccAttachment,PlacedPercent,ReinsCurrency,InuringPriority,ReinsType,RiskLevel,UseReinsDates,OEDVersion +1,1,ABC QS,WW1,2018-01-01,2018-12-31,1,0,0,0,0,1,GBP,1,SS,LOC,N,2.0.0 +""" + +SCOPE_DATA_VALID = """ReinsNumber,PortNumber,AccNumber,PolNumber,LocGroup,LocNumber,CedantName,ProducerName,LOB,CountryCode,ReinsTag,CededPercent,OEDVersion +1,1,A11111,,,10002082047,,,,GB,,0.1,2.0.0 +1,1,A11111,,,10002082048,,,,GB,,0.2,2.0.0 +""" + +LOCATION_DATA_INVALID = """Port,AccNumber,LocNumb,IsTenant,BuildingID,CountryCode,Latitude,Longitude,Street,PostalCode,OccupancyCode,ConstructionCode,LocPerilsCovered,BuildingTIV,OtherTIV,ContentsTIV,BITIV,LocCurrency,OEDVersion +1,A11111,10002082046,1,1,GB,52.76698052,-0.895469856,1 ABINGDON ROAD,LE13 0HL,1050,5000,WW1,220000,0,0,0,GBP,2.0.0 +1,A11111,10002082047,1,1,GB,52.76697956,-0.89536613,2 ABINGDON ROAD,LE13 0HL,1050,5000,XXYA,790000,0,0,0,GBP,2.0.0 +1,A11111,10002082048,1,1,GB,52.76697845,,3 ABINGDON ROAD,LE13 0HL,1050,5000,WW1,160000,0,0,0,GBP,2.0.0 +1,A11111,10002082049,1,1,GB,52.76696096,-0.895473908,4 ABINGDON ROAD,LE13 0HL,1050,-1,WW1,30000,0,0,0,GBP,2.0.0 +""" + + +class PortfolioValidation(WebTestMixin, TestCase): + + def test_all_exposure__are_valid(self): + content_type = 'text/csv' + loc_data = pd.read_csv(io.StringIO(LOCATION_DATA_VALID)) + acc_data = pd.read_csv(io.StringIO(ACCOUNT_DATA_VALID)) + inf_data = pd.read_csv(io.StringIO(INFO_DATA_VALID)) + scp_data = pd.read_csv(io.StringIO(SCOPE_DATA_VALID)) + + loc_file_content = loc_data.to_csv(index=False).encode('utf-8') + acc_file_content = acc_data.to_csv(index=False).encode('utf-8') + inf_file_content = inf_data.to_csv(index=False).encode('utf-8') + scp_file_content = scp_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), loc_file_content), + ), + ) + self.app.post( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), acc_file_content), + ), + ) + self.app.post( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), inf_file_content), + ), + ) + self.app.post( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), scp_file_content), + ), + ) + + validate_response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + # Get current validate status - Not yet run + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': False, + 'accounts_validated': False, + 'reinsurance_info_validated': False, + 'reinsurance_scope_validated': False}) + + # Run validate - check is valid + validate_response = self.app.post( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': True, + 'accounts_validated': True, + 'reinsurance_info_validated': True, + 'reinsurance_scope_validated': True}) + + def test_location_file__is_valid(self): + content_type = 'text/csv' + test_data = pd.read_csv(io.StringIO(LOCATION_DATA_VALID)) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + validate_response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + # Get current validate status - Not yet run + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': False, + 'accounts_validated': None, + 'reinsurance_info_validated': None, + 'reinsurance_scope_validated': None}) + + # Run validate - check is valid + validate_response = self.app.post( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': True, + 'accounts_validated': None, + 'reinsurance_info_validated': None, + 'reinsurance_scope_validated': None}) + + def test_location_file__is_invalid__response_is_400(self): + content_type = 'text/csv' + test_data = pd.read_csv(io.StringIO(LOCATION_DATA_INVALID)) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + validate_response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + # Get current validate status - Not yet run + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': False, + 'accounts_validated': None, + 'reinsurance_info_validated': None, + 'reinsurance_scope_validated': None}) + + # Run validate - check is invalid + validate_response = self.app.post( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + self.assertEqual(400, validate_response.status_code) + self.assertEqual(validate_response.json, [ + ['location', 'missing required column PortNumber'], + ['location', 'missing required column LocNumber'], + ['location', "column 'Port' is not a valid oed field"], + ['location', "column 'LocNumb' is not a valid oed field"], + ['location', "column 'Street' is not a valid oed field"], + ['location', 'LocPerilsCovered has invalid perils.\n AccNumber LocPerilsCovered\n1 A11111 XXYA'], + ['location', 'invalid ConstructionCode.\n AccNumber ConstructionCode\n3 A11111 -1'] + ]) + + def test_account_file__is_invalid__response_is_400(self): + content_type = 'text/csv' + test_data = pd.read_csv(io.StringIO(LOCATION_DATA_VALID)) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + validate_response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + # Get current validate status - Not yet run + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': None, + 'accounts_validated': False, + 'reinsurance_info_validated': None, + 'reinsurance_scope_validated': None}) + + # Run validate - check is valid + validate_response = self.app.post( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + self.assertEqual(400, validate_response.status_code) + self.assertEqual(validate_response.json, [ + ['account', 'missing required column AccCurrency'], + ['account', 'missing required column PolNumber'], + ['account', 'missing required column PolPerilsCovered'], + ['account', "column 'LocNumber' is not a valid oed field"], + ['account', "column 'IsTenant' is not a valid oed field"], + ['account', "column 'BuildingID' is not a valid oed field"], + ['account', "column 'CountryCode' is not a valid oed field"], + ['account', "column 'Latitude' is not a valid oed field"], + ['account', "column 'Longitude' is not a valid oed field"], + ['account', "column 'StreetAddress' is not a valid oed field"], + ['account', "column 'PostalCode' is not a valid oed field"], + ['account', "column 'OccupancyCode' is not a valid oed field"], + ['account', "column 'ConstructionCode' is not a valid oed field"], + ['account', "column 'LocPerilsCovered' is not a valid oed field"], + ['account', "column 'BuildingTIV' is not a valid oed field"], + ['account', "column 'OtherTIV' is not a valid oed field"], + ['account', "column 'ContentsTIV' is not a valid oed field"], + ['account', "column 'BITIV' is not a valid oed field"], + ['account', "column 'LocCurrency' is not a valid oed field"] + ]) + + def test_reinsurance_info_file__is_invalid__response_is_400(self): + content_type = 'text/csv' + test_data = pd.read_csv(io.StringIO(LOCATION_DATA_VALID)) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + validate_response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + # Get current validate status - Not yet run + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': None, + 'accounts_validated': None, + 'reinsurance_info_validated': False, + 'reinsurance_scope_validated': None}) + + # Run validate - check is valid + validate_response = self.app.post( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + self.assertEqual(400, validate_response.status_code) + self.assertEqual(validate_response.json, [ + ['ri_info', 'missing required column ReinsNumber'], + ['ri_info', 'missing required column ReinsPeril'], + ['ri_info', 'missing required column PlacedPercent'], + ['ri_info', 'missing required column ReinsCurrency'], + ['ri_info', 'missing required column InuringPriority'], + ['ri_info', 'missing required column ReinsType'], + ['ri_info', "column 'PortNumber' is not a valid oed field"], + ['ri_info', "column 'AccNumber' is not a valid oed field"], + ['ri_info', "column 'LocNumber' is not a valid oed field"], + ['ri_info', "column 'IsTenant' is not a valid oed field"], + ['ri_info', "column 'BuildingID' is not a valid oed field"], + ['ri_info', "column 'CountryCode' is not a valid oed field"], + ['ri_info', "column 'Latitude' is not a valid oed field"], + ['ri_info', "column 'Longitude' is not a valid oed field"], + ['ri_info', "column 'StreetAddress' is not a valid oed field"], + ['ri_info', "column 'PostalCode' is not a valid oed field"], + ['ri_info', "column 'OccupancyCode' is not a valid oed field"], + ['ri_info', "column 'ConstructionCode' is not a valid oed field"], + ['ri_info', "column 'LocPerilsCovered' is not a valid oed field"], + ['ri_info', "column 'BuildingTIV' is not a valid oed field"], + ['ri_info', "column 'OtherTIV' is not a valid oed field"], + ['ri_info', "column 'ContentsTIV' is not a valid oed field"], + ['ri_info', "column 'BITIV' is not a valid oed field"], + ['ri_info', "column 'LocCurrency' is not a valid oed field"] + ]) + + def test_reinsurance_scope_file__is_invalid__response_is_400(self): + content_type = 'text/csv' + test_data = pd.read_csv(io.StringIO(LOCATION_DATA_VALID)) + file_content = test_data.to_csv(index=False).encode('utf-8') + + with TemporaryDirectory() as d: + with override_settings(MEDIA_ROOT=d, PORTFOLIO_UPLOAD_VALIDATION=False): + user = fake_user() + portfolio = fake_portfolio() + + self.app.post( + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + upload_files=( + ('file', 'file{}'.format(mimetypes.guess_extension(content_type)), file_content), + ), + ) + + validate_response = self.app.get( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + ) + + # Get current validate status - Not yet run + self.assertEqual(200, validate_response.status_code) + self.assertEqual(validate_response.json, { + 'location_validated': None, + 'accounts_validated': None, + 'reinsurance_info_validated': None, + 'reinsurance_scope_validated': False}) + + # Run validate - check is valid + validate_response = self.app.post( + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', + headers={ + 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) + }, + expect_errors=True, + ) + self.assertEqual(400, validate_response.status_code) + self.assertEqual(validate_response.json, [ + ['ri_scope', 'missing required column ReinsNumber'], + ['ri_scope', "column 'IsTenant' is not a valid oed field"], + ['ri_scope', "column 'BuildingID' is not a valid oed field"], + ['ri_scope', "column 'Latitude' is not a valid oed field"], + ['ri_scope', "column 'Longitude' is not a valid oed field"], + ['ri_scope', "column 'StreetAddress' is not a valid oed field"], + ['ri_scope', "column 'PostalCode' is not a valid oed field"], + ['ri_scope', "column 'OccupancyCode' is not a valid oed field"], + ['ri_scope', "column 'ConstructionCode' is not a valid oed field"], + ['ri_scope', "column 'LocPerilsCovered' is not a valid oed field"], + ['ri_scope', "column 'BuildingTIV' is not a valid oed field"], + ['ri_scope', "column 'OtherTIV' is not a valid oed field"], + ['ri_scope', "column 'ContentsTIV' is not a valid oed field"], + ['ri_scope', "column 'BITIV' is not a valid oed field"], + ['ri_scope', "column 'LocCurrency' is not a valid oed field"] + ]) diff --git a/src/server/oasisapi/portfolios/v1_api/urls.py b/src/server/oasisapi/portfolios/v1_api/urls.py new file mode 100644 index 000000000..972a7d84b --- /dev/null +++ b/src/server/oasisapi/portfolios/v1_api/urls.py @@ -0,0 +1,13 @@ +from django.conf.urls import include, url +from rest_framework.routers import SimpleRouter +from .viewsets import PortfolioViewSet + + +app_name = 'portfolios' +v1_api_router = SimpleRouter() +v1_api_router.include_root_view = False +v1_api_router.register('portfolios', PortfolioViewSet, basename='portfolio') + +urlpatterns = [ + url(r'', include(v1_api_router.urls)), +] diff --git a/src/server/oasisapi/portfolios/v1_api/viewsets.py b/src/server/oasisapi/portfolios/v1_api/viewsets.py new file mode 100644 index 000000000..0f2a891c1 --- /dev/null +++ b/src/server/oasisapi/portfolios/v1_api/viewsets.py @@ -0,0 +1,267 @@ +from __future__ import absolute_import + +from django.utils.translation import gettext_lazy as _ +from django_filters import rest_framework as filters +from django.conf import settings as django_settings +from django.utils.decorators import method_decorator +from drf_yasg.utils import swagger_auto_schema +from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.status import HTTP_201_CREATED + +from ...filters import TimeStampedFilter +from ...analyses.v1_api.serializers import AnalysisSerializer +from ...files.views import handle_related_file +from ...files.serializers import RelatedFileSerializer +from ..models import Portfolio +from ...schemas.custom_swagger import FILE_RESPONSE, FILE_FORMAT_PARAM, FILE_VALIDATION_PARAM +from ...schemas.serializers import StorageLinkSerializer +from .serializers import ( + PortfolioSerializer, + CreateAnalysisSerializer, + PortfolioStorageSerializer, + PortfolioListSerializer, + PortfolioValidationSerializer +) + + +class PortfolioFilter(TimeStampedFilter): + name = filters.CharFilter(help_text=_('Filter results by case insensitive names equal to the given string'), lookup_expr='iexact') + name__contains = filters.CharFilter(help_text=_( + 'Filter results by case insensitive name containing the given string'), lookup_expr='icontains', field_name='name') + user = filters.CharFilter( + help_text=_('Filter results by case insensitive `user` equal to the given string'), + lookup_expr='iexact', + field_name='creator__username' + ) + + class Meta: + model = Portfolio + fields = [ + 'name', + 'name__contains', + 'user', + ] + + +@method_decorator(name='list', decorator=swagger_auto_schema(responses={200: PortfolioSerializer(many=True)})) +class PortfolioViewSet(viewsets.ModelViewSet): + """ + list: + Returns a list of Portfolio objects. + + ### Examples + + To get all portfolios with 'foo' in their name + + /portfolio/?name__contains=foo + + To get all portfolios created on 1970-01-01 + + /portfolio/?created__date=1970-01-01 + + To get all portfolios updated before 2000-01-01 + + /portfolio/?modified__lt=2000-01-01 + + retrieve: + Returns the specific portfolio entry. + + create: + Creates a portfolio based on the input data + + update: + Updates the specified portfolio + + partial_update: + Partially updates the specified portfolio (only provided fields are updated) + """ + + queryset = Portfolio.objects.all().select_related( + 'location_file', + 'accounts_file', + 'reinsurance_scope_file', + 'reinsurance_info_file' + ) + serializer_class = PortfolioSerializer + filterset_class = PortfolioFilter + + supported_mime_types = [ + 'application/json', + 'text/csv', + 'application/gzip', + 'application/x-bzip2', + 'application/zip', + 'application/x-bzip2', + 'application/octet-stream', + ] + + def get_serializer_class(self): + if self.action == 'create_analysis': + return CreateAnalysisSerializer + elif self.action in ['list']: + return PortfolioListSerializer + elif self.action in ['set_storage_links', 'storage_links']: + return PortfolioStorageSerializer + elif self.action in ['validate']: + return PortfolioValidationSerializer + elif self.action in [ + 'accounts_file', 'location_file', 'reinsurance_info_file', 'reinsurance_scope_file', + ]: + return RelatedFileSerializer + else: + return super(PortfolioViewSet, self).get_serializer_class() + + @property + def parser_classes(self): + upload_views = ['accounts_file', 'location_file', 'reinsurance_info_file', 'reinsurance_scope_file'] + if getattr(self, 'action', None) in upload_views: + return [MultiPartParser] + else: + return api_settings.DEFAULT_PARSER_CLASSES + + @action(methods=['post'], detail=True) + def create_analysis(self, request, pk=None, version=None): + """ + Creates an analysis object from the portfolio. + """ + portfolio = self.get_object() + serializer = self.get_serializer(data=request.data, portfolio=portfolio, context=self.get_serializer_context()) + serializer.is_valid(raise_exception=True) + analysis = serializer.create(serializer.validated_data) + analysis.generate_inputs(request.user, run_mode_override='V1') + + return Response( + AnalysisSerializer(instance=analysis, context=self.get_serializer_context()).data, + status=HTTP_201_CREATED, + ) + + @swagger_auto_schema(methods=['post'], request_body=StorageLinkSerializer) + @action(methods=['get', 'post'], detail=True) + def storage_links(self, request, pk=None, version=None): + """ + get: + Gets the portfolios storage backed link references, `object keys` or `file paths` + """ + method = request.method.lower() + if method == 'get': + serializer = self.get_serializer(self.get_object()) + return Response(serializer.data) + else: + serializer = self.get_serializer(self.get_object(), data=request.data) + serializer.is_valid(raise_exception=True) + serializer.save() + return Response(serializer.data) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}, manual_parameters=[FILE_FORMAT_PARAM]) + @swagger_auto_schema(methods=['post'], manual_parameters=[FILE_VALIDATION_PARAM]) + @action(methods=['get', 'post', 'delete'], detail=True) + def accounts_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `accounts_file` contents + + post: + Sets the portfolios `accounts_file` contents + + delete: + Disassociates the portfolios `accounts_file` with the portfolio + """ + method = request.method.lower() + if method == 'post': + store_as_parquet = django_settings.PORTFOLIO_PARQUET_STORAGE + oed_validate = request.GET.get('validate', str(django_settings.PORTFOLIO_UPLOAD_VALIDATION)).lower() == 'true' + else: + store_as_parquet = None + oed_validate = None + return handle_related_file(self.get_object(), 'accounts_file', request, self.supported_mime_types, store_as_parquet, oed_validate) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}, manual_parameters=[FILE_FORMAT_PARAM]) + @swagger_auto_schema(methods=['post'], manual_parameters=[FILE_VALIDATION_PARAM]) + @action(methods=['get', 'post', 'delete'], detail=True) + def location_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `location_file` contents + + post: + Sets the portfolios `location_file` contents + + delete: + Disassociates the portfolios `location_file` contents + """ + method = request.method.lower() + if method == 'post': + store_as_parquet = django_settings.PORTFOLIO_PARQUET_STORAGE + oed_validate = request.GET.get('validate', str(django_settings.PORTFOLIO_UPLOAD_VALIDATION)).lower() == 'true' + else: + store_as_parquet = None + oed_validate = None + return handle_related_file(self.get_object(), 'location_file', request, self.supported_mime_types, store_as_parquet, oed_validate) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}, manual_parameters=[FILE_FORMAT_PARAM]) + @swagger_auto_schema(methods=['post'], manual_parameters=[FILE_VALIDATION_PARAM]) + @action(methods=['get', 'post', 'delete'], detail=True) + def reinsurance_info_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `reinsurance_info_file` contents + + post: + Sets the portfolios `reinsurance_info_file` contents + + delete: + Disassociates the portfolios `reinsurance_info_file` contents + """ + method = request.method.lower() + if method == 'post': + store_as_parquet = django_settings.PORTFOLIO_PARQUET_STORAGE + oed_validate = request.GET.get('validate', str(django_settings.PORTFOLIO_UPLOAD_VALIDATION)).lower() == 'true' + else: + store_as_parquet = None + oed_validate = None + return handle_related_file(self.get_object(), 'reinsurance_info_file', request, self.supported_mime_types, store_as_parquet, oed_validate) + + @swagger_auto_schema(methods=['get'], responses={200: FILE_RESPONSE}, manual_parameters=[FILE_FORMAT_PARAM]) + @swagger_auto_schema(methods=['post'], manual_parameters=[FILE_VALIDATION_PARAM]) + @action(methods=['get', 'post', 'delete'], detail=True) + def reinsurance_scope_file(self, request, pk=None, version=None): + """ + get: + Gets the portfolios `reinsurance_scope_file` contents + + post: + Sets the portfolios `reinsurance_scope_file` contents + + delete: + Disassociates the portfolios `reinsurance_scope_file` contents + """ + method = request.method.lower() + if method == 'post': + store_as_parquet = django_settings.PORTFOLIO_PARQUET_STORAGE + oed_validate = request.GET.get('validate', str(django_settings.PORTFOLIO_UPLOAD_VALIDATION)).lower() == 'true' + else: + store_as_parquet = None + oed_validate = None + return handle_related_file(self.get_object(), 'reinsurance_scope_file', request, self.supported_mime_types, store_as_parquet, oed_validate) + + @action(methods=['get', 'post'], detail=True) + def validate(self, request, pk=None, version=None): + """ + get: + Return OED validation status for each attached file + + post: + Run OED validation on the connected exposure files + """ + method = request.method.lower() + instance = self.get_object() + + if method == 'post': + instance.run_oed_validation() + + serializer = self.get_serializer(instance) + return Response(serializer.data) diff --git a/src/server/oasisapi/portfolios/v2_api/__init__.py b/src/server/oasisapi/portfolios/v2_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/portfolios/serializers.py b/src/server/oasisapi/portfolios/v2_api/serializers.py similarity index 95% rename from src/server/oasisapi/portfolios/serializers.py rename to src/server/oasisapi/portfolios/v2_api/serializers.py index b15c75c40..6a7f3644b 100644 --- a/src/server/oasisapi/portfolios/serializers.py +++ b/src/server/oasisapi/portfolios/v2_api/serializers.py @@ -13,13 +13,13 @@ from azure.core.exceptions import ResourceNotFoundError as Blob_ResourceNotFoundError from azure.storage.blob import BlobLeaseClient -from .models import Portfolio -from ..analyses.serializers import AnalysisSerializer -from ..files.models import RelatedFile -from ..files.models import file_storage_link -from ..files.upload import wait_for_blob_copy -from ..permissions.group_auth import validate_and_update_groups, validate_user_is_owner -from ..schemas.serializers import ( +from ..models import Portfolio +from ...analyses.v2_api.serializers import AnalysisSerializer +from ...files.models import RelatedFile +from ...files.models import file_storage_link +from ...files.upload import wait_for_blob_copy +from ...permissions.group_auth import validate_and_update_groups, validate_user_is_owner +from ...schemas.serializers import ( LocFileSerializer, AccFileSerializer, ReinsInfoFileSerializer, @@ -31,7 +31,6 @@ class PortfolioListSerializer(serializers.Serializer): """ Read Only Portfolio Deserializer for efficiently returning a list of all Portfolios in DB """ - id = serializers.IntegerField(read_only=True) name = serializers.CharField(read_only=True) created = serializers.DateTimeField(read_only=True) @@ -320,13 +319,14 @@ def update(self, instance, validated_data): files_for_removal = list() user = self.context['request'].user for field in validated_data: - content_type = self.get_content_type(validated_data[field]) + old_file_name = validated_data[field] + content_type = self.get_content_type(old_file_name) + fname = path.basename(old_file_name) + new_file_name = default_storage.get_alternative_name(fname, '') # S3 storage - File copy needed if hasattr(default_storage, 'bucket'): - fname = path.basename(validated_data[field]) new_file = ContentFile(b'') - new_file.name = default_storage.get_alternative_name(fname, '') new_related_file = RelatedFile.objects.create( file=new_file, filename=fname, @@ -343,12 +343,11 @@ def update(self, instance, validated_data): elif hasattr(default_storage, 'azure_container'): # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-blob-copy?tabs=python - fname = path.basename(validated_data[field]) - new_file_name = default_storage.get_alternative_name(validated_data[field], '') + new_file_name = default_storage.get_alternative_name(old_file_name, '') new_blobname = '/'.join([default_storage.location, path.basename(new_file_name)]) # Copies a blob asynchronously. - source_blob = default_storage.client.get_blob_client(validated_data[field]) + source_blob = default_storage.client.get_blob_client(old_file_name) dest_blob = default_storage.client.get_blob_client(new_blobname) try: @@ -367,17 +366,17 @@ def update(self, instance, validated_data): file=File(stored_blob, name=new_file_name), filename=fname, content_type=content_type, - creator=self.context['request'].user, + creator=user, store_as_filename=True, ) # Shared-fs else: - stored_file = default_storage.open(validated_data[field]) - new_file = File(stored_file, name=validated_data[field]) + stored_file = default_storage.open(old_file_name) + new_file = File(stored_file, name=new_file_name) new_related_file = RelatedFile.objects.create( file=new_file, - filename=validated_data[field], + filename=fname, content_type=content_type, creator=user, store_as_filename=True, diff --git a/src/server/oasisapi/portfolios/v2_api/tests/__init__.py b/src/server/oasisapi/portfolios/v2_api/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/oasisapi/portfolios/v2_api/tests/fakes.py b/src/server/oasisapi/portfolios/v2_api/tests/fakes.py new file mode 100644 index 000000000..5d69077cf --- /dev/null +++ b/src/server/oasisapi/portfolios/v2_api/tests/fakes.py @@ -0,0 +1,7 @@ +from model_mommy import mommy + +from src.server.oasisapi.portfolios.models import Portfolio + + +def fake_portfolio(**kwargs): + return mommy.make(Portfolio, **kwargs) diff --git a/src/server/oasisapi/portfolios/tests/test_portfolio.py b/src/server/oasisapi/portfolios/v2_api/tests/test_portfolio.py similarity index 89% rename from src/server/oasisapi/portfolios/tests/test_portfolio.py rename to src/server/oasisapi/portfolios/v2_api/tests/test_portfolio.py index fdef8368b..eaa680f4f 100644 --- a/src/server/oasisapi/portfolios/tests/test_portfolio.py +++ b/src/server/oasisapi/portfolios/v2_api/tests/test_portfolio.py @@ -16,23 +16,24 @@ from rest_framework_simplejwt.tokens import AccessToken from ods_tools.oed.exposure import OedExposure -from ...files.tests.fakes import fake_related_file -from ...analysis_models.tests.fakes import fake_analysis_model -from ...analyses.models import Analysis -from ...auth.tests.fakes import fake_user, add_fake_group -from ..models import Portfolio +from src.server.oasisapi.files.tests.fakes import fake_related_file +from src.server.oasisapi.analysis_models.v2_api.tests.fakes import fake_analysis_model +from src.server.oasisapi.analyses.models import Analysis +from src.server.oasisapi.auth.tests.fakes import fake_user, add_fake_group +from src.server.oasisapi.portfolios.models import Portfolio from .fakes import fake_portfolio # Override default deadline for all tests to 10s settings.register_profile("ci", deadline=1000.0) settings.load_profile("ci") +NAMESPACE = 'v2-portfolios' class PortfolioApi(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): portfolio = fake_portfolio() - response = self.app.get(portfolio.get_absolute_url(), expect_errors=True) + response = self.app.get(portfolio.get_absolute_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -40,7 +41,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): portfolio = fake_portfolio() response = self.app.get( - reverse('portfolio-detail', kwargs={'version': 'v1', 'pk': portfolio.pk + 1}), + reverse(f'{NAMESPACE}:portfolio-detail', kwargs={'pk': portfolio.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -53,7 +54,7 @@ def test_name_is_not_provided___response_is_400(self): user = fake_user() response = self.app.post( - reverse('portfolio-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:portfolio-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -69,7 +70,7 @@ def test_cleaned_name_is_empty___response_is_400(self, name): user = fake_user() response = self.app.post( - reverse('portfolio-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:portfolio-list'), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -92,7 +93,7 @@ def test_cleaned_name_is_present___object_is_created(self, name, group_name): add_fake_group(user, group_name) response = self.app.post( - reverse('portfolio-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:portfolio-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -109,7 +110,7 @@ def test_cleaned_name_is_present___object_is_created(self, name, group_name): portfolio.save() response = self.app.get( - portfolio.get_absolute_url(), + portfolio.get_absolute_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -123,26 +124,26 @@ def test_cleaned_name_is_present___object_is_created(self, name, group_name): 'modified': portfolio.modified.strftime('%Y-%m-%dT%H:%M:%S.%fZ'), 'groups': [group_name], 'accounts_file': { - "uri": response.request.application_url + portfolio.get_absolute_accounts_file_url(), + "uri": response.request.application_url + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), "name": portfolio.accounts_file.filename, "stored": str(portfolio.accounts_file.file) }, 'location_file': { - "uri": response.request.application_url + portfolio.get_absolute_location_file_url(), + "uri": response.request.application_url + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), "name": portfolio.location_file.filename, "stored": str(portfolio.location_file.file) }, 'reinsurance_info_file': { - "uri": response.request.application_url + portfolio.get_absolute_reinsurance_info_file_url(), + "uri": response.request.application_url + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), "name": portfolio.reinsurance_info_file.filename, "stored": str(portfolio.reinsurance_info_file.file) }, 'reinsurance_scope_file': { - "uri": response.request.application_url + portfolio.get_absolute_reinsurance_scope_file_url(), + "uri": response.request.application_url + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), "name": portfolio.reinsurance_scope_file.filename, "stored": str(portfolio.reinsurance_scope_file.file) }, - 'storage_links': response.request.application_url + portfolio.get_absolute_storage_url() + 'storage_links': response.request.application_url + portfolio.get_absolute_storage_url(namespace=NAMESPACE) }, response.json) @given( @@ -155,7 +156,7 @@ def test_default_empty_groups___visible_for_users_without_groups_only(self, name add_fake_group(user_with_groups, group_name) response = self.app.post( - reverse('portfolio-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:portfolio-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_with_groups)) }, @@ -169,7 +170,7 @@ def test_default_empty_groups___visible_for_users_without_groups_only(self, name # The same user who created it should also see it response = self.app.get( - reverse('portfolio-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:portfolio-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user_with_groups)) }, @@ -179,7 +180,7 @@ def test_default_empty_groups___visible_for_users_without_groups_only(self, name # Another user that don't belong to the group should not see it response = self.app.get( - reverse('portfolio-list', kwargs={'version': 'v1'}), + reverse(f'{NAMESPACE}:portfolio-list'), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(fake_user())) }, @@ -192,7 +193,7 @@ class PortfolioApiCreateAnalysis(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): portfolio = fake_portfolio() - response = self.app.get(portfolio.get_absolute_create_analysis_url(), expect_errors=True) + response = self.app.get(portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_user_is_authenticated_object_does_not_exist___response_is_404(self): @@ -200,7 +201,7 @@ def test_user_is_authenticated_object_does_not_exist___response_is_404(self): portfolio = fake_portfolio() response = self.app.post( - reverse('portfolio-create-analysis', kwargs={'version': 'v1', 'pk': portfolio.pk + 1}), + reverse(f'{NAMESPACE}:portfolio-create-analysis', kwargs={'pk': portfolio.pk + 1}), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -215,7 +216,7 @@ def test_name_is_not_provided___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_create_analysis_url(), + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -232,7 +233,7 @@ def test_portfolio_does_not_have_location_file_set___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_create_analysis_url(), + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -249,7 +250,7 @@ def test_model_is_not_provided___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_create_analysis_url(), + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -266,7 +267,7 @@ def test_model_does_not_exist___response_is_400(self): model = fake_analysis_model() response = self.app.post( - portfolio.get_absolute_create_analysis_url(), + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -284,7 +285,7 @@ def test_cleaned_name_is_empty___response_is_400(self, name): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_create_analysis_url(), + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), expect_errors=True, headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) @@ -304,10 +305,11 @@ def test_cleaned_name_and_model_are_present___object_is_created_inputs_are_gener user = fake_user() model = fake_analysis_model() + model.run_mode = model.run_mode_choices.V2 portfolio = fake_portfolio(location_file=fake_related_file()) response = self.app.post( - portfolio.get_absolute_create_analysis_url(), + portfolio.get_absolute_create_analysis_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -327,8 +329,9 @@ def test_cleaned_name_and_model_are_present___object_is_created_inputs_are_gener analysis.run_traceback_file = fake_related_file() analysis.save() + ANALYSES_NAMESPACE = 'v2-analyses' response = self.app.get( - analysis.get_absolute_url(), + analysis.get_absolute_url(namespace=ANALYSES_NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -341,19 +344,22 @@ def test_cleaned_name_and_model_are_present___object_is_created_inputs_are_gener self.assertEqual(response.json['name'], name) self.assertEqual(response.json['portfolio'], portfolio.pk) self.assertEqual(response.json['model'], model.pk) - self.assertEqual(response.json['settings_file'], response.request.application_url + analysis.get_absolute_settings_file_url()) - self.assertEqual(response.json['input_file'], response.request.application_url + analysis.get_absolute_input_file_url()) + self.assertEqual(response.json['settings_file'], response.request.application_url + + analysis.get_absolute_settings_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['input_file'], response.request.application_url + + analysis.get_absolute_input_file_url(namespace=ANALYSES_NAMESPACE)) self.assertEqual(response.json['lookup_errors_file'], response.request.application_url + - analysis.get_absolute_lookup_errors_file_url()) + analysis.get_absolute_lookup_errors_file_url(namespace=ANALYSES_NAMESPACE)) self.assertEqual(response.json['lookup_success_file'], response.request.application_url + - analysis.get_absolute_lookup_success_file_url()) + analysis.get_absolute_lookup_success_file_url(namespace=ANALYSES_NAMESPACE)) self.assertEqual(response.json['lookup_validation_file'], response.request.application_url + - analysis.get_absolute_lookup_validation_file_url()) + analysis.get_absolute_lookup_validation_file_url(namespace=ANALYSES_NAMESPACE)) self.assertEqual(response.json['input_generation_traceback_file'], response.request.application_url + - analysis.get_absolute_input_generation_traceback_file_url()) - self.assertEqual(response.json['output_file'], response.request.application_url + analysis.get_absolute_output_file_url()) + analysis.get_absolute_input_generation_traceback_file_url(namespace=ANALYSES_NAMESPACE)) + self.assertEqual(response.json['output_file'], response.request.application_url + + analysis.get_absolute_output_file_url(namespace=ANALYSES_NAMESPACE)) self.assertEqual(response.json['run_traceback_file'], response.request.application_url + - analysis.get_absolute_run_traceback_file_url()) + analysis.get_absolute_run_traceback_file_url(namespace=ANALYSES_NAMESPACE)) generate_mock.assert_called_once_with(analysis, user) @@ -361,7 +367,7 @@ class PortfolioAccountsFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): portfolio = fake_portfolio() - response = self.app.get(portfolio.get_absolute_accounts_file_url(), expect_errors=True) + response = self.app.get(portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_accounts_file_is_not_present___get_response_is_404(self): @@ -369,7 +375,7 @@ def test_accounts_file_is_not_present___get_response_is_404(self): portfolio = fake_portfolio() response = self.app.get( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -383,7 +389,7 @@ def test_accounts_file_is_not_present___delete_response_is_404(self): portfolio = fake_portfolio() response = self.app.delete( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -399,7 +405,7 @@ def test_accounts_file_is_not_a_valid_format___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -422,7 +428,7 @@ def test_accounts_file_user_is_not_permitted___response_is_403(self, file_conten portfolio.save() response = self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -442,7 +448,7 @@ def test_accounts_file_is_uploaded___file_can_be_retrieved(self, file_content, c portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -452,7 +458,7 @@ def test_accounts_file_is_uploaded___file_can_be_retrieved(self, file_content, c ) response = self.app.get( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -471,7 +477,7 @@ def test_accounts_file_invalid_uploaded___parquet_exception_raised(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -493,7 +499,7 @@ def test_accounts_file_is_uploaded_as_parquet___file_can_be_retrieved(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -503,14 +509,14 @@ def test_accounts_file_is_uploaded_as_parquet___file_can_be_retrieved(self): ) csv_response = self.app.get( - portfolio.get_absolute_accounts_file_url() + '?file_format=csv', + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE) + '?file_format=csv', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, ) parquet_response = self.app.get( - portfolio.get_absolute_accounts_file_url() + '?file_format=parquet', + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE) + '?file_format=parquet', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -536,7 +542,7 @@ class PortfolioLocationFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): portfolio = fake_portfolio() - response = self.app.get(portfolio.get_absolute_location_file_url(), expect_errors=True) + response = self.app.get(portfolio.get_absolute_location_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_location_file_is_not_present___get_response_is_404(self): @@ -544,7 +550,7 @@ def test_location_file_is_not_present___get_response_is_404(self): portfolio = fake_portfolio() response = self.app.get( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -558,7 +564,7 @@ def test_location_file_is_not_present___delete_response_is_404(self): portfolio = fake_portfolio() response = self.app.delete( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -574,7 +580,7 @@ def test_location_file_is_not_a_valid_format___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -594,7 +600,7 @@ def test_location_file_is_uploaded___file_can_be_retrieved(self, file_content, c portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -604,7 +610,7 @@ def test_location_file_is_uploaded___file_can_be_retrieved(self, file_content, c ) response = self.app.get( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -623,7 +629,7 @@ def test_location_file_invalid_uploaded___parquet_exception_raised(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -645,7 +651,7 @@ def test_location_file_is_uploaded_as_parquet___file_can_be_retrieved(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -655,13 +661,13 @@ def test_location_file_is_uploaded_as_parquet___file_can_be_retrieved(self): ) csv_response = self.app.get( - portfolio.get_absolute_location_file_url() + '?file_format=csv', + portfolio.get_absolute_location_file_url(namespace=NAMESPACE) + '?file_format=csv', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, ) parquet_response = self.app.get( - portfolio.get_absolute_location_file_url() + '?file_format=parquet', + portfolio.get_absolute_location_file_url(namespace=NAMESPACE) + '?file_format=parquet', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -687,7 +693,7 @@ class PortfolioReinsuranceSourceFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): portfolio = fake_portfolio() - response = self.app.get(portfolio.get_absolute_reinsurance_scope_file_url(), expect_errors=True) + response = self.app.get(portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_reinsurance_scope_file_is_not_present___get_response_is_404(self): @@ -695,7 +701,7 @@ def test_reinsurance_scope_file_is_not_present___get_response_is_404(self): portfolio = fake_portfolio() response = self.app.get( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -709,7 +715,7 @@ def test_reinsurance_scope_file_is_not_present___delete_response_is_404(self): portfolio = fake_portfolio() response = self.app.delete( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -725,7 +731,7 @@ def test_reinsurance_scope_file_is_not_a_valid_format___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -745,7 +751,7 @@ def test_reinsurance_scope_file_is_uploaded___file_can_be_retrieved(self, file_c portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -755,7 +761,7 @@ def test_reinsurance_scope_file_is_uploaded___file_can_be_retrieved(self, file_c ) response = self.app.get( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -774,7 +780,7 @@ def test_reinsurance_scope_file_invalid_uploaded___parquet_exception_raised(self portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -796,7 +802,7 @@ def test_reinsurance_scope_file_is_uploaded_as_parquet___file_can_be_retrieved(s portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -806,14 +812,14 @@ def test_reinsurance_scope_file_is_uploaded_as_parquet___file_can_be_retrieved(s ) csv_response = self.app.get( - portfolio.get_absolute_reinsurance_scope_file_url() + '?file_format=csv', + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE) + '?file_format=csv', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, ) parquet_response = self.app.get( - portfolio.get_absolute_reinsurance_scope_file_url() + '?file_format=parquet', + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE) + '?file_format=parquet', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -839,7 +845,7 @@ class PortfolioReinsuranceInfoFile(WebTestMixin, TestCase): def test_user_is_not_authenticated___response_is_forbidden(self): portfolio = fake_portfolio() - response = self.app.get(portfolio.get_absolute_reinsurance_info_file_url(), expect_errors=True) + response = self.app.get(portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), expect_errors=True) self.assertIn(response.status_code, [401, 403]) def test_reinsurance_info_file_is_not_present___get_response_is_404(self): @@ -847,7 +853,7 @@ def test_reinsurance_info_file_is_not_present___get_response_is_404(self): portfolio = fake_portfolio() response = self.app.get( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -861,7 +867,7 @@ def test_reinsurance_info_file_is_not_present___delete_response_is_404(self): portfolio = fake_portfolio() response = self.app.delete( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -877,7 +883,7 @@ def test_reinsurance_info_file_is_not_a_valid_format___response_is_400(self): portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -897,7 +903,7 @@ def test_reinsurance_info_file_is_uploaded___file_can_be_retrieved(self, file_co portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -907,7 +913,7 @@ def test_reinsurance_info_file_is_uploaded___file_can_be_retrieved(self, file_co ) response = self.app.get( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -926,7 +932,7 @@ def test_reinsurance_info_file_invalid_uploaded___parquet_exception_raised(self) portfolio = fake_portfolio() response = self.app.post( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -948,7 +954,7 @@ def test_reinsurance_info_file_is_uploaded_as_parquet___file_can_be_retrieved(se portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -958,14 +964,14 @@ def test_reinsurance_info_file_is_uploaded_as_parquet___file_can_be_retrieved(se ) csv_response = self.app.get( - portfolio.get_absolute_reinsurance_info_file_url() + '?file_format=csv', + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE) + '?file_format=csv', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, ) parquet_response = self.app.get( - portfolio.get_absolute_reinsurance_info_file_url() + '?file_format=parquet', + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE) + '?file_format=parquet', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1036,7 +1042,7 @@ def test_all_exposure__are_valid(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1045,7 +1051,7 @@ def test_all_exposure__are_valid(self): ), ) self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1054,7 +1060,7 @@ def test_all_exposure__are_valid(self): ), ) self.app.post( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1063,7 +1069,7 @@ def test_all_exposure__are_valid(self): ), ) self.app.post( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1073,7 +1079,7 @@ def test_all_exposure__are_valid(self): ) validate_response = self.app.get( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1089,7 +1095,7 @@ def test_all_exposure__are_valid(self): # Run validate - check is valid validate_response = self.app.post( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1112,7 +1118,7 @@ def test_location_file__is_valid(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1122,7 +1128,7 @@ def test_location_file__is_valid(self): ) validate_response = self.app.get( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1138,7 +1144,7 @@ def test_location_file__is_valid(self): # Run validate - check is valid validate_response = self.app.post( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1161,7 +1167,7 @@ def test_location_file__is_invalid__response_is_400(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_location_file_url(), + portfolio.get_absolute_location_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1170,7 +1176,7 @@ def test_location_file__is_invalid__response_is_400(self): ), ) validate_response = self.app.get( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1186,7 +1192,7 @@ def test_location_file__is_invalid__response_is_400(self): # Run validate - check is invalid validate_response = self.app.post( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1214,7 +1220,7 @@ def test_account_file__is_invalid__response_is_400(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_accounts_file_url(), + portfolio.get_absolute_accounts_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1224,7 +1230,7 @@ def test_account_file__is_invalid__response_is_400(self): ) validate_response = self.app.get( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1240,7 +1246,7 @@ def test_account_file__is_invalid__response_is_400(self): # Run validate - check is valid validate_response = self.app.post( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1280,7 +1286,7 @@ def test_reinsurance_info_file__is_invalid__response_is_400(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_reinsurance_info_file_url(), + portfolio.get_absolute_reinsurance_info_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1290,7 +1296,7 @@ def test_reinsurance_info_file__is_invalid__response_is_400(self): ) validate_response = self.app.get( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1306,7 +1312,7 @@ def test_reinsurance_info_file__is_invalid__response_is_400(self): # Run validate - check is valid validate_response = self.app.post( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1351,7 +1357,7 @@ def test_reinsurance_scope_file__is_invalid__response_is_400(self): portfolio = fake_portfolio() self.app.post( - portfolio.get_absolute_reinsurance_scope_file_url(), + portfolio.get_absolute_reinsurance_scope_file_url(namespace=NAMESPACE), headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1361,7 +1367,7 @@ def test_reinsurance_scope_file__is_invalid__response_is_400(self): ) validate_response = self.app.get( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, @@ -1377,7 +1383,7 @@ def test_reinsurance_scope_file__is_invalid__response_is_400(self): # Run validate - check is valid validate_response = self.app.post( - portfolio.get_absolute_url() + 'validate/', + portfolio.get_absolute_url(namespace=NAMESPACE) + 'validate/', headers={ 'Authorization': 'Bearer {}'.format(AccessToken.for_user(user)) }, diff --git a/src/server/oasisapi/portfolios/v2_api/urls.py b/src/server/oasisapi/portfolios/v2_api/urls.py new file mode 100644 index 000000000..da976f087 --- /dev/null +++ b/src/server/oasisapi/portfolios/v2_api/urls.py @@ -0,0 +1,13 @@ +from django.conf.urls import include, url +from rest_framework.routers import SimpleRouter +from .viewsets import PortfolioViewSet + + +app_name = 'portfolios' +v2_api_router = SimpleRouter() +v2_api_router.include_root_view = False +v2_api_router.register('portfolios', PortfolioViewSet, basename='portfolio') + +urlpatterns = [ + url(r'', include(v2_api_router.urls)), +] diff --git a/src/server/oasisapi/portfolios/viewsets.py b/src/server/oasisapi/portfolios/v2_api/viewsets.py similarity index 94% rename from src/server/oasisapi/portfolios/viewsets.py rename to src/server/oasisapi/portfolios/v2_api/viewsets.py index 10ebd6d58..e777a8b74 100644 --- a/src/server/oasisapi/portfolios/viewsets.py +++ b/src/server/oasisapi/portfolios/v2_api/viewsets.py @@ -11,9 +11,9 @@ from rest_framework.settings import api_settings from rest_framework.status import HTTP_201_CREATED -from .models import Portfolio -from ..schemas.custom_swagger import FILE_RESPONSE, FILE_FORMAT_PARAM, FILE_VALIDATION_PARAM -from ..schemas.serializers import StorageLinkSerializer +from ..models import Portfolio +from ...schemas.custom_swagger import FILE_RESPONSE, FILE_FORMAT_PARAM, FILE_VALIDATION_PARAM +from ...schemas.serializers import StorageLinkSerializer from .serializers import ( PortfolioSerializer, CreateAnalysisSerializer, @@ -21,13 +21,13 @@ PortfolioListSerializer, PortfolioValidationSerializer ) -from ..analyses.serializers import AnalysisSerializer -from ..files.serializers import RelatedFileSerializer -from ..files.views import handle_related_file -from ..filters import TimeStampedFilter -from ..permissions.group_auth import VerifyGroupAccessModelViewSet -from ..schemas.custom_swagger import FILE_RESPONSE, FILE_FORMAT_PARAM -from ..schemas.serializers import StorageLinkSerializer +from ...analyses.v2_api.serializers import AnalysisSerializer +from ...files.serializers import RelatedFileSerializer +from ...files.views import handle_related_file +from ...filters import TimeStampedFilter +from ...permissions.group_auth import VerifyGroupAccessModelViewSet +from ...schemas.custom_swagger import FILE_RESPONSE, FILE_FORMAT_PARAM +from ...schemas.serializers import StorageLinkSerializer class PortfolioFilter(TimeStampedFilter): diff --git a/src/server/oasisapi/queues/consumers.py b/src/server/oasisapi/queues/consumers.py index 55257f845..8ee119ee2 100644 --- a/src/server/oasisapi/queues/consumers.py +++ b/src/server/oasisapi/queues/consumers.py @@ -56,7 +56,7 @@ def wrap_message_content(message_type, content, status=ContentStatus.SUCCESS): def build_task_status_message(items: List[TaskStatusMessageItem], message_type='queue_status.updated'): - from src.server.oasisapi.analyses.serializers import AnalysisSerializerWebSocket, AnalysisTaskStatusSerializer + from src.server.oasisapi.analyses.v2_api.serializers import AnalysisSerializerWebSocket, AnalysisTaskStatusSerializer from src.server.oasisapi.queues.serializers import QueueSerializer return wrap_message_content( @@ -115,6 +115,7 @@ def build_all_queue_status_message(analysis_filter=None, message_type='queue_sta class QueueStatusConsumer(GuardedAsyncJsonWebsocketConsumer): + # class QueueStatusConsumer(AsyncJsonWebsocketConsumer): groups = ['queue_status'] async def connect(self): diff --git a/src/server/oasisapi/queues/routing.py b/src/server/oasisapi/queues/routing.py index de9725735..acc34c529 100644 --- a/src/server/oasisapi/queues/routing.py +++ b/src/server/oasisapi/queues/routing.py @@ -3,5 +3,5 @@ from src.server.oasisapi.queues import consumers websocket_urlpatterns = [ - path('v1/queue-status/', consumers.QueueStatusConsumer), + path('v2/queue-status/', consumers.QueueStatusConsumer), ] diff --git a/src/server/oasisapi/queues/serializers.py b/src/server/oasisapi/queues/serializers.py index b0a2a89e1..619590d1a 100644 --- a/src/server/oasisapi/queues/serializers.py +++ b/src/server/oasisapi/queues/serializers.py @@ -2,43 +2,61 @@ from rest_framework import serializers from src.server.oasisapi.analysis_models.models import AnalysisModel -from src.server.oasisapi.analysis_models.serializers import AnalysisModelSerializer -from src.server.oasisapi.analyses.serializers import AnalysisSerializer, AnalysisTaskStatusSerializer +from src.server.oasisapi.analysis_models.v2_api.serializers import AnalysisModelSerializer +from src.server.oasisapi.analyses.v2_api.serializers import AnalysisSerializerWebSocket, AnalysisTaskStatusSerializer class QueueSerializer(serializers.Serializer): name = serializers.CharField() pending_count = serializers.IntegerField() - worker_count = serializers.IntegerField() queued_count = serializers.IntegerField() running_count = serializers.IntegerField() + queue_message_count = serializers.IntegerField() + worker_count = serializers.IntegerField() models = serializers.SerializerMethodField() - @swagger_serializer_method(serializer_or_field=AnalysisModelSerializer) + @swagger_serializer_method(serializer_or_field=AnalysisModelSerializer(many=True)) def get_models(self, instance, *args, **kwargs): - models = [m for m in AnalysisModel.objects.all() if str(m) == instance['name']] + queue_name = instance['name'].removesuffix('-v2') + models = [m for m in AnalysisModel.objects.all() if str(m) == queue_name] return AnalysisModelSerializer(instance=models, many=True).data -class WebsocketSerializer(serializers.Serializer): - """ This is a 'dummy' Serializer to document - the WebSocket schema - """ +class WebsocketAnalysesSerializer(serializers.Serializer): + analysis = serializers.SerializerMethodField() + updated_tasks = serializers.SerializerMethodField() + + @swagger_serializer_method(serializer_or_field=AnalysisSerializerWebSocket()) + def get_analysis(self, instance, *args, **kwargs): + pass + + @swagger_serializer_method(serializer_or_field=AnalysisTaskStatusSerializer(many=True)) + def get_updated_tasks(self, instance, *args, **kwargs): + pass + + +class WebsocketContentSerializer(serializers.Serializer): queue = serializers.SerializerMethodField() analyses = serializers.SerializerMethodField() - updated_tasks = serializers.SerializerMethodField() - time = serializers.DateField() - type = serializers.CharField() - status = serializers.CharField() @swagger_serializer_method(serializer_or_field=QueueSerializer) def get_queue(self, instance, *args, **kwargs): pass - @swagger_serializer_method(serializer_or_field=AnalysisSerializer(many=True)) + @swagger_serializer_method(serializer_or_field=WebsocketAnalysesSerializer(many=True)) def get_analyses(self, instance, *args, **kwargs): pass - @swagger_serializer_method(serializer_or_field=AnalysisTaskStatusSerializer(many=True)) - def get_updated_tasks(self, instance, *args, **kwargs): + +class WebsocketSerializer(serializers.Serializer): + """ This is a 'dummy' Serializer to document + the WebSocket schema + """ + time = serializers.DateField() + type = serializers.CharField() + status = serializers.CharField() + content = serializers.SerializerMethodField() + + @swagger_serializer_method(serializer_or_field=WebsocketContentSerializer(many=True)) + def get_content(self, instance, *args, **kwargs): pass diff --git a/src/server/oasisapi/queues/urls.py b/src/server/oasisapi/queues/urls.py new file mode 100644 index 000000000..d57ad78de --- /dev/null +++ b/src/server/oasisapi/queues/urls.py @@ -0,0 +1,15 @@ +from rest_framework.routers import SimpleRouter +from django.conf.urls import url, include +from .viewsets import QueueViewSet +from .viewsets import WebsocketViewSet + + +app_name = 'queue' +v2_api_router = SimpleRouter() +v2_api_router.include_root_view = False +v2_api_router.register('queue', QueueViewSet, basename='queue') +v2_api_router.register('queue-status', WebsocketViewSet, basename='queue') + +urlpatterns = [ + url(r'', include(v2_api_router.urls)), +] diff --git a/src/server/oasisapi/queues/utils.py b/src/server/oasisapi/queues/utils.py index b0bcab868..a111c696c 100644 --- a/src/server/oasisapi/queues/utils.py +++ b/src/server/oasisapi/queues/utils.py @@ -5,42 +5,51 @@ from django.db.models import Count from kombu import Connection -from src.server.oasisapi.celery_app import celery_app +from src.server.oasisapi.celery_app_v2 import v2 as celery_app_v2 QueueInfo = Dict[str, int] +def _get_queue_consumers(queue_name): + with celery_app_v2.pool.acquire(block=True) as conn: + chan = conn.channel() + name, message_count, consumers = chan.queue_declare(queue=queue_name, passive=True) + chan.close() + return consumers + + +def _get_queue_message_count(queue_name): + with celery_app_v2.pool.acquire(block=True) as conn: + chan = conn.channel() + name, message_count, consumers = chan.queue_declare(queue=queue_name, passive=True) + chan.close() + return message_count + + def _add_to_dict(d, k, v): - d[k] = v + if k not in d: + d[k] = v + else: + d[k] += v return d def _get_broker_queue_names(): if settings.BROKER_URL.startswith('amqp://'): c = Connection(settings.BROKER_URL) - return (q['name'] for q in c.connection.client.manager.get_queues()) + return (q['name'] for q in c.connection.client.manager.get_queues() if 'pidbox' not in q['name'] and 'celeryev' not in q['name']) if settings.BROKER_URL.startswith('redis://'): c = Connection(settings.BROKER_URL) - return (q['name'] for q in c.connection.client.manager.channel.active_queues) + return (q['name'] for q in c.connection.client.manager.channel.active_queues if 'pidbox' not in q['name'] and 'celeryev' not in q['name']) elif settings.BROKER_URL.startswith('memory://'): # # TODO: figure out how to get this to work for memory broker # - return (celery_app.conf.task_default_routing_key, ) + return (celery_app_v2.conf.task_default_routing_key, ) raise NotImplementedError('Support for your broker is not yet supported') -def _get_active_queues(): - if settings.BROKER_URL.startswith('memory://'): - # - # TODO: figure out how to get this to work for memory broker - # - return {} - - return celery_app.control.inspect().active_queues() - - def get_queues_info() -> List[QueueInfo]: """ Gets a list of dictionaries containing information about the queues in the system. @@ -55,34 +64,18 @@ def get_queues_info() -> List[QueueInfo]: """ from src.server.oasisapi.analyses.models import AnalysisTaskStatus - # setup an entry for every element in the broker (this will include - # queues with no workers yet) + # setup an entry for every element in the broker res = [ { 'name': q, 'pending_count': 0, 'queued_count': 0, 'running_count': 0, - 'worker_count': 0, + 'queue_message_count': _get_queue_message_count(q), + 'worker_count': _get_queue_consumers(q), } for q in _get_broker_queue_names() ] - # increment the number of workers available for each queue - queues = _get_active_queues() - if queues: - for worker in queues.values(): - for queue in worker: - try: - next(r for r in res if r['name'] == queue['routing_key'])['worker_count'] += 1 - except StopIteration: - # in case there are workers around still for inactive queues add it here - res.append({ - 'name': queue['routing_key'], - 'queued_count': 0, - 'running_count': 0, - 'worker_count': 1, - }) - # get the stats of the running and queued tasks pending = reduce( lambda current, value: _add_to_dict(current, value['queue_name'], value['count']), diff --git a/src/server/oasisapi/queues/viewsets.py b/src/server/oasisapi/queues/viewsets.py index 981e61c8f..53e88ec6d 100644 --- a/src/server/oasisapi/queues/viewsets.py +++ b/src/server/oasisapi/queues/viewsets.py @@ -19,17 +19,17 @@ def list(self, request, *args, **kwargs): class WebsocketViewSet(viewsets.ViewSet): - @swagger_auto_schema(responses={200: WebsocketSerializer(many=True, read_only=True)}) + @swagger_auto_schema(responses={200: WebsocketSerializer(many=False, read_only=True)}) def list(self, request, *args, **kwargs): """ This endpoint documents the schema for the WebSocket used for async status updates at - `ws://:/ws/v1/queue-status/` + `ws://:/ws/v2/queue-status/` Issuing a GET call returns the current state returned from the WebSocket. To print the websocket directly use the following: ``` pip install websocket_client - ./manage.py ws_echo --url ws://localhost:8001/ws/v1/queue-status/ + ./manage.py ws_echo --url ws://localhost:8001/ws/v2/queue-status/ ``` """ return Response(build_all_queue_status_message()) diff --git a/src/server/oasisapi/schemas/custom_swagger.py b/src/server/oasisapi/schemas/custom_swagger.py index ec06b449a..5b1ed0fdd 100644 --- a/src/server/oasisapi/schemas/custom_swagger.py +++ b/src/server/oasisapi/schemas/custom_swagger.py @@ -3,6 +3,7 @@ 'HEALTHCHECK', 'TOKEN_REFRESH_HEADER', 'FILE_FORMAT_PARAM', + 'RUN_MODE_PARAM', 'SUBTASK_STATUS_PARAM', 'SUBTASK_SLUG_PARAM', 'FILE_VALIDATION_PARAM', @@ -71,6 +72,15 @@ enum=['csv', 'parquet'] ) +RUN_MODE_PARAM = openapi.Parameter( + 'run_mode_override', + openapi.IN_QUERY, + required=False, + description="Override task run_mode, `V1 = Single server` or `V2 = distributed`", + type=openapi.TYPE_STRING, + enum=['V1', 'V2'] +) + SUBTASK_STATUS_PARAM = openapi.Parameter( 'subtask_status', openapi.IN_QUERY, diff --git a/src/server/oasisapi/schemas/serializers.py b/src/server/oasisapi/schemas/serializers.py index d4ce6d59d..eb0b2779b 100644 --- a/src/server/oasisapi/schemas/serializers.py +++ b/src/server/oasisapi/schemas/serializers.py @@ -6,6 +6,10 @@ 'ReinsScopeFileSerializer', 'AnalysisSettingsSerializer', 'ModelParametersSerializer', + 'GroupNameSerializer', + 'QueueNameSerializer', + 'TaskCountSerializer', + 'TaskErrorSerializer', ] import json @@ -20,6 +24,11 @@ from ods_tools.oed.common import OdsException +TaskErrorSerializer = serializers.ListField(child=serializers.IntegerField()) +GroupNameSerializer = serializers.ListField(child=serializers.CharField()) +QueueNameSerializer = serializers.ListField(child=serializers.CharField()) + + class TokenObtainPairResponseSerializer(serializers.Serializer): refresh_token = serializers.CharField(read_only=True) access_token = serializers.CharField(read_only=True) @@ -106,6 +115,23 @@ def update(self, instance, validated_data): raise NotImplementedError() +class TaskCountSerializer(serializers.Serializer): + TOTAL_IN_QUEUE = serializers.IntegerField() + TOTAL = serializers.IntegerField() + PENDING = serializers.IntegerField() + QUEUED = serializers.IntegerField() + STARTED = serializers.IntegerField() + COMPLETED = serializers.IntegerField() + CANCELLED = serializers.IntegerField() + ERROR = serializers.IntegerField() + + def create(self, validated_data): + raise NotImplementedError() + + def update(self, instance, validated_data): + raise NotImplementedError() + + def update_links(link_prefix, d): """ Linking in pre-defined scheams with path links will be nested diff --git a/src/server/oasisapi/settings/__init__.py b/src/server/oasisapi/settings/__init__.py new file mode 100644 index 000000000..9b5ed21c9 --- /dev/null +++ b/src/server/oasisapi/settings/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/src/server/oasisapi/settings.py b/src/server/oasisapi/settings/base.py similarity index 92% rename from src/server/oasisapi/settings.py rename to src/server/oasisapi/settings/base.py index a6a5b6c93..51a6ee616 100644 --- a/src/server/oasisapi/settings.py +++ b/src/server/oasisapi/settings/base.py @@ -17,10 +17,10 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework.reverse import reverse_lazy -from ...common.shared import set_aws_log_level -from ...conf import iniconf # noqa -from ...conf.celeryconf import * # noqa -from ...common.shared import set_aws_log_level, set_azure_log_level +from ....common.shared import set_aws_log_level +from ....conf import iniconf # noqa +from ....common.shared import set_aws_log_level, set_azure_log_level +from ....conf.base import * IN_TEST = 'test' in sys.argv @@ -32,7 +32,11 @@ # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/2.0/howto/deployment/checklist/ -if (len(sys.argv) >= 2 and sys.argv[1] == 'runserver'): +IS_UNITTEST = sys.argv[0].endswith('pytest') +IS_TESTSERVER = len(sys.argv) >= 2 and sys.argv[1] == 'runserver' +IS_SWAGGER_GEN = len(sys.argv) >= 2 and sys.argv[1] == 'generate_swagger' + +if IS_UNITTEST or IS_TESTSERVER: # Always set Debug mode when in dev environment MEDIA_ROOT = './shared-fs/' DEBUG = True @@ -47,6 +51,17 @@ URL_SUB_PATH = iniconf.settings.getboolean('server', 'URL_SUB_PATH', fallback=True) CONSOLE_DEBUG = False + +# Generate All +DEFAULT_GENERATOR_CLASS = 'drf_yasg.generators.OpenAPISchemaGenerator' # Generate All +if IS_SWAGGER_GEN: + # generate only V1 endpoints + if iniconf.settings.getboolean('server', 'GEN_SWAGGER_V1', fallback=False): + DEFAULT_GENERATOR_CLASS = 'src.server.oasisapi.swagger.CustomGeneratorClassV1' + # generate only V2 endpoints + if iniconf.settings.getboolean('server', 'GEN_SWAGGER_V2', fallback=False): + DEFAULT_GENERATOR_CLASS = 'src.server.oasisapi.swagger.CustomGeneratorClassV2' + # Django 3.2 - the default pri-key field changed to 'BigAutoField.', # https://docs.djangoproject.com/en/3.2/releases/3.2/#customizing-type-of-auto-created-primary-keys DEFAULT_AUTO_FIELD = 'django.db.models.AutoField' @@ -77,11 +92,6 @@ 'django.contrib.messages', 'django.contrib.staticfiles', - 'django_filters', - 'rest_framework', - 'drf_yasg', - 'channels', - 'storages', 'src.server.oasisapi.oidc', 'src.server.oasisapi.files', @@ -93,6 +103,16 @@ 'src.server.oasisapi.info', 'src.server.oasisapi.queues', 'django_cleanup.apps.CleanupConfig', + + + # 'django_extensions', + 'django_filters', + 'rest_framework', + 'drf_yasg', + 'channels', + 'storages', + + ] MIDDLEWARE = [ @@ -171,8 +191,7 @@ 'src.server.oasisapi.filters.Backend', ), 'DATETIME_FORMAT': '%Y-%m-%dT%H:%M:%S.%fZ', - 'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning', - 'DEFAULT_VERSION': 'v1', + 'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.NamespaceVersioning', } # Password validation @@ -180,9 +199,9 @@ AUTHENTICATION_BACKENDS = iniconf.settings.get('server', 'auth_backends', fallback='django.contrib.auth.backends.ModelBackend').split(',') AUTH_PASSWORD_VALIDATORS = [] - API_AUTH_TYPE = iniconf.settings.get('server', 'API_AUTH_TYPE', fallback='') + if API_AUTH_TYPE == 'keycloak': INSTALLED_APPS += ( @@ -203,6 +222,7 @@ OIDC_VERIFY_SSL = False SWAGGER_SETTINGS = { + 'DEFAULT_GENERATOR_CLASS': DEFAULT_GENERATOR_CLASS, 'USE_SESSION_AUTH': False, 'SECURITY_DEFINITIONS': { "keycloak": { @@ -229,6 +249,7 @@ } SWAGGER_SETTINGS = { + 'DEFAULT_GENERATOR_CLASS': DEFAULT_GENERATOR_CLASS, 'DEFAULT_INFO': 'src.server.oasisapi.urls.api_info', 'LOGIN_URL': reverse_lazy('rest_framework:login'), 'LOGOUT_URL': reverse_lazy('rest_framework:logout'), @@ -252,6 +273,9 @@ # Place the app in a sub path (swagger still available in /) # FORCE_SCRIPT_NAME = '/api/' +# limit analyses logs access to admin accounts +RESTRICT_SYSTEM_LOGS = iniconf.settings.getboolean('server', 'RESTRICT_SYSTEM_LOGS', fallback=False) + # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/2.0/howto/static-files/ MEDIA_URL = '/api/media/' @@ -441,6 +465,9 @@ CELERY_TASK_ALWAYS_EAGER = True +# Option to remote the 'v2' routes and only run the server with 'v1' endpoints +DISABLE_V2_API = iniconf.settings.getboolean('server', 'disable_v2_api', fallback=False) + if DEBUG_TOOLBAR: INTERNAL_IPS = [ '127.0.0.1', diff --git a/src/server/oasisapi/settings/v1.py b/src/server/oasisapi/settings/v1.py new file mode 100644 index 000000000..5209b0ec2 --- /dev/null +++ b/src/server/oasisapi/settings/v1.py @@ -0,0 +1,4 @@ +from .base import * +from ....conf.celeryconf_v1 import * + +INSTALLED_APPS += ['src.server.oasisapi.analyses.v1_api'] diff --git a/src/server/oasisapi/settings/v2.py b/src/server/oasisapi/settings/v2.py new file mode 100644 index 000000000..02c117b6f --- /dev/null +++ b/src/server/oasisapi/settings/v2.py @@ -0,0 +1,4 @@ +from .base import * +from ....conf.celeryconf_v2 import * + +INSTALLED_APPS += ['src.server.oasisapi.analyses.v2_api'] diff --git a/src/server/oasisapi/swagger.py b/src/server/oasisapi/swagger.py new file mode 100644 index 000000000..5e4d21c93 --- /dev/null +++ b/src/server/oasisapi/swagger.py @@ -0,0 +1,55 @@ +__all__ = [ + 'CustomGeneratorClassV1', + 'CustomGeneratorClassV2', +] + +from drf_yasg.generators import OpenAPISchemaGenerator +from django.conf.urls import include, url +from django.conf import settings + + +# API v1 Routes +api_v1_urlpatterns = [ + url(r'^v1/', include('src.server.oasisapi.analysis_models.v1_api.urls', namespace='v1-models')), + url(r'^v1/', include('src.server.oasisapi.portfolios.v1_api.urls', namespace='v1-portfolios')), + url(r'^v1/', include('src.server.oasisapi.analyses.v1_api.urls', namespace='v1-analyses')), + url(r'^v1/', include('src.server.oasisapi.data_files.v1_api.urls', namespace='v1-files')), +] + +# API v2 Routes +api_v2_urlpatterns = [ + url(r'^v2/', include('src.server.oasisapi.analysis_models.v2_api.urls', namespace='v2-models')), + url(r'^v2/', include('src.server.oasisapi.analyses.v2_api.urls', namespace='v2-analyses')), + url(r'^v2/', include('src.server.oasisapi.portfolios.v2_api.urls', namespace='v2-portfolios')), + url(r'^v2/', include('src.server.oasisapi.data_files.v2_api.urls', namespace='v2-files')), + url(r'^v2/', include('src.server.oasisapi.queues.urls', namespace='v2-queues')), +] + +if settings.URL_SUB_PATH: + swagger_v1_urlpatterns = [url(r'^api/', include(api_v1_urlpatterns))] + swagger_v2_urlpatterns = [url(r'^api/', include(api_v2_urlpatterns))] +else: + swagger_v1_urlpatterns = api_v1_urlpatterns + swagger_v2_urlpatterns = api_v2_urlpatterns + + +class CustomGeneratorClassV1(OpenAPISchemaGenerator): + def __init__(self, info, version='', url=None, patterns=None, urlconf=None): + super().__init__( + info=info, + version='v1', + url=url, + patterns=swagger_v1_urlpatterns, + urlconf=urlconf + ) + + +class CustomGeneratorClassV2(OpenAPISchemaGenerator): + def __init__(self, info, version='', url=None, patterns=None, urlconf=None): + super().__init__( + info=info, + version='v2', + url=url, + patterns=swagger_v2_urlpatterns, + urlconf=urlconf + ) diff --git a/src/server/oasisapi/urls.py b/src/server/oasisapi/urls.py index 2f69e3a6a..070ee4e90 100644 --- a/src/server/oasisapi/urls.py +++ b/src/server/oasisapi/urls.py @@ -1,42 +1,21 @@ from django.conf import settings from django.conf.urls import include, url from django.conf.urls.static import static -from django.contrib import admin from drf_yasg import openapi from drf_yasg.views import get_schema_view from rest_framework import permissions -from rest_framework_nested import routers -from .analysis_models.viewsets import AnalysisModelViewSet, ModelSettingsView, SettingsTemplateViewSet -from .analyses.viewsets import AnalysisViewSet, AnalysisSettingsView, AnalysisTaskStatusViewSet -from .portfolios.viewsets import PortfolioViewSet -from .healthcheck.views import HealthcheckView -from .data_files.viewsets import DataFileViewset -from .info.views import PerilcodesView -from .info.views import ServerInfoView -from .queues.viewsets import QueueViewSet -from .queues.viewsets import WebsocketViewSet +from .swagger import ( + api_v1_urlpatterns, + api_v2_urlpatterns, + CustomGeneratorClassV1, + CustomGeneratorClassV2, +) if settings.DEBUG_TOOLBAR: from django.urls import path import debug_toolbar -admin.autodiscover() - -api_router = routers.DefaultRouter() -api_router.include_root_view = False -api_router.register('portfolios', PortfolioViewSet, basename='portfolio') -api_router.register('analyses', AnalysisViewSet, basename='analysis') -api_router.register('analysis-task-statuses', AnalysisTaskStatusViewSet, basename='analysis-task-status') -api_router.register('models', AnalysisModelViewSet, basename='analysis-model') -api_router.register('data_files', DataFileViewset, basename='data-file') -api_router.register('queue', QueueViewSet, basename='queue') -api_router.register('queue-status', WebsocketViewSet, basename='queue') -# api_router.register('files', FilesViewSet, basename='file') - -templates_router = routers.NestedSimpleRouter(api_router, r'models', lookup='models') -templates_router.register('setting_templates', SettingsTemplateViewSet, basename='models-setting_templates') - api_info_description = """ # Workflow @@ -71,59 +50,57 @@ 7. Run the analysis (post to `/analyses//run/`) 8. Get the outputs (get `/analyses//output_file/`)""" + api_info = openapi.Info( title="Oasis Platform", - default_version='v1', + default_version='v2', description=api_info_description, ) - -schema_view = get_schema_view( +schema_view_all = get_schema_view( api_info, public=True, permission_classes=(permissions.AllowAny,), ) +schema_view_v1 = get_schema_view( + api_info, + public=True, + permission_classes=(permissions.AllowAny,), + generator_class=CustomGeneratorClassV1, +) -""" Developer note: - -These are custom routes to use the endpoint 'settings' -adding the method 'def settings( .. )' fails under -viewsets.ModelViewSet due to it overriding -the internal Django settings object -""" - -model_settings = ModelSettingsView.as_view({ - 'get': 'model_settings', - 'post': 'model_settings', - 'delete': 'model_settings' -}) -analyses_settings = AnalysisSettingsView.as_view({ - 'get': 'analysis_settings', - 'post': 'analysis_settings', - 'delete': 'analysis_settings' -}) +schema_view_v2 = get_schema_view( + api_info, + public=True, + permission_classes=(permissions.AllowAny,), + generator_class=CustomGeneratorClassV2, +) -urlpatterns = [ - url(r'^(?P[^/]+)/models/(?P\d+)/settings/', model_settings, name='model-settings'), - url(r'^(?P[^/]+)/analyses/(?P\d+)/settings/', analyses_settings, name='analysis-settings'), - url(r'^(?P\.json|\.yaml)$', schema_view.without_ui(cache_timeout=0), name='schema-json'), - url(r'^$', schema_view.with_ui('swagger', cache_timeout=0), name='schema-ui'), - url(r'^', include('src.server.oasisapi.auth.urls', namespace='auth')), - url(r'^healthcheck/$', HealthcheckView.as_view(), name='healthcheck'), - url(r'^oed_peril_codes/$', PerilcodesView.as_view(), name='perilcodes'), - url(r'^server_info/$', ServerInfoView.as_view(), name='serverinfo'), - url(r'^auth/', include('rest_framework.urls')), - url(r'^admin/', admin.site.urls), - url(r'^(?P[^/]+)/', include(api_router.urls)), - url(r'^(?P[^/]+)/', include(templates_router.urls)), +api_urlpatterns = [ + # Main Swagger page + url(r'^(?P\.json|\.yaml)$', schema_view_all.without_ui(cache_timeout=0), name='schema-json'), + url(r'^$', schema_view_all.with_ui('swagger', cache_timeout=0), name='schema-ui'), + # V1 only swagger endpoints + url(r'^v1/$', schema_view_v1.with_ui('swagger', cache_timeout=0), name='schema-ui-v1'), + url(r'^v1/(?P\.json|\.yaml)$', schema_view_v1.without_ui(cache_timeout=0), name='schema-json'), + # V2 only swagger endpoints + url(r'^v2/$', schema_view_v2.with_ui('swagger', cache_timeout=0), name='schema-ui-v2'), + url(r'^v2/(?P\.json|\.yaml)$', schema_view_v2.without_ui(cache_timeout=0), name='schema-json'), + # basic urls (auth, server info) + url(r'^', include('src.server.oasisapi.base_urls')), ] +api_urlpatterns += api_v1_urlpatterns +if not settings.DISABLE_V2_API: + api_urlpatterns += api_v2_urlpatterns -urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) +urlpatterns = static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) if settings.URL_SUB_PATH: urlpatterns += static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) - urlpatterns = [url(r'^api/', include(urlpatterns))] + urlpatterns + urlpatterns += [url(r'^api/', include(api_urlpatterns))] else: urlpatterns += static(settings.STATIC_DEBUG_URL, document_root=settings.STATIC_ROOT) + urlpatterns += [url(r'^', include(api_urlpatterns))] + if settings.DEBUG_TOOLBAR: urlpatterns.append(path('__debug__/', include(debug_toolbar.urls))) diff --git a/src/startup_tester.sh b/src/startup_tester.sh deleted file mode 100755 index d30ed19a7..000000000 --- a/src/startup_tester.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -host=server:8000 -SERVER_HEALTH=0 -echo "curl -s -X GET 'http://$host/healthcheck/'" - -until [ $SERVER_HEALTH -gt 0 ] ; do - SERVER_HEALTH=$(curl -s -X GET "http://$host/healthcheck/" -H "accept: application/json" | grep -c "OK") - - if [ "$SERVER_HEALTH" -lt 1 ]; then - >&2 echo "Waiting for Server" - sleep 2 - fi -done - -echo "Server is available - exit" -set -e - -if [ -z "${TEST_TIMEOUT}" ]; then - TEST_TIMEOUT=180 -fi - -timeout $TEST_TIMEOUT pytest -v -p no:django /home/worker/tests/integration/api_integration.py "$@" diff --git a/src/startup_tester_S3.sh b/src/startup_tester_S3.sh deleted file mode 100755 index 3aca45274..000000000 --- a/src/startup_tester_S3.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -host=server:8000 -SERVER_HEALTH=0 -BUCKET_UP=0 -echo "curl -s -X GET 'http://$host/healthcheck/'" - -# Wait for SERVER -until [ $SERVER_HEALTH -gt 0 ] ; do - SERVER_HEALTH=$(curl -s -X GET "http://$host/healthcheck/" -H "accept: application/json" | grep -c "OK") - if [ "$SERVER_HEALTH" -lt 1 ]; then - >&2 echo "Waiting for Server" - sleep 2 - fi -done - -# Wait for LOCALSTACK -until [ $BUCKET_UP -gt 0 ] ; do - BUCKET_UP=$(curl -sI "http://localstack-s3:4572/example-bucket" | grep -c "OK") - if [ "$BUCKET_UP" -lt 1 ]; then - >&2 echo "Waiting for LocalStack S3 bucket" - sleep 2 - fi -done - -echo "Server + Localstack are available - exit" -set -e - -if [ -z "${TEST_TIMEOUT}" ]; then - TEST_TIMEOUT=180 -fi - -timeout $TEST_TIMEOUT pytest -v -p no:django /home/worker/tests/integration/api_integration.py "$@" diff --git a/src/startup_worker.sh b/src/startup_worker.sh index 98c673058..f73db23c2 100755 --- a/src/startup_worker.sh +++ b/src/startup_worker.sh @@ -7,21 +7,15 @@ export PYTHONPATH=$SCRIPT_DIR # Set the ini file path export OASIS_INI_PATH="${SCRIPT_DIR}/conf.ini" + # Delete celeryd.pid file - fix que pickup issues on reboot of server rm -f /home/worker/celeryd.pid -# ---- needs to be one or the other?? --- - # OLD - wait for it (test and remove) - #./src/utils/wait-for-it.sh "$OASIS_RABBIT_HOST:$OASIS_RABBIT_PORT" -t 60 - # use for REDIS - ./src/utils/wait-for-it.sh "$OASIS_CELERY_BROKER_URL" -t 60 - +# Check connectivity +./src/utils/wait-for-it.sh "$OASIS_CELERY_BROKER_URL" -t 60 ./src/utils/wait-for-it.sh "$OASIS_CELERY_DB_HOST:$OASIS_CELERY_DB_PORT" -t 60 -# Start current worker on init -# celery --app src.model_execution_worker.tasks worker --concurrency=1 --loglevel=INFO -Q "${OASIS_MODEL_SUPPLIER_ID}-${OASIS_MODEL_ID}-${OASIS_MODEL_VERSION_ID}" |& tee -a /var/log/oasis/worker.log - # set concurrency flag if [ -z "$OASIS_CELERY_CONCURRENCY" ] @@ -31,5 +25,26 @@ else WORKER_CONCURRENCY='--concurrency '$OASIS_CELERY_CONCURRENCY fi + +# Oasis select API version +SELECT_RUN_MODE=$(echo "$OASIS_RUN_MODE" | tr '[:upper:]' '[:lower:]') +case "$SELECT_RUN_MODE" in + "v1") + API_VER='' + TASK_FILE='src.model_execution_worker.tasks' + ;; + "v2") + API_VER='-v2' + TASK_FILE='src.model_execution_worker.distributed_tasks' + ;; + *) + echo "Invalid value for api version:" + echo " Set 'OASIS_RUN_MODE=v1' For Single server execution" + echo " Set 'OASIS_RUN_MODE=v2' For Distributed execution" + exit 1 + ;; +esac + + # Start new worker on init -celery --app src.model_execution_worker.distributed_tasks worker $WORKER_CONCURRENCY --loglevel=INFO -Q "${OASIS_MODEL_SUPPLIER_ID}-${OASIS_MODEL_ID}-${OASIS_MODEL_VERSION_ID}" ${OASIS_CELERY_EXTRA_ARGS} |& tee -a /var/log/oasis/worker.log +celery --app $TASK_FILE worker $WORKER_CONCURRENCY --loglevel=INFO -Q "${OASIS_MODEL_SUPPLIER_ID}-${OASIS_MODEL_ID}-${OASIS_MODEL_VERSION_ID}${API_VER}" ${OASIS_CELERY_EXTRA_ARGS} |& tee -a /var/log/oasis/worker.log diff --git a/tests/base.py b/tests/base.py deleted file mode 100644 index 4ddf822bf..000000000 --- a/tests/base.py +++ /dev/null @@ -1,26 +0,0 @@ -import string -from random import choice -from unittest import TestCase - -import os - -from src.server import app -from src.conf.iniconf import settings - - -class AppTestCase(TestCase): - def setUp(self): - self._init_testing = app.APP.config['TESTING'] - self.app = app.APP.test_client() - - def tearDown(self): - app.APP.config['TESTING'] = self._init_testing - - def create_input_file(self, path, size_in_bytes=None, data=b''): - path = os.path.join(settings.get('server', 'INPUTS_DATA_DIRECTORY'), path) - - with open(path, 'wb') as outfile: - for x in range(size_in_bytes or 0): - outfile.write(choice(string.ascii_letters).encode()) - - outfile.write(data) diff --git a/tests/integration/admin_auth.json b/tests/integration/admin_auth.json deleted file mode 100644 index ffbcaafa9..000000000 --- a/tests/integration/admin_auth.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "username": "admin", - "password": "password" -} diff --git a/tests/integration/api_integration.py b/tests/integration/api_integration.py deleted file mode 100644 index 7adbbaf20..000000000 --- a/tests/integration/api_integration.py +++ /dev/null @@ -1,282 +0,0 @@ -import pytest -import socket -import os -import tarfile -import configparser - -import pandas as pd - -from pandas.util.testing import assert_frame_equal - -from oasislmf.api.client import APIClient - - -# ------------ load command line optios -------------------- # -cli_test_conf = os.environ.get('PY_CONFIG', '/var/oasis/test/conf.ini') -cli_test_output = True if os.environ.get('PY_TEST_OUTPUT') else False -cli_test_case = os.environ.get('PY_TEST_CASE').split(' ') if os.environ.get('PY_TEST_CASE') else None -cli_test_model = os.environ.get('PY_TEST_MODEL') if os.environ.get('PY_TEST_MODEL') else None -cli_test_retry = int(os.environ.get('PY_TEST_RETRY')) if os.environ.get('PY_TEST_MODEL') else 1 - -config = configparser.ConfigParser() -config.read(os.path.abspath(cli_test_conf)) - - -def get_path(section, var, config=config): - try: - return os.path.abspath(config.get(section, var)) - except configparser.NoOptionError: - return None - - -def check_expected(result_path, expected_path): - comparison_list = [] - cwd = os.getcwd() - os.chdir(expected_path) - for rootdir, _, filelist in os.walk('.'): - for f in filelist: - comparison_list.append(os.path.join(rootdir[2:], f)) - - print(comparison_list) - os.chdir(cwd) - for csv in comparison_list: - print(csv) - df_expect = pd.read_csv(os.path.join(expected_path, csv)) - df_found = pd.read_csv(os.path.join(result_path, csv)) - assert_frame_equal(df_expect, df_found) - - -def check_non_empty(result_path): - comparison_list = [] - cwd = os.getcwd() - os.chdir(result_path) - for rootdir, _, filelist in os.walk('.'): - for f in filelist: - comparison_list.append(os.path.join(rootdir[2:], f)) - - print(comparison_list) - os.chdir(cwd) - for csv in comparison_list: - file_path = os.path.join(result_path, csv) - file_size = os.path.getsize(file_path) - print(f'{file_size} Bytes: -> {csv}') - assert (file_size > 0) - - -# --- Test Paramatization --------------------------------------------------- # - - -if cli_test_model: - test_model = cli_test_model -else: - test_model = config.get('default', 'TEST_MODEL').lower() - -if cli_test_case: - test_cases = cli_test_case - print('Loading test cases from command line args:') -else: - test_cases = config.get(test_model, 'RUN_TEST_CASES').split(' ') - print('Load default test cases from default conf.in:') - -base_dir = os.path.dirname(os.path.abspath(cli_test_conf)) -os.chdir(base_dir) - - -# --- API connection Fixture ------------------------------------------------ # - - -@pytest.fixture(scope="module", params=test_cases) -def session_fixture(request): - server_addr = config.get('server', 'API_HOST') - server_port = config.get('server', 'API_PORT') - server_vers = config.get('server', 'API_VERS') - server_user = config.get('server', 'API_USER') - server_pass = config.get('server', 'API_PASS') - - print(request.param) - try: - server_url = 'http://{}:{}'.format(socket.gethostbyname(server_addr), server_port) - except Exception: - server_url = 'http://{}:{}'.format('localhost', server_port) - session = APIClient(server_url, server_vers, server_user, server_pass) - - print(session.api.tkn_access) - return request.param, session - - -# --- Test Case Fixture ----------------------------------------------------- # - - -@pytest.fixture(scope="module") -def case_fixture(session_fixture): - case, session = session_fixture - ids = {} - - # Add or find model - _model = { - 'supplier_id': config.get(test_model, 'SUPPLIER_ID'), - 'model_id': config.get(test_model, 'MODEL_ID'), - 'version_id': config.get(test_model, 'VERSION_ID'), - } - - r_model = session.models.search(_model) - if len(r_model.json()) < 1: - # Model not found - Add new model - r_model = session.models.create(**_model) - ids['model'] = r_model.json()['id'] - else: - # Model found - Use result of search - ids['model'] = r_model.json()[0]['id'] - - # Create Portfolio - loc_fp = get_path(f'{test_model}.{case}', 'LOC_FILE') - print(loc_fp) - assert os.path.isfile(loc_fp) - acc_fp = get_path(f'{test_model}.{case}', 'ACC_FILE') - inf_fp = get_path(f'{test_model}.{case}', 'INF_FILE') - scp_fp = get_path(f'{test_model}.{case}', 'SCP_FILE') - - r_portfolio = session.upload_inputs( - portfolio_name='Integration_test_{}_{}'.format(test_model, case), - location_fp=loc_fp, - accounts_fp=acc_fp, - ri_info_fp=inf_fp, - ri_scope_fp=scp_fp) - ids['portfolio'] = r_portfolio['id'] - - # Create analysis - settings_fp = get_path(f'{test_model}.{case}', 'SETTINGS_RUN') - assert os.path.isfile(settings_fp) - r_analysis = session.create_analysis( - analysis_name='Integration_test_{}_{}'.format(test_model, case), - portfolio_id=ids['portfolio'], - model_id=ids['model']) - - ids['analysis'] = r_analysis['id'] - r_upload_settings = session.analyses.settings_file.upload(ids['analysis'], settings_fp, 'application/json') - assert r_upload_settings.ok - - return session, case, ids - -# --- Test Fucntions -------------------------------------------------------- # - - -def test_connection(case_fixture): - session, case, ids = case_fixture - assert session.api.health_check().ok - - -def test_uploaded(case_fixture): - session, case, ids = case_fixture - analysis = session.analyses.get(ids['analysis']) - portfolio = session.portfolios.get(ids['portfolio']) - - assert portfolio.ok - assert analysis.ok - assert analysis.json()['status'] == 'NEW' - print(analysis.json()) - print(portfolio.json()) - - -def test_generate(case_fixture): - session, case, ids = case_fixture - analysis = session.analyses.get(ids['analysis']) - if analysis.json()['status'] not in ['NEW', 'INPUTS_GENERATION_ERROR']: - pytest.skip('setup error in prevous step') - - for r in range(cli_test_retry): - print(f'Attempt: {r+1}') - session.run_generate(ids['analysis']) - analysis = session.analyses.get(ids['analysis']) - if analysis.json()['status'] == 'READY': - break - - assert analysis.ok - assert analysis.json()['status'] == 'READY' - - -def test_generated_files(case_fixture): - session, case, ids = case_fixture - analysis = session.analyses.get(ids['analysis']) - if analysis.json()['status'] not in ['READY']: - pytest.skip('Error in file Generation step') - - output_dir = os.path.abspath(config.get('default', 'TEST_OUTPUT_DIR')) - download_to = '{0}/{1}_input.tar.gz'.format(output_dir, case, ids['analysis']) - extract_to = os.path.join(output_dir, case, 'input') - - if os.path.isfile(download_to): - os.remove(download_to) - r = session.analyses.input_file.download(ids['analysis'], download_to) - assert r.ok - - tar_object = tarfile.open(download_to) - csv_only = [f for f in tar_object.getmembers() if '.csv' in f.name] - tar_object.extractall(path=extract_to, members=csv_only) - tar_object.close() - if os.path.isfile(download_to): - os.remove(download_to) - - -def test_analysis_run(case_fixture): - session, case, ids = case_fixture - analysis = session.analyses.get(ids['analysis']) - if analysis.json()['status'] not in ['READY', 'RUN_ERROR']: - pytest.skip('Error in file Generation step') - - for r in range(cli_test_retry): - print(f'Attempt: {r+1}') - session.run_analysis(ids['analysis']) - analysis = session.analyses.get(ids['analysis']) - if analysis.json()['status'] == 'RUN_COMPLETED': - break - - assert analysis.ok - assert analysis.json()['status'] == 'RUN_COMPLETED' - - -def test_analysis_output(case_fixture): - session, case, ids = case_fixture - analysis = session.analyses.get(ids['analysis']) - if analysis.json()['status'] not in ['RUN_COMPLETED']: - pytest.skip('Error in file Generation step') - - if not get_path(test_model, 'EXPECTED_OUTPUT_DIR'): - pytest.skip('Expected data missing') - - expected_results = os.path.join(get_path(test_model, 'EXPECTED_OUTPUT_DIR'), case, 'output') - output_dir = os.path.abspath(config.get('default', 'TEST_OUTPUT_DIR')) - download_to = '{0}/{1}_output.tar.gz'.format(output_dir, case, ids['analysis']) - extract_to = os.path.join(output_dir, case) - - if os.path.isfile(download_to): - os.remove(download_to) - r = session.analyses.output_file.download(ids['analysis'], download_to) - assert r.ok - - tar_object = tarfile.open(download_to) - csv_only = [f for f in tar_object.getmembers() if '.csv' in f.name] - - tar_object.extractall(path=extract_to, members=csv_only) - tar_object.close() - - if cli_test_output: - check_expected(os.path.join(extract_to, 'output'), expected_results) - else: - check_non_empty(os.path.join(extract_to, 'output')) - - if os.path.isfile(download_to): - os.remove(download_to) - - -def test_cleanup(case_fixture): - if not config.getboolean('default', 'CLEAN_UP'): - pytest.skip('Skip clean up') - - session, case, ids = case_fixture - r_del_analyses = session.analyses.delete(ids['analysis']) - r_del_portfolios = session.portfolios.delete(ids['portfolio']) - session.api.close() - - assert r_del_analyses.ok - assert r_del_portfolios.ok diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py deleted file mode 100644 index 7d62950f5..000000000 --- a/tests/integration/conftest.py +++ /dev/null @@ -1,44 +0,0 @@ -# content of conftest.py -# import pytest - - -def pytest_addoption(parser): - parser.addoption( - '--test-case', required=False, dest='data_case', - action='store', nargs='+', default=None, - help='The directory of data files used to test a model.' - ) - parser.addoption( - '--test-output', required=False, dest='data_test', - action='store_true', default=False, - help='The directory of data files used to test a model.' - ) - parser.addoption( - '--test-model', required=False, dest='test-model', - action='store', nargs='?', default=None, - help='Model to test from `conf.ini`' - ) - parser.addoption( - '--test-retry', required=False, dest='test-retry', - action='store', nargs='?', default=1, type=int, - help='Number of times to re-submit API tests' - ) - parser.addoption( - '--config', required=False, dest='test_conf', - action='store', nargs='?', default='/var/oasis/test/conf.ini', - help='File path to test configuration ini' - ) - - -def pytest_configure(config): - import os - if config.getoption('--test-case'): - os.environ["PY_TEST_CASE"] = " ".join(config.getvalue('--test-case')) - if config.getoption('--test-output'): - os.environ["PY_TEST_OUTPUT"] = 'True' - if config.getoption('--test-model'): - os.environ["PY_TEST_MODEL"] = config.getvalue('--test-model') - if config.getoption('--test-retry'): - os.environ["PY_TEST_RETRY"] = str(config.getvalue('--test-retry')) - if config.getoption('--config'): - os.environ["PY_CONFIG"] = config.getvalue('--config') diff --git a/tox.ini b/tox.ini index 825a7b961..991483fda 100644 --- a/tox.ini +++ b/tox.ini @@ -4,8 +4,16 @@ skipsdist = True [testenv] deps = -rrequirements.txt -commands = pytest --junitxml={toxinidir}/pytest_report.xml --cov=src {posargs} +commands = pytest --junitxml={toxinidir}/pytest_report.xml --cov=src {posargs} setenv = COV_CORE_SOURCE={toxinidir}/src COV_CORE_CONFIG={toxinidir}/setup.cfg COVERAGE_FILE={toxinidir}/.coverage.{envname} + +[flake8] +per-file-ignores = + # imported but unused + src/server/oasisapi/settings/__init__.py: F401 + src/conf/celeryconf_v1.py: F401 + src/conf/celeryconf_v2.py: F401 + src/server/oasisapi/settings/base.py: F401