diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..cda19c3 --- /dev/null +++ b/.env.example @@ -0,0 +1,24 @@ +# Tiled setup +TILED_KEY= +DEFAULT_TILED_URI=https://tiled-demo.blueskyproject.io/api/v1/metadata/rsoxs/raw/ +DEFAULT_TILED_QUERY="/primary/data/Small Angle CCD Detector_image" + +# Directory setup +DATA_DIR=/path/to/read/data + +# Services URLs +# If running in docker +# SPLASH_URL=http://splash:80/api/v0 +# MLEX_COMPUTE_URL=http://job-service:8080/api/v0 +# MLEX_CONTENT_URL=http://content-api:8000/api/v0 +# MONGO_DB_USERNAME= +# MONGO_DB_PASSWORD= +# If running locally +SPLASH_URL=http://localhost:8087/api/v0 +MLEX_COMPUTE_URL=http://localhost:8080/api/v0 +MLEX_CONTENT_URL=http://localhost:8001/api/v0 +HOST_NICKNAME=local + +# Static Tiled setup [Optional] +STATIC_TILED_URI= +STATIC_TILED_API_KEY= diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..a9f89c5 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +# 127 is width of the Github code viewer, +# black default is 88 so this will only warn about comments >127 +max-line-length = 127 +# Ignore errors due to incompatibility with black +#https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html +extend-ignore = E203,E701 diff --git a/.github/workflows/publish_docker.yml b/.github/workflows/publish_docker.yml new file mode 100644 index 0000000..0210af4 --- /dev/null +++ b/.github/workflows/publish_docker.yml @@ -0,0 +1,49 @@ +name: Docker + +on: + push: + branches: [ "main" ] + tags: [ 'v*.*.*' ] + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + + +jobs: + build-and-push-image: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Log in to the Container registry + uses: docker/login-action@v2 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v4 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + # Build and push Docker image with Buildx (don't push on PR) + # https://github.com/docker/build-push-action + - name: Build and push Docker image + id: build-and-push + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile + push: true + platforms: linux/amd64,linux/arm64 + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml new file mode 100644 index 0000000..0280bd0 --- /dev/null +++ b/.github/workflows/test-lint.yml @@ -0,0 +1,31 @@ +name: test-and-lint + +on: pull_request + +jobs: + test-and-lint: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Set up Python 3.11.8 + uses: actions/setup-python@v3 + with: + python-version: '3.11.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Run isort + uses: isort/isort-action@master + - name: Test formatting with black + run: | + black . --check diff --git a/.gitignore b/.gitignore index cf53f0e..15bfaab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,171 @@ -/data/* -*pycache* -.DS_Store -.keras -.Makefile.swp -._* +# Docker override +docker-compose.override.y*ml +database/* + +# Byte-compiled / optimized / DLL files +**pycache** +**cache** +*.py[cod] +*$py.class + +# C extensions +*.so + +# MacOS +**.DS_Store** + +# Project +data/* + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6df81cc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,34 @@ +default_language_version: + python: python3 +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-ast + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + - repo: https://github.com/gitguardian/ggshield + rev: v1.25.0 + hooks: + - id: ggshield + language_version: python3 + stages: [commit] + # Using this mirror lets us use mypyc-compiled black, which is about 2x faster + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.2.0 + hooks: + - id: black + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black"] diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..9fe21fc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.11 +MAINTAINER THE MLEXCHANGE TEAM + +RUN ls +COPY pyproject.toml pyproject.toml +COPY README.md README.md + +RUN pip install --upgrade pip &&\ + pip install . + +WORKDIR /app/work +ENV HOME /app/work +COPY src src +COPY frontend.py frontend.py +COPY gunicorn_config.py gunicorn_config.py + +CMD ["bash"] +CMD gunicorn -c gunicorn_config.py --reload frontend:server diff --git a/README.md b/README.md index 5498d22..23d8ed8 100644 --- a/README.md +++ b/README.md @@ -3,19 +3,25 @@ This app provides a training/testing platform for image classification with supervised deep-learning approaches. ## Running as a standalone application -First, let's install docker: -* https://docs.docker.com/engine/install/ +1. Start the compute and content services in the [MLExchange platform](https://github.com/mlexchange/mlex). Before moving to the next step, please make sure that the computing API and the content registry are up and running. For more information, please refer to their respective +README files. -Next, let's setup its dependencies: -* [mlex_computing_api](https://github.com/mlexchange/mlex_computing_api) -* [mlex_content_registry](https://github.com/mlexchange/mlex_content_registry) +2. Start [splash-ml](https://github.com/als-computing/splash-ml) -Before moving to the next step, please make sure that the computing API and the content -registry are up and running. For more information, please refer to their respective -README files. -* Next, cd into mlex_mlcoach -* type `docker-compose up --build` into your terminal +2. Create a new Python environment and install dependencies: +``` +conda create -n new_env python==3.11 +conda activate new_env +pip install . +``` + +3. Create a `.env` file using `.env.example` as reference. Update this file accordingly. + +4. Start example app: +``` +python frontend.py +``` Finally, you can access MLCoach at: * Dash app: http://localhost:8062/ @@ -26,10 +32,12 @@ to use this application. # Model Description **TF-NeuralNetworks:** Assortment of neural networks implemented in [TensorFlow](https://www.tensorflow.org). -Further information can be found in [concepts](/docs/concepts.md). +Further information can be found in [mlex_image_classification](https://github.com/mlexchange/mlex_image_classification). + +To make existing algorithms available in MLCoach, make sure to upload the `model description` to the content registry. # Copyright -MLExchange Copyright (c) 2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. +MLExchange Copyright (c) 2024, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Intellectual Property Office at IPO@lbl.gov. diff --git a/docker-compose.yml b/docker-compose.yml index c14881d..64c3257 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,21 +1,62 @@ -version: "3" +version: '3.7' services: - front-end: + mlcoach: restart: "unless-stopped" container_name: "mlcoach" - build: - context: "." - dockerfile: "docker/Dockerfile" + image: ghcr.io/mlexchange/mlex_mlcoach:main + volumes: + - $DATA_DIR:/app/work/data + environment: + DIR_MOUNT: "${DATA_DIR}" + DATA_DIR: "/app/work/data" + PROJECT_ID: "-1" + SPLASH_URL: "http://splash:80/api/v0" + MLEX_COMPUTE_URL: "http://job-service:8080/api/v0" + MLEX_CONTENT_URL: "http://content-api:8000/api/v0" + HOST_NICKNAME: "${HOST_NICKNAME}" + TILED_KEY: "${TILED_KEY}" + DEFAULT_TILED_URI: "${DEFAULT_TILED_URI}" + DEFAULT_TILED_SUB_URI: "${DEFAULT_TILED_SUB_URI}" + APP_HOST: "0.0.0.0" + APP_PORT: "8050" + ports: + - 127.0.0.1:8062:8050 + depends_on: + - splash + networks: + - computing_api_default + + splash_db: + image: mongo:4.4 + container_name: splash_db + working_dir: /data/db environment: - DATA_DIR: "${PWD}/data/" + MONGO_INITDB_ROOT_USERNAME: '${MONGO_DB_USERNAME}' + MONGO_INITDB_ROOT_PASSWORD: '${MONGO_DB_PASSWORD}' volumes: - - ./data:/app/work/data + - "${PWD}/database/:/data/db" + networks: + - computing_api_default + + splash: + image: ghcr.io/als-computing/splash-ml:master + container_name: splash + environment: + APP_MODULE: "tagging.api:app" + LOGLEVEL: DEBUG + MONGO_DB_URI: "mongodb://${MONGO_DB_USERNAME}:${MONGO_DB_PASSWORD}@splash_db:27017" + MAX_WORKERS: 1 ports: - - "8062:8062" + - 127.0.0.1:8087:80 + depends_on: + - splash_db networks: - computing_api_default + classifier: + image: ghcr.io/mlexchange/mlex_image_classification:main + networks: computing_api_default: external: true diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 2694d08..0000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,26 +0,0 @@ -FROM python:3.9 -MAINTAINER THE MLEXCHANGE TEAM - -RUN ls -COPY docker/requirements.txt requirements.txt - -RUN apt-get update && apt-get install -y \ - build-essential \ - wget \ - python3-pip\ - ffmpeg\ - libsm6\ - libxext6 - -RUN pip3 install --upgrade pip &&\ - pip3 install --timeout=2000 -r requirements.txt - -EXPOSE 8000 - -WORKDIR /app/work -ENV HOME /app/work -COPY src src -ENV PYTHONUNBUFFERED=1 - -CMD ["bash"] -CMD python3 src/frontend.py diff --git a/docker/requirements.txt b/docker/requirements.txt deleted file mode 100644 index 43ddfc6..0000000 --- a/docker/requirements.txt +++ /dev/null @@ -1,20 +0,0 @@ -config==0.5.1 -dash==1.21.0 -dash-auth==1.3.2 -dash_bootstrap_components==0.13.1 -dash_core_components==1.17.1 -dash_daq==0.5.0 -dash_html_components==1.1.4 -dash_table==4.12.0 -dash-uploader==0.6.0 -flask==2.0.2 -imageio==2.9.0 -keras==2.6.0 -numpy==1.19.5 -pandas==1.3.4 -Pillow==8.3.2 -plotly==5.4.0 -plotly-express==0.4.1 -scikit_image==0.18.3 -requests==2.26.0 -werkzeug==2.0.0 diff --git a/frontend.py b/frontend.py new file mode 100644 index 0000000..ffffda6 --- /dev/null +++ b/frontend.py @@ -0,0 +1,250 @@ +import json +import os +import pathlib +import pickle +import shutil +import tempfile +from uuid import uuid4 + +from dash import ClientsideFunction, Input, Output, State, dcc +from dash_component_editor import JSONParameterEditor +from file_manager.data_project import DataProject + +from src.app_layout import DATA_DIR, TILED_KEY, USER, app, long_callback_manager +from src.callbacks.display import ( # noqa: F401 + close_warning_modal, + open_warning_modal, + refresh_image, + refresh_label, + refresh_results, + update_slider_boundaries_new_dataset, + update_slider_boundaries_prediction, + update_slider_value, +) +from src.callbacks.download import disable_download, toggle_storage_modal # noqa: F401 +from src.callbacks.execute import close_resources_popup, execute # noqa: F401 +from src.callbacks.load_labels import load_from_splash_modal # noqa: F401 +from src.callbacks.table import delete_row, open_job_modal, update_table # noqa: F401 +from src.utils.data_utils import get_input_params, prepare_directories +from src.utils.job_utils import MlexJob +from src.utils.model_utils import get_gui_components, get_model_content + +APP_HOST = os.getenv("APP_HOST", "127.0.0.1") +APP_PORT = os.getenv("APP_PORT", "8062") +DIR_MOUNT = os.getenv("DIR_MOUNT", DATA_DIR) + + +server = app.server + +app.clientside_callback( + ClientsideFunction(namespace="clientside", function_name="transform_image"), + Output("img-output", "src"), + Input("log-transform", "on"), + Input("img-output-store", "data"), + prevent_initial_call=True, +) + + +app.clientside_callback( + """ + function(n) { + if (typeof Intl === 'object' && typeof Intl.DateTimeFormat === 'function') { + const dtf = Intl.DateTimeFormat(); + if (typeof dtf === 'object' && typeof dtf.resolvedOptions === 'function') { + const ro = dtf.resolvedOptions(); + if (typeof ro === 'object' && typeof ro.timeZone === 'string') { + return ro.timeZone; + } + } + } + return 'Timezone information not available'; + } + """, + Output("timezone-browser", "value"), + Input("interval", "n_intervals"), +) + + +@app.callback( + Output("app-parameters", "children"), + Input("model-selection", "value"), + Input("action", "value"), + prevent_intial_call=True, +) +def load_parameters(model_selection, action_selection): + """ + This callback dynamically populates the parameters and contents of the website according to the + selected action & model. + Args: + model_selection: Selected model (from content registry) + action_selection: Selected action (pre-defined actions in Data Clinic) + Returns: + app-parameters: Parameters according to the selected model & action + """ + parameters = get_gui_components(model_selection, action_selection) + gui_item = JSONParameterEditor( + _id={"type": str(uuid4())}, # pattern match _id (base id), name + json_blob=parameters, + ) + gui_item.init_callbacks(app) + return gui_item + + +@app.long_callback( + Output("download-out", "data"), + Input("download-button", "n_clicks"), + State("jobs-table", "data"), + State("jobs-table", "selected_rows"), + manager=long_callback_manager, + prevent_intial_call=True, +) +def save_results(download, job_data, row): + """ + This callback saves the experimental results as a ZIP file + Args: + download: Download button + job_data: Table of jobs + row: Selected job/row + Returns: + ZIP file with results + """ + if download and row: + experiment_id = job_data[row[0]]["experiment_id"] + experiment_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{USER}/{experiment_id}") + with tempfile.TemporaryDirectory(): + tmp_dir = tempfile.gettempdir() + archive_path = os.path.join(tmp_dir, "results") + shutil.make_archive(archive_path, "zip", experiment_path) + return dcc.send_file(f"{archive_path}.zip") + else: + return None + + +@app.long_callback( + Output("job-alert-confirm", "is_open"), + Input("submit", "n_clicks"), + State("app-parameters", "children"), + State("num-cpus", "value"), + State("num-gpus", "value"), + State("action", "value"), + State("jobs-table", "data"), + State("jobs-table", "selected_rows"), + State({"base_id": "file-manager", "name": "data-project-dict"}, "data"), + State("model-name", "value"), + State("event-id", "value"), + State("model-selection", "value"), + State("log-transform", "on"), + State("img-labeled-indx", "options"), + running=[(Output("job-alert", "is_open"), "True", "False")], + manager=long_callback_manager, + prevent_initial_call=True, +) +def submit_ml_job( + submit, + children, + num_cpus, + num_gpus, + action_selection, + job_data, + row, + data_project_dict, + model_name, + event_id, + model_id, + log, + labeled_dropdown, +): + """ + This callback submits a job request to the compute service according to the selected action & model + Args: + submit: Submit button + children: Model parameters + num_cpus: Number of CPUs assigned to job + num_gpus: Number of GPUs assigned to job + action_selection: Action selected + job_data: Lists of jobs + row: Selected row (job) + data_project_dict: Data project dictionary + model_name: User-defined name for training or prediction model + event_id: Tagging event id for version control of tags + model_id: UID of model in content registry + log: Log toggle + labeled_dropdown: Indexes of the labeled images in this data set + Returns: + open the alert indicating that the job was submitted + """ + # Get model information from content registry + model_uri, [train_cmd, prediction_cmd] = get_model_content(model_id) + + # Get model parameters + input_params = get_input_params(children) + input_params["log"] = log + + kwargs = {} + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + + if action_selection == "train_model": + experiment_id, orig_out_path, data_info = prepare_directories( + USER, + data_project, + labeled_indices=labeled_dropdown, + correct_path=(DATA_DIR == DIR_MOUNT), + ) + # Find the relative data directory in docker container + if DIR_MOUNT == DATA_DIR: + relative_data_dir = "/app/work/data" + out_path = "/app/work/data" + str(orig_out_path).split(DATA_DIR, 1)[-1] + data_info = "/app/work/data" + str(data_info).split(DATA_DIR, 1)[-1] + else: + relative_data_dir = DATA_DIR + command = f"{train_cmd} -d {data_info} -o {out_path} -e {event_id}" + else: + experiment_id, orig_out_path, data_info = prepare_directories( + USER, data_project, train=False, correct_path=(DATA_DIR == DIR_MOUNT) + ) + # Find the relative data directory in docker container + if DIR_MOUNT == DATA_DIR: + relative_data_dir = "/app/work/data" + out_path = "/app/work/data" + str(orig_out_path).split(DATA_DIR, 1)[-1] + data_info = "/app/work/data" + str(data_info).split(DATA_DIR, 1)[-1] + else: + relative_data_dir = DATA_DIR + training_exp_id = job_data[row[0]]["experiment_id"] + model_path = pathlib.Path( + f"{relative_data_dir}/mlex_store/{USER}/{training_exp_id}" + ) + command = f"{prediction_cmd} -d {data_info} -m {model_path} -o {out_path}" + kwargs = {"train_params": job_data[row[0]]["parameters"]} + + with open(f"{orig_out_path}/.file_manager_vars.pkl", "wb") as file: + pickle.dump( + data_project_dict, + file, + ) + + # Define MLExjob + job = MlexJob( + service_type="backend", + description=model_name, + working_directory="{}".format(DIR_MOUNT), + job_kwargs={ + "uri": model_uri, + "type": "docker", + "cmd": f"{command} -p '{json.dumps(input_params)}'", + "kwargs": { + "job_type": action_selection, + "experiment_id": experiment_id, + "dataset": data_project.project_id, + "params": input_params, + **kwargs, + }, + }, + ) + + # Submit job + job.submit(USER, num_cpus, num_gpus) + return True + + +if __name__ == "__main__": + app.run_server(debug=True, host=APP_HOST, port=APP_PORT) diff --git a/gunicorn_config.py b/gunicorn_config.py new file mode 100644 index 0000000..5a13d37 --- /dev/null +++ b/gunicorn_config.py @@ -0,0 +1,13 @@ +import os + +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +APP_HOST = os.getenv("APP_HOST", "127.0.0.1") +APP_PORT = os.getenv("APP_PORT", "8062") + +# Gunicorn configuration +bind = f"{APP_HOST}:{APP_PORT}" +workers = 4 diff --git a/models/Makefile b/models/Makefile deleted file mode 100644 index dc2cb1c..0000000 --- a/models/Makefile +++ /dev/null @@ -1,59 +0,0 @@ -TAG := latest -USER := mlexchange1 -PROJECT := tensorflow-neural-networks -DATA_PATH := /data/tanchavez/Datasets/born -DATA_PATH_NPZ := /data/tanchavez/Datasets - -IMG_WEB_SVC := ${USER}/${PROJECT}:${TAG} -IMG_WEB_SVC_JYP := ${USER}/${PROJECT_JYP}:${TAG} -ID_USER := ${shell id -u} -ID_GROUP := ${shell id -g} - -.PHONY: - -test: - echo ${IMG_WEB_SVC} - echo ${TAG} - echo ${PROJECT} - echo ${PROJECT}:${TAG} - echo ${ID_USER} - -build_docker: - docker build -t ${IMG_WEB_SVC} -f ./docker/Dockerfile . - -build_docker_gpu: - docker build -t ${IMG_WEB_SVC} -f ./docker/Dockerfile_gpu . - -build_docker_arm64: - docker build -t ${IMG_WEB_SVC} -f ./docker/Dockerfile_arm64 . - -run_docker: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it -v ${PWD}:/app/work/ -v ${DATA_PATH}:/app/work/data -p 8888:8888 ${IMG_WEB_SVC} - -train_example_gpu: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it --gpus all -v ${DATA_PATH}:/app/work/data ${IMG_WEB_SVC} python3 src/train_model.py data/train data/logs '{"rotation_angle": 0, "image_flip": "None", "batch_size": 32, "val_pct": 0, "shuffle": true, "seed": 45, "pooling": "None", "optimizer": "Adam", "loss_function": "categorical_crossentropy", "learning_rate": 0.01, "epochs": 3, "nn_model": "ResNet50"}' - -train_example_cpu: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it -v ${DATA_PATH}:/app/work/data ${IMG_WEB_SVC} python3 src/train_model.py data/train data/logs '{"rotation_angle": 0, "image_flip": "None", "batch_size": 32, "val_pct": 10, "shuffle": true, "seed": 45, "pooling": "None", "optimizer": "Adam", "loss_function": "categorical_crossentropy", "learning_rate": 0.01, "epochs": 3, "nn_model": "ResNet50"}' - -train_example_npz: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it --gpus all -v ${DATA_PATH_NPZ}:/app/work/data ${IMG_WEB_SVC} python3 src/train_model.py data/test.npz data/logs '{"rotation_angle": 0, "image_flip": "None", "batch_size": 32, "val_pct": 10, "shuffle": true, "seed": 45, "pooling": "None", "optimizer": "Adam", "loss_function": "categorical_crossentropy", "learning_rate": 0.01, "epochs": 3, "nn_model": "ResNet50", "x_key": "x_train", "y_key": "y_train"}' - -evaluate_example: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it -v ${DATA_PATH}:/app/work/data ${IMG_WEB_SVC} python3 src/evaluate_model.py data/test data/logs/model.h5 '{"rotation_angle": 0, "image_flip": "None", "batch_size": 1}' - -predict_example: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it -v ${DATA_PATH}:/app/work/data ${IMG_WEB_SVC} python3 src/predict_model.py data/test data/logs/model.h5 data/logs '{"rotation_angle": 0, "image_flip": "None", "batch_size": 32}' - -transfer_example: - docker run -u ${ID_USER $USER}:${ID_GROUP $USER} --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --memory-swap -1 -it --gpus all -v ${DATA_PATH}:/app/work/data ${IMG_WEB_SVC} python3 src/transfer_learning.py data/train data/logs/model.h5 data/logs '{"rotation_angle": 0, "image_flip": "None", "batch_size": 32, "val_pct": 20, "shuffle": true, "seed": 45, "pooling": "None","optimizer": "Adam", "loss_function": "categorical_crossentropy", "learning_rate": 0.01, "epochs": 3, "init_layer": 100}' - - -clean: - find -name "*~" -delete - -rm .python_history - -rm -rf .config - -rm -rf .cache - -push_docker: - docker push ${IMG_WEB_SVC} diff --git a/models/Tensorflow-NN_v1.0.0-2.json b/models/Tensorflow-NN_v1.0.0-2.json deleted file mode 100644 index fa7c524..0000000 --- a/models/Tensorflow-NN_v1.0.0-2.json +++ /dev/null @@ -1,397 +0,0 @@ -{ - "name": "TF-NeuralNetworks", - "version": "1.0.1", - "type": "supervised", - "user": "mlexchange team", - "uri": "mlexchange/mlcoach", - "application": ["mlcoach"], - "description": "Tensorflow neural networks for image classification", - "gui_parameters": [ - { - "type": "slider", - "name": "rotation_angle", - "title": "Rotation Angle", - "param_key": "rotation_angle", - "min": 0, - "max": 360, - "value": 0, - "comp_group": "train_model" - }, - { - "type": "radio", - "name": "image_flip", - "title": "Image Flip", - "param_key": "image_flip", - "value": "None", - "options": - [ - {"label": "None", "value": "None"}, - {"label": "Vertical", "value": "Vertical"}, - {"label": "Horizontal", "value": "Horizontal"}, - {"label": "Both", "value": "Both"} - ], - "comp_group": "train_model" - }, - { - "type": "radio", - "name": "shuffle", - "title": "Shuffle Data", - "param_key": "shuffle", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, - { - "type": "slider", - "name": "batch_size", - "title": "Batch Size", - "param_key": "batch_size", - "min": 16, - "max": 128, - "step": 16, - "value": 32, - "comp_group": "train_model" - }, - { - "type": "slider", - "name": "val_pct", - "title": "Validation Percentage", - "param_key": "val_pct", - "min": 0, - "max": 100, - "step": 5, - "value": 20, - "marks": { - "0": "0", - "100": "100" - }, - "comp_group": "train_model" - }, - { - "type": "dropdown", - "name": "weights", - "title": "Weights", - "param_key": "weights", - "value": "None", - "options": [ - {"label": "None", "value": "None"}, - {"label": "imagenet", "value": "imagenet"} - ], - "comp_group": "train_model" - }, - { - "type": "dropdown", - "name": "optimizer", - "title": "Optimizer", - "param_key": "optimizer", - "value": "Adam", - "options": [ - {"label": "Adadelta", "value": "Adadelta"}, - {"label": "Adagrad", "value": "Adagrad"}, - {"label": "Adam", "value": "Adam"}, - {"label": "Adamax", "value": "Adamax"}, - {"label": "Ftrl", "value": "Ftrl"}, - {"label": "Nadam", "value": "Nadam"}, - {"label": "RMSprop", "value": "RMSprop"}, - {"label": "SGD", "value": "SGD"} - ], - "comp_group": "train_model" - }, - { - "type": "dropdown", - "name": "loss_function", - "title": "Loss Function", - "param_key": "loss_function", - "value": "categorical_crossentropy", - "options": [ - {"label": "BinaryCrossentropy", "value": "binary_crossentropy"}, - {"label": "BinaryFocalCrossentropy", "value": "binary_focal_crossentropy"}, - {"label": "CategoricalCrossentropy", "value": "categorical_crossentropy"}, - {"label": "CategoricalHinge", "value": "categorical_hinge"}, - {"label": "CosineSimilarity", "value": "cosine_similarity"}, - {"label": "Hinge", "value": "hinge"}, - {"label": "Huber", "value": "huber"}, - {"label": "LogCosh", "value": "log_cosh"}, - {"label": "KullbackLeiblerDivergence", "value": "kullback_leibler_divergence"}, - {"label": "MeanAbsoluteError", "value": "mean_absolute_error"}, - {"label": "MeanAbsolutePercentageError", "value": "mean_absolute_percentage_error"}, - {"label": "MeanSquaredError", "value": "mean_squared_error"}, - {"label": "MeanSquaredLogarithmicError", "value": "mean_squared_logarithmic_error"}, - {"label": "Poisson", "value": "poisson"}, - {"label": "SparseCategoricalCrossentropy", "value": "sparse_categorical_crossentropy"}, - {"label": "SquaredHinge", "value": "squared_hinge"} - ], - "comp_group": "train_model" - }, - { - "type": "float", - "name": "learning_rate", - "title": "Learning Rate", - "param_key": "learning_rate", - "value": 0.001, - "comp_group": "train_model" - }, - { - "type": "slider", - "name": "epochs", - "title": "Number of epoch", - "param_key": "epochs", - "min": 1, - "max": 1000, - "value": 30, - "comp_group": "train_model" - }, - { - "type": "dropdown", - "name": "nn_model", - "title": "ML Model", - "param_key": "nn_model", - "value": "ResNet50", - "options": [ - {"label": "VGG16", "value": "VGG16"}, - {"label": "VGG19", "value": "VGG19"}, - {"label": "ResNet101", "value": "ResNet101"}, - {"label": "ResNet152", "value": "ResNet152"}, - {"label": "ResNet50V2", "value": "ResNet50V2"}, - {"label": "ResNet50", "value": "ResNet50"}, - {"label": "ResNet152V2", "value": "ResNet152V2"}, - {"label": "DenseNet201", "value": "DenseNet201"}, - {"label": "NASNetLarge", "value": "NASNetLarge"}, - {"label": "DenseNet169", "value": "DenseNet169"} - ], - "comp_group": "train_model" - }, - { - "type": "int", - "name": "seed", - "title": "Seed", - "param_key": "seed", - "value": 0, - "comp_group": "train_model" - }, - { - "type": "slider", - "name": "rotation_angle", - "title": "Rotation Angle", - "param_key": "rotation_angle", - "min": 0, - "max": 360, - "value": 0, - "comp_group": "evaluate_model" - }, - { - "type": "radio", - "name": "image_flip", - "title": "Image Flip", - "param_key": "image_flip", - "value": "None", - "options": - [ - {"label": "None", "value": "None"}, - {"label": "Vertical", "value": "Vertical"}, - {"label": "Horizontal", "value": "Horizontal"}, - {"label": "Both", "value": "Both"} - ], - "comp_group": "evaluate_model" - }, - { - "type": "radio", - "name": "shuffle", - "title": "Shuffle Data", - "param_key": "shuffle", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "evaluate_model" - }, - { - "type": "slider", - "name": "batch_size", - "title": "Batch Size", - "param_key": "batch_size", - "min": 16, - "max": 128, - "step": 16, - "value": 32, - "comp_group": "evaluate_model" - }, - { - "type": "int", - "name": "seed", - "title": "Seed", - "param_key": "seed", - "value": 0, - "comp_group": "evaluate_model" - }, - { - "type": "slider", - "name": "rotation_angle", - "title": "Rotation Angle", - "param_key": "rotation_angle", - "min": 0, - "max": 360, - "value": 0, - "comp_group": "prediction_model" - }, - { - "type": "radio", - "name": "image_flip", - "title": "Image Flip", - "param_key": "image_flip", - "value": "None", - "options": - [ - {"label": "None", "value": "None"}, - {"label": "Vertical", "value": "Vertical"}, - {"label": "Horizontal", "value": "Horizontal"}, - {"label": "Both", "value": "Both"} - ], - "comp_group": "prediction_model" - }, - { - "type": "radio", - "name": "shuffle", - "title": "Shuffle Data", - "param_key": "shuffle", - "value": false, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, - { - "type": "slider", - "name": "batch_size", - "title": "Batch Size", - "param_key": "batch_size", - "min": 16, - "max": 128, - "step": 16, - "value": 32, - "comp_group": "prediction_model" - }, - { - "type": "int", - "name": "seed", - "title": "Seed", - "param_key": "seed", - "value": 0, - "comp_group": "prediction_model" - }, - { - "type": "slider", - "name": "rotation_angle", - "title": "Rotation Angle", - "param_key": "rotation_angle", - "min": 0, - "max": 360, - "value": 0, - "comp_group": "transfer_learning" - }, - { - "type": "radio", - "name": "image_flip", - "title": "Image Flip", - "param_key": "image_flip", - "value": "None", - "options": - [ - {"label": "None", "value": "None"}, - {"label": "Vertical", "value": "Vertical"}, - {"label": "Horizontal", "value": "Horizontal"}, - {"label": "Both", "value": "Both"} - ], - "comp_group": "transfer_learning" - }, - { - "type": "radio", - "name": "shuffle", - "title": "Shuffle Data", - "param_key": "shuffle", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "transfer_learning" - }, - { - "type": "slider", - "name": "batch_size", - "title": "Batch Size", - "param_key": "batch_size", - "min": 16, - "max": 128, - "step": 16, - "value": 32, - "comp_group": "transfer_learning" - }, - { - "type": "slider", - "name": "val_pct", - "title": "Validation Percentage", - "param_key": "val_pct", - "min": 0, - "max": 100, - "step": 5, - "value": 20, - "marks": { - "0": "0", - "100": "100" - }, - "comp_group": "transfer_learning" - }, - { - "type": "int", - "name": "init_layer", - "title": "Choose a trained model from the job list and select a layer to start training at below", - "param_key": "init_layer", - "value": 1, - "min": 0, - "comp_group": "transfer_learning" - }, - { - "type": "int", - "name": "seed", - "title": "Seed", - "param_key": "seed", - "value": 0, - "comp_group": "transfer_learning" - } - ], - "cmd": ["python3 src/train_model.py", "python3 src/evaluate_model.py", "python3 src/predict_model.py"], - "reference": "xxx", - "content_type": "model", - "public": false, - "service_type": "backend" -} \ No newline at end of file diff --git a/models/docker/Dockerfile b/models/docker/Dockerfile deleted file mode 100644 index e45a40e..0000000 --- a/models/docker/Dockerfile +++ /dev/null @@ -1,21 +0,0 @@ -FROM tensorflow/tensorflow:2.7.0 - -COPY docker/requirements.txt requirements.txt - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential\ - libsm6 \ - libxext6 \ - libxrender-dev \ - ffmpeg \ - graphviz \ - tree \ - python3-pip &&\ - pip install --upgrade pip - -RUN pip install -r requirements.txt - -WORKDIR /app/work/ -COPY src/ src/ - -CMD ["bash"] diff --git a/models/docker/Dockerfile_arm64 b/models/docker/Dockerfile_arm64 deleted file mode 100644 index 9706441..0000000 --- a/models/docker/Dockerfile_arm64 +++ /dev/null @@ -1,22 +0,0 @@ -#FROM armswdev/tensorflow-arm-neoverse:latest -FROM sonoisa/deep-learning-coding:pytorch1.12.0_tensorflow2.9.1 - -COPY docker/requirements_arm64.txt requirements.txt -USER root -RUN apt-get update && sudo apt-get install -y --no-install-recommends \ - build-essential\ - libsm6 \ - libxext6 \ - libxrender-dev \ - ffmpeg \ - graphviz \ - tree \ - python3-pip &&\ - pip3 install --upgrade pip - -RUN pip3 install -r requirements.txt -RUN pip3 install pydantic -WORKDIR /app/work/ -COPY src/ src/ - -CMD ["bash"] diff --git a/models/docker/Dockerfile_gpu b/models/docker/Dockerfile_gpu deleted file mode 100644 index 14f4fd2..0000000 --- a/models/docker/Dockerfile_gpu +++ /dev/null @@ -1,24 +0,0 @@ -FROM tensorflow/tensorflow:latest-gpu - -COPY docker/requirements.txt requirements.txt - -RUN rm /etc/apt/sources.list.d/cuda.list -RUN rm /etc/apt/sources.list.d/nvidia-ml.list - -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential\ - libsm6 \ - libxext6 \ - libxrender-dev \ - ffmpeg \ - graphviz \ - tree \ - python3-pip &&\ - pip install --upgrade pip - -RUN pip install -r requirements.txt - -WORKDIR /app/work/ -COPY src/ src/ - -CMD ["bash"] diff --git a/models/docker/requirements.txt b/models/docker/requirements.txt deleted file mode 100644 index b890e19..0000000 --- a/models/docker/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -h5py==3.1.0 -keras==2.7.0 -pandas==1.3.1 -pyarrow==9.0.0 -# fastparquet==0.8.1 -Pillow==8.3.1 -pydantic==1.8.2 -pydot==1.4.2 -scipy==1.7.1 diff --git a/models/docker/requirements_arm64.txt b/models/docker/requirements_arm64.txt deleted file mode 100644 index fbab654..0000000 --- a/models/docker/requirements_arm64.txt +++ /dev/null @@ -1,13 +0,0 @@ -#h5py==3.1.0 -keras==2.7.0 -# graphviz==0.19.1 -pandas==1.3.1 -pyarrow==9.0.0 -Pillow==8.3.1 -pydantic==1.8.2 -pydot==1.4.2 -#tensorflow-deps -#tensorflow-macos -#tensorflow-metal -scipy==1.7.1 - diff --git a/models/old_json_files/evaluate_model.json b/models/old_json_files/evaluate_model.json deleted file mode 100644 index 4f9066a..0000000 --- a/models/old_json_files/evaluate_model.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "model_name": "evaluate_model", - "version": "0.0.1", - "type": "testing", - "user": "mlexchange-team", - "uri": "TBD", - "application": [ - "labelmaker" - ], - "description": "", - "gui_parameters": [ - { - "type": "intslider", - "name": "rotation_angle", - "title": "Rotation Angle", - "value": 0, - "min": 0, - "max": 360 - }, - { - "type": "strchecklist", - "name": "image_flip", - "title": "Image FLip", - "value": [], - "options": [ - { - "label": "Horizontal", - "value": "horiz" - }, - { - "label": "Vertical", - "value": "vert" - } - ] - }, - { - "type": "intslider", - "name": "batch_size", - "title": "Batch Size", - "value": 32, - "min": 16, - "max": 1500, - "step": 8, - "tooltip": { - "always_visible": true, - "placement": "bottom"} - } - ], - "cmd": [] -} \ No newline at end of file diff --git a/models/old_json_files/kwarg_editor.py b/models/old_json_files/kwarg_editor.py deleted file mode 100644 index f7b11ea..0000000 --- a/models/old_json_files/kwarg_editor.py +++ /dev/null @@ -1,397 +0,0 @@ -import json -import re -from functools import partial -from typing import Callable, Dict -# noinspection PyUnresolvedReferences -from inspect import signature, _empty - -import dash -import dash_core_components as dcc -import dash_html_components as html -import dash_bootstrap_components as dbc -import dash_daq as daq -from dash.dependencies import Input, ALL, Output, State - -from targeted_callbacks import targeted_callback - -""" -{'name', 'title', 'value', 'type', -""" - - -def regularize_name(name): - return ''.join([c for c in name if c not in [' ']]) - - -class SimpleItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - type='number', - debounce=True, - visible=True, - **kwargs): - self.name = name - - self.label = dbc.Label(title or name) - self.input = dbc.Input(type=type, - debounce=debounce, - id={**base_id, - 'name': name, - 'layer': 'input'}, - **kwargs) - style = {} - if not visible: - style['display'] = 'none' - - super(SimpleItem, self).__init__(id={**base_id, - 'name': name, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class FloatItem(SimpleItem): - pass - - -class IntItem(SimpleItem): - def __init__(self, *args, **kwargs): - if 'min' not in kwargs: - kwargs['min'] = -9007199254740991 # min must be set for int validation to be enabled - super(IntItem, self).__init__(*args, step=1, **kwargs) - - -class StrItem(SimpleItem): - def __init__(self, *args, **kwargs): - super(StrItem, self).__init__(*args, type='text', **kwargs) - - -class SliderItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - visible=True, - **kwargs): - - self.label = dbc.Label(title or name) - self.input = dcc.Slider(id={**base_id, - 'name': name, - 'layer': 'input'}, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(SliderItem, self).__init__(id={**base_id, - 'name': name, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class ChecklistItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - options, - title=None, - visible=True, - **kwargs): - - self.label = dbc.Label(title or name) - self.input = dcc.Checklist(id={**base_id, - 'name': name, - 'layer': 'input'}, - options=options, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(ChecklistItem, self).__init__(id={**base_id, - 'name': name, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class DropdownItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - options, - title=None, - visible=True, - **kwargs): - - self.label = dbc.Label(title or name) - self.input = dcc.Dropdown(id={**base_id, - 'name': name, - 'layer': 'input'}, - options=options, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(DropdownItem, self).__init__(id={**base_id, - 'name': name, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class RadioItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - options, - title=None, - visible=True, - **kwargs): - - self.label = dbc.Label(title or name) - self.input = dbc.RadioItems(id={**base_id, - 'name': name, - 'layer': 'input'}, - options=options, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(RadioItem, self).__init__(id={**base_id, - 'name': name, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class BoolItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - visible=True, - **kwargs): - - self.label = dbc.Label(title or name) - self.input = daq.ToggleSwitch(id={**base_id, - 'name': name, - 'layer': 'input'}, - **kwargs) - self.output_label = dbc.Label('False/True') - - style = {} - if not visible: - style['display'] = 'none' - - super(BoolItem, self).__init__(id={**base_id, - 'name': name, - 'layer': 'form_group'}, - children=[self.label, self.input, self.output_label], - style=style) - - -class ParameterEditor(dbc.Form): - - type_map = {float: FloatItem, - int: IntItem, - str: StrItem, - } - - def __init__(self, _id, parameters, **kwargs): - self._parameters = parameters - - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) - self.children = self.build_children() - - def init_callbacks(self, app): - targeted_callback(self.stash_value, - Input({**self.id, - 'name': ALL, - 'layer': 'input'}, - 'value'), - Output(self.id, 'n_submit'), - State(self.id, 'n_submit'), - app=app) - - def stash_value(self, value): - # find the changed item name from regex - r = '(?<=\"name\"\:\")[\w\-_]+(?=\")' - matches = re.findall(r, dash.callback_context.triggered[0]['prop_id']) - - if not matches: - raise LookupError('Could not find changed item name. Check that all parameter names use simple chars (\\w)') - - name = matches[0] - self.parameters[name]['value'] = value - - print(self.values) - - return next(iter(dash.callback_context.states.values())) or 0 + 1 - - @property - def values(self): - return {param['name']: param.get('value', None) for param in self._parameters} - - @property - def parameters(self): - return {param['name']: param for param in self._parameters} - - def _determine_type(self, parameter_dict): - if 'type' in parameter_dict: - if parameter_dict['type'] in self.type_map: - return parameter_dict['type'] - elif parameter_dict['type'].__name__ in self.type_map: - return parameter_dict['type'].__name__ - elif type(parameter_dict['value']) in self.type_map: - return type(parameter_dict['value']) - raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') - - def build_children(self, values=None): - children = [] - for parameter_dict in self._parameters: - parameter_dict = parameter_dict.copy() - if values and parameter_dict['name'] in values: - parameter_dict['value'] = values[parameter_dict['name']] - type = self._determine_type(parameter_dict) - parameter_dict.pop('type', None) - item = self.type_map[type](**parameter_dict, base_id=self.id) - children.append(item) - - return children - - -class JSONParameterEditor(ParameterEditor): - type_map = {'float': FloatItem, - 'int': IntItem, - 'str': StrItem, - 'intslider': SliderItem, - 'strdropdown': DropdownItem, - 'radio': RadioItem, - 'bool': BoolItem, - "strchecklist": ChecklistItem - } - - def _determine_type(self, parameter_dict): - if 'type' in parameter_dict: - if parameter_dict['type'] in self.type_map: - return parameter_dict['type'] - elif parameter_dict['type'].__name__ in self.type_map: - return parameter_dict['type'].__name__ - elif type(parameter_dict['value']).__name__ in self.type_map: - return type(parameter_dict['value']).__name__ - raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') - - -class KwargsEditor(ParameterEditor): - def __init__(self, instance_index, func: Callable, **kwargs): - self.func = func - self._instance_index = instance_index - - parameters = [{'name': name, 'value': param.default} for name, param in - signature(func).parameters.items() - if param.default is not _empty] - - super(KwargsEditor, self).__init__(dict(index=instance_index, type='kwargs-editor'), - parameters=parameters, **kwargs) - - @staticmethod - def parameters_from_func(func, prefix=''): - parameters = [{'name': prefix + name, - 'title': name, - 'value': param.default} - for name, param in signature(func).parameters.items() - if param.default is not _empty] - return parameters - - def new_record(self): - return {name: p.default for name, p in signature(self.func).parameters.items() if p.default is not _empty} - - -class StackedKwargsEditor(html.Div): - def __init__(self, instance_index, funcs: Dict[str, Callable], selector_label: str, id='kwargs-editor', **kwargs): - self.func_selector = dbc.Select(id=dict(index=instance_index, type=id, layer='stack'), - options=[{'label': name, 'value': name} for i, name in enumerate(funcs.keys())], - value=next(iter(funcs.keys()))) - - self.funcs = funcs - - parameters = [] - for i, (name, func) in enumerate(funcs.items()): - regularized_name = regularize_name(name) - func_params = KwargsEditor.parameters_from_func(func, prefix=f'{regularized_name}-') - if i: - for param in func_params: - param['visible'] = False - # self._param_map[name] = func_params.keys()) - parameters.extend(func_params) - - self.parameter_editor = ParameterEditor(dict(index=instance_index, type=id, layer='editor'), - parameters=parameters, - **kwargs) - - super(StackedKwargsEditor, self).__init__(children=[dbc.Label(selector_label), - html.Br(), - self.func_selector, - dbc.CardBody(children=self.parameter_editor)]) - - def init_callbacks(self, app): - for child in self.parameter_editor.children: - targeted_callback(partial(self.update_visibility, name=child.id['name']), - Input(self.func_selector.id, 'value'), - Output(child.id, 'style'), - prevent_initial_call=True, - app=app) - self.parameter_editor.init_callbacks(app) - - def update_visibility(self, value: str, name:str): - if name.startswith(f'{regularize_name(value)}-'): - return {'display': 'block'} - else: - return {'display': 'none'} - - -if __name__ == '__main__': - - app_kwargs = {'external_stylesheets': [dbc.themes.BOOTSTRAP]} - app = dash.Dash(__name__, **app_kwargs) - - item_list = ParameterEditor(_id={'type': 'parameter_editor'}, - parameters=[{'name': 'test', 'value': 2, 'max': 10, 'min': 0}, - {'name': 'test2', 'value': 'blah'}, - {'name': 'test3', 'value': 3.2, 'type': float}]) - - with open('example.json') as f: - json_file = json.load(f) - - json_items = JSONParameterEditor(_id={'type': 'json_parameter_editor'}, - parameters=json_file) - - def my_func(a='t', p=1, d='blah', e=23.4): - ... - - def my_func2(a, b, x=1, w='blah', z=23.4): - ... - - kwarg_list = KwargsEditor(0, func=my_func) - - func_editor = StackedKwargsEditor(1, funcs={'my_func': my_func, 'my_func2': my_func2}, - selector_label='test') - - kwarg_list.init_callbacks(app) - item_list.init_callbacks(app) - func_editor.init_callbacks(app) - json_items.init_callbacks(app) - - app.layout = html.Div([json_items]) - - app.run_server(debug=True) diff --git a/models/old_json_files/prediction_model.json b/models/old_json_files/prediction_model.json deleted file mode 100644 index 1a362d8..0000000 --- a/models/old_json_files/prediction_model.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "model_name": "prediction_model", - "version": "0.0.1", - "type": "testing", - "user": "mlexchange-team", - "uri": "TBD", - "application": [ - "labelmaker" - ], - "description": "", - "gui_parameters": [ - { - "type": "intslider", - "name": "rotation_angle", - "title": "Rotation Angle", - "value": 0, - "min": 0, - "max": 360 - }, - { - "type": "strchecklist", - "name": "image_flip", - "title": "Image FLip", - "value": [], - "options": [ - { - "label": "Horizontal", - "value": "horiz" - }, - { - "label": "Vertical", - "value": "vert" - } - ] - }, - { - "type": "intslider", - "name": "batch_size", - "title": "Batch Size", - "value": 32, - "min": 16, - "max": 1500, - "step": 8, - "tooltip": { - "always_visible": true, - "placement": "bottom"} - } - ], - "cmd": [] -} \ No newline at end of file diff --git a/models/old_json_files/train_model.json b/models/old_json_files/train_model.json deleted file mode 100644 index b93cf72..0000000 --- a/models/old_json_files/train_model.json +++ /dev/null @@ -1,97 +0,0 @@ -{ - "model_name": "train_model", - "version": "0.0.1", - "type": "training", - "user": "mlexchange-team", - "uri": "TBD", - "application": [ - "labelmaker" - ], - "description": "", - "gui_parameters": [ - { - "type": "intslider", - "name": "rotation_angle", - "title": "Rotation Angle", - "value": 0, - "min": 0, - "max": 360, - "tooltip": { - "always_visible": true, - "placement": "bottom"} - }, - { - "type": "strchecklist", - "name": "image_flip", - "title": "Image FLip", - "value": [], - "options": [ - { - "label": "Horizontal", - "value": "horiz" - }, - { - "label": "Vertical", - "value": "vert" - } - ] - }, - { - "type": "intslider", - "name": "batch_size", - "title": "Batch Size", - "value": 32, - "min": 16, - "max": 1500, - "step": 8, - "tooltip": { - "always_visible": true, - "placement": "bottom"} - }, - { - "type": "radio", - "name": "pooling_opts", - "title": "Pooling Options", - "options": [ - {"label": "None", "value": "None"}, - {"label": "Maximum", "value": "max"}, - {"label": "Average", "value": "avg"} - ], - "value": "None", - "labelStyle": {"display": "inline-block"} - }, - { - "type": "intslider", - "name": "num_epoch", - "title": "Number of Epoch", - "value": 3, - "min": 1, - "max": 100, - "tooltip": { - "always_visible": true, - "placement": "bottom"} - }, - { - "type": "strdropdown", - "name": "ml_model", - "title": "ML Model", - "options": [ - {"label": "Xception", "value": "Xception"}, - {"label": "VGG16", "value": "VGG16"}, - {"label": "VGG19", "value": "VGG19"}, - {"label": "ResNet101", "value": "ResNet101"}, - {"label": "ResNet152", "value": "ResNet152"}, - {"label": "ResNet50V2", "value": "ResNet50V2"}, - {"label": "ResNet50", "value": "ResNet50"}, - {"label": "ResNet152V2", "value": "ResNet152V2"}, - {"label": "InceptionV3", "value": "InceptionV3"}, - {"label": "DenseNet201", "value": "DenseNet201"}, - {"label": "NASNetLarge", "value": "NASNetLarge"}, - {"label": "InceptionResNetV2", "value": "InceptionResNetV2"}, - {"label": "DenseNet169", "value": "DenseNet169"} - ], - "value": "Xception" - } - ], - "cmd": [] -} \ No newline at end of file diff --git a/models/src/evaluate_model.py b/models/src/evaluate_model.py deleted file mode 100644 index 875944c..0000000 --- a/models/src/evaluate_model.py +++ /dev/null @@ -1,27 +0,0 @@ -import argparse -import json -import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - -from tensorflow.keras.models import load_model -from model_validation import DataAugmentationParams -from helper_utils import data_processing - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('test_dir', help='output directory') - parser.add_argument('model_dir', help='input directory') - parser.add_argument('parameters', help='list of parameters') - args = parser.parse_args() - - test_dir = args.test_dir - model_dir = args.model_dir - parameters = DataAugmentationParams(**json.loads(args.parameters)) - - (test_generator, tmp) = data_processing(parameters, test_dir) - - loaded_model = load_model(model_dir) - results = loaded_model.evaluate(test_generator, - verbose=False) - print("test loss, test acc: " + str(results[0]) + ", " + str(results[1]), flush=True) diff --git a/models/src/helper_utils.py b/models/src/helper_utils.py deleted file mode 100644 index f1279b7..0000000 --- a/models/src/helper_utils.py +++ /dev/null @@ -1,290 +0,0 @@ -import glob -import os -import math - -import pandas as pd -import numpy as np -import requests -from scipy.ndimage import interpolation -import tensorflow as tf -from tensorflow.keras.preprocessing.image import ImageDataGenerator - - -# keras callbacks for model training. Threads while keras functions are running -# so that you can see training or evaluation of the model in progress -class TrainCustomCallback(tf.keras.callbacks.Callback): - # For model training - def on_epoch_end(self, epoch, logs=None): - if logs.get('val_loss'): - if epoch == 0: - print('epoch loss val_loss accuracy val_accuracy\n', flush=True) - loss = logs.get('loss') - val_loss = logs.get('val_loss') - accuracy = logs.get('accuracy') - val_accuracy = logs.get('val_accuracy') - print(str(epoch) + ' ' + str(loss) + ' ' + str(val_loss) + ' ' + str(accuracy) + ' ' + str(val_accuracy) - + '\n', flush=True) - else: - if epoch == 0: - print('epoch loss accuracy\n', flush=True) - loss = logs.get('loss') - accuracy = logs.get('accuracy') - print(str(epoch) + ' ' + str(loss) + ' ' + str(accuracy) + '\n', flush=True) - - def on_train_end(self, logs=None): - print('Train process completed', flush=True) - - -# keras callbacks for model training. Threads while keras functions are running -# so that you can see training or evaluation of the model in progress -class TestCustomCallback(tf.keras.callbacks.Callback): - def __init__(self, filenames=None, classes=None): - self.classes = classes - self.filenames = filenames - - def on_predict_begin(self, logs=None): - print('Prediction process started\n', flush=True) - - def on_predict_batch_end(self, batch, logs=None): - out = logs['outputs'] - batch_size = out.shape[0] - if batch==0: - msg = ['filename'] + self.classes - print(' '.join(msg) + '\n', flush=True) - filenames = self.filenames[batch*batch_size:(batch+1)*batch_size] - for row in range(batch_size): # when batch>1 - prob = np.char.mod('%f', out[row,:]) - print(filenames[row]+ ' ' + ' '.join(prob) + '\n', flush=True) - - def on_predict_end(self, logs=None): - print('Prediction process completed', flush=True) - - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -IMG_SIZE = (224, 224) # dimensions for images: fixed due to TF models -COLOR_MODE = 'rgb' # fixed due to TF models -SPLASH_CLIENT = 'http://splash:80/api/v0' - - -# Get dataset from splash-ml -def load_from_splash(uri_list): - ''' - This function queries labels from splash-ml. - Args: - uri_list: URI of dataset (e.g. file path) - Returns: - splash_df: Dataframe of labeled images (docker path) - ''' - url = f'{SPLASH_CLIENT}/datasets/search' - params = {"page[offset]": 0, "page[limit]": len(uri_list)} - data = {'uris': uri_list} - datasets = requests.post(url, params=params, json=data).json() - labels_name_data = [] - for dataset in datasets: - for tag in dataset['tags']: - if tag['name'] == 'labelmaker': - labels_name_data.append([dataset['uri'], tag['locator']['path']]) - splash_df = pd.DataFrame(data=labels_name_data, - index=None, - columns=['filename', 'class']) - classes = list(splash_df['class'].unique()) - return splash_df, classes - - -# Data Augmentation + Batch Size -def data_processing(parameters, data_dir, no_label=False): - rotation_angle = parameters.rotation_angle - image_flip = parameters.image_flip - if image_flip=='None': - horizontal_flip = False - vertical_flip = False - if image_flip=='Vertical': - horizontal_flip = False - vertical_flip = True - if image_flip=='Horizontal': - horizontal_flip = True - vertical_flip = False - if image_flip=='Both': - horizontal_flip = True - vertical_flip = True - batch_size = parameters.batch_size - target_width = 224 - target_height = 224 - if parameters.shuffle: - shuffle = parameters.shuffle - else: - shuffle = False - if parameters.seed: - seed = parameters.seed - else: - seed = 45 # fixed seed - uri_list = parameters.splash - if parameters.val_pct: - datagen = ImageDataGenerator(rotation_range=rotation_angle, - rescale=1/255, - horizontal_flip=horizontal_flip, - vertical_flip=vertical_flip, - validation_split=parameters.val_pct/100) - else: - datagen = ImageDataGenerator(rotation_range=rotation_angle, - rescale=1/255, - horizontal_flip=horizontal_flip, - vertical_flip=vertical_flip) - - data_generator = [] - if not(uri_list): - first_data = glob.glob(data_dir + '/**/*.*', recursive=True) - if len(first_data) > 0: - data_type = os.path.splitext(first_data[0])[-1] - if data_type in ['.tiff', '.tif', '.jpg', '.jpeg', '.png']: - if no_label: - classes = None - datagen = ImageDataGenerator(rescale=1/255) - - list_filename = [] - for dirpath, subdirs, files in os.walk(data_dir): - for file in files: - if os.path.splitext(file)[-1] in ['.tiff', '.tif', '.jpg', '.jpeg', '.png'] and not ('.' in os.path.splitext(file)[0]): - filename = os.path.join(dirpath, file) - list_filename.append(filename) - - data_df = pd.DataFrame(data=list_filename, - index=None, - columns=['filename']) - train_generator = datagen.flow_from_dataframe(data_df, - directory=data_dir, - x_col='filename', - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - class_mode=None, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - ) - #train_generator = datagen.flow_from_directory('/'.join(data_dir.split('/')[0:-1]), - # target_size=(target_width, target_height), - # color_mode=COLOR_MODE, - # class_mode=None, - # batch_size=batch_size, - # shuffle=shuffle, - # seed=seed) - valid_generator = [] - elif parameters.val_pct: - classes = [subdir for subdir in sorted(os.listdir(data_dir)) if os.path.isdir(os.path.join(data_dir, subdir))] - train_generator = datagen.flow_from_directory(data_dir, - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - class_mode='categorical', - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - subset='training') - valid_generator = datagen.flow_from_directory(data_dir, - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - class_mode='categorical', - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - subset='validation') - else: - classes = [subdir for subdir in sorted(os.listdir(data_dir)) if os.path.isdir(os.path.join(data_dir, subdir))] - train_generator = datagen.flow_from_directory(data_dir, - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - class_mode='categorical', - batch_size=batch_size, - shuffle=shuffle, - seed=seed) - valid_generator = [] - data_generator = (train_generator, valid_generator) - - if os.path.splitext(data_dir)[-1] == '.npz': - x_key = parameters.x_key - y_key = parameters.y_key - data = np.load(data_dir) - x_data = data[x_key] - y_data = data[y_key] - y_data = tf.keras.utils.to_categorical(y_data, num_classes=len(np.unique(y_data))) - if target_width != x_data.shape[1] or target_height != x_data.shape[2]: # resize if needed - data_shape = list(x_data.shape) - w = target_width/data_shape[1] - h = target_height/data_shape[2] - for channel in data_shape[0]: - x_data[channel,:] = interpolation.zoom(x_data[channel,:], [1,w,h]) - if data_shape[0]==1: - x_data = np.repeat(x_data[:,:,:,np.newaxis], 3, axis=3) # RGB - print(y_data.shape) - if parameters.val_pct: - train_generator = datagen.flow(x=x_data, - y=y_data, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - subset='training') - valid_generator = datagen.flow(x=x_data, - y=y_data, - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - subset='validation') - else: - train_generator = datagen.flow(x=x_data, - y=data[y_key], - batch_size=batch_size, - shuffle=shuffle, - seed=seed) - valid_generator = [] - classes = np.unique(train_generator.__dict__['y'], axis=0) - data_generator = (train_generator, valid_generator) - - else: - splash_df, classes = load_from_splash(uri_list) - #print(splash_df) - if parameters.val_pct: - train_generator = datagen.flow_from_dataframe(splash_df, - directory=data_dir, - x_col='filename', - y_col='class', - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - classes=None, - class_mode='categorical', - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - subset='training', - ) - - valid_generator = datagen.flow_from_dataframe(splash_df, - directory=data_dir, - x_col='filename', - y_col='class', - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - classes=None, - class_mode='categorical', - batch_size=batch_size, - shuffle=shuffle, - seed=seed, - subset='validation', - ) - else: - train_generator = datagen.flow_from_dataframe(splash_df, - directory=data_dir, - x_col='filename', - y_col='class', - target_size=(target_width, target_height), - color_mode=COLOR_MODE, - classes=None, - class_mode='categorical', - batch_size=batch_size, - shuffle=shuffle, - seed=seed - ) - - valid_generator = [] - - data_generator = (train_generator, valid_generator) - - return data_generator, classes diff --git a/models/src/model_validation.py b/models/src/model_validation.py deleted file mode 100644 index 39836e7..0000000 --- a/models/src/model_validation.py +++ /dev/null @@ -1,90 +0,0 @@ -from enum import Enum -from pydantic import BaseModel, Field -from typing import Optional, List - - -class NNModel(str, Enum): - xception = 'Xception' - vgg16 = 'VGG16' - vgg19 = 'VGG19' - resnet101 = 'ResNet101' - resnet152 = 'ResNet152' - resnet50v2 = 'ResNet50V2' - resnet50 = 'ResNet50' - resnet152v2 = 'ResNet152V2' - inceptionv3 = 'InceptionV3' - densenet201 = 'DenseNet201' - nasnetlarge = 'NASNetLarge' - inceptionresnetv2 = 'InceptionResNetV2' - densenet169 = 'DenseNet169' - - -class Optimizer(str, Enum): - Adadelta = "Adadelta" - Adagrad = "Adagrad" - Adam = "Adam" - Adamax = "Adamax" - Ftrl = "Ftrl" - Nadam = "Nadam" - RMSprop = "RMSprop" - SGD = "SGD" - - -class Weights(str, Enum): - none = 'None' - imagenet = 'imagenet' - - -class LossFunction(str, Enum): - binary_crossentropy = "binary_crossentropy" - binary_focal_crossentropy = "binary_focal_crossentropy" - categorical_crossentropy = "categorical_crossentropy" - categorical_hinge = "categorical_hinge" - cosine_similarity = "cosine_similarity" - hinge = "hinge" - huber = "huber" - log_cosh = "log_cosh" - kullback_leibler_divergence = "kullback_leibler_divergence" - mean_absolute_error = "mean_absolute_error" - mean_absolute_percentage_error = "mean_absolute_percentage_error" - mean_squared_error = "mean_squared_error" - mean_squared_logarithmic_error = "mean_squared_logarithmic_error" - poisson = "poisson" - sparse_categorical_crossentropy = "sparse_categorical_crossentropy" - squared_hinge = "squared_hinge" - - -class ImageFlip(str, Enum): - none = 'None' - vert = 'Vertical' - horiz = 'Horizontal' - both = 'Both' - - -class DataAugmentationParams(BaseModel): - rotation_angle: int = Field(description='rotation angle') - image_flip: ImageFlip - batch_size: int = Field(description='batch size') - val_pct: Optional[int] = Field(description='validation percentage') - shuffle: Optional[bool] = Field(description='shuffle data') - target_width: Optional[int] = Field(description='data target width') - target_height: Optional[int] = Field(description='data target height') - x_key: Optional[str] = Field(description='keyword for x data in NPZ') - y_key: Optional[str] = Field(description='keyword for y data in NPZ') - seed: Optional[int] = Field(description='random seed') - splash: Optional[List[str]] = Field(description='List of URIs in splash-ml') - - -class TrainingParams(DataAugmentationParams): - weights: Weights - optimizer: Optimizer - loss_function: LossFunction - learning_rate: float = Field(description='learning rate') - epochs: int = Field(description='number of epochs') - nn_model: NNModel - - -class TransferLearningParams(DataAugmentationParams): - epochs: int = Field(description='number of epochs') - init_layer: int = Field(description='initial layer') - nn_model: Optional[NNModel] diff --git a/models/src/predict_model.py b/models/src/predict_model.py deleted file mode 100644 index 00344d2..0000000 --- a/models/src/predict_model.py +++ /dev/null @@ -1,52 +0,0 @@ -import argparse -import json -import os - -import numpy as np -import pandas as pd -from tensorflow.keras.models import load_model - -from model_validation import DataAugmentationParams -from helper_utils import TestCustomCallback, data_processing - - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('test_dir', help='output directory') - parser.add_argument('model_dir', help='input directory') - parser.add_argument('out_dir', help='output directory') - parser.add_argument('parameters', help='list of parameters') - args = parser.parse_args() - - test_dir = args.test_dir - model_dir = args.model_dir - out_dir = args.out_dir - data_parameters = DataAugmentationParams(**json.loads(args.parameters)) - - (test_generator, tmp), tmp_class = data_processing(data_parameters, test_dir, True) - test_dir = '/'.join(test_dir.split('/')[0:-1]) - try: - test_filenames = test_generator.filenames #[f'{test_dir}/{x}' for x in test_generator.filenames] - class_dir = os.path.split(model_dir)[0] - classes = pd.read_csv(class_dir+'/'+'classes.csv') - classes = classes.values.tolist() - classes = [x for xs in classes for x in xs] - except Exception as e: - test_filenames = list(range(len(test_generator.__dict__['x']))) # list of indexes - test_filenames = [f'{test_dir}/{x}' for x in test_filenames] # full docker path filenames - classes = np.unique(test_generator.__dict__['y'], axis=0) # list of classes - classes = [str(x) for x in classes] - - df_files = pd.DataFrame(test_filenames, columns=['filename']) - loaded_model = load_model(model_dir) - prob = loaded_model.predict(test_generator, - verbose=0, - callbacks=[TestCustomCallback(test_filenames, classes)]) - df_prob = pd.DataFrame(prob, columns=classes) - #print(df_prob) - df_results = pd.concat([df_files,df_prob], axis=1) - df_results = df_results.set_index(['filename']) - df_results.to_parquet(out_dir + '/results.parquet') diff --git a/models/src/train_model.py b/models/src/train_model.py deleted file mode 100644 index 1efab72..0000000 --- a/models/src/train_model.py +++ /dev/null @@ -1,79 +0,0 @@ -import argparse -import json -import os - -import numpy as np -import pandas as pd -import tensorflow as tf - -from model_validation import TrainingParams, DataAugmentationParams -from helper_utils import TrainCustomCallback, data_processing -#from keras.layers import VersionAwareLayers -import tensorflow.keras.layers as layers - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' -#layers = VersionAwareLayers() - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('train_dir', help='output directory') - parser.add_argument('out_dir', help='output directory') - parser.add_argument('parameters', help='list of training parameters') - args = parser.parse_args() - - train_dir = args.train_dir - out_dir = args.out_dir - train_parameters = TrainingParams(**json.loads(args.parameters)) - data_parameters = DataAugmentationParams(**json.loads(args.parameters)) - print(tf.test.gpu_device_name()) - (train_generator, valid_generator), classes = data_processing(data_parameters, train_dir) - try: - train_filenames = train_generator.filenames - except Exception as e: - train_filenames = list(range(len(train_generator.__dict__['x']))) # list of indexes - class_num = len(classes) - - weights = train_parameters.weights - epochs = train_parameters.epochs - nn_model = train_parameters.nn_model - optimizer = train_parameters.optimizer.value - learning_rate = train_parameters.learning_rate - loss_func = train_parameters.loss_function.value - - opt_code = compile(f'tf.keras.optimizers.{optimizer}(learning_rate={learning_rate})', '', 'eval') - print(f'weights: {weights}') - if weights != 'None': - model_code = compile(f"tf.keras.applications.{nn_model}(include_top=False, input_shape=(224,224,3), weights='imagenet', input_tensor=None)", - "", 'eval') - base_model = eval(model_code) - - x = base_model.output - x = layers.Flatten(name="flatten")(x) - x = layers.Dense(4096, activation="relu", name="fc1")(x) - x = layers.Dense(4096, activation="relu", name="fc2")(x) - predictions = layers.Dense(class_num, activation='softmax', name="predictions")(x) - model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions) - else: - model_code = compile(f"tf.keras.applications.{nn_model}(include_top=True, weights=None, input_tensor=None, classes={class_num})", - "", 'eval') - model = eval(model_code) - model.compile(optimizer=eval(opt_code), # default adam - loss=loss_func, # default categorical_crossentropy - metrics=['accuracy']) - model.summary() - # tf.keras.utils.plot_model(model, out_dir+'/model_layout.png', show_shapes=True) # plot NN - print('Length:', len(model.layers), 'layers') # number of layers - - # fit model while also keeping track of data for dash plots. - model.fit(train_generator, - validation_data=valid_generator, - epochs=epochs, - verbose=0, - callbacks=[TrainCustomCallback()], - shuffle=data_parameters.shuffle) - - # save model - model.save(out_dir+'/model.h5') - df_classes = pd.DataFrame(classes) - df_classes.to_csv(out_dir + '/classes.csv', index=False) - print("Saved to disk") diff --git a/models/src/transfer_learning.py b/models/src/transfer_learning.py deleted file mode 100644 index c89aebd..0000000 --- a/models/src/transfer_learning.py +++ /dev/null @@ -1,55 +0,0 @@ -import argparse -import json -import os - -import tensorflow as tf -from tensorflow.keras.models import load_model - -from model_validation import TransferLearningParams, DataAugmentationParams -from helper_utils import TrainCustomCallback, data_processing - - -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('train_dir', help='output directory') - parser.add_argument('valid_dir', help='output directory') - parser.add_argument('model_dir', help='output directory') - parser.add_argument('out_dir', help='output directory') - parser.add_argument('parameters', help='list of training parameters') - args = parser.parse_args() - - train_dir = args.train_dir - valid_dir = args.valid_dir - model_dir = args.model_dir - out_dir = args.out_dir - transfer_parameters = TransferLearningParams(**json.loads(args.parameters)) - data_parameters = DataAugmentationParams(**json.loads(args.parameters)) - - print('Device: ', tf.test.gpu_device_name()) - - (train_generator, valid_generator) = data_processing(data_parameters, train_dir) - - epochs = transfer_parameters.epochs - start_layer = transfer_parameters.init_layer - - model = load_model(model_dir) - model.trainable = True - for layers in model.layers[:start_layer]: - layers.trainable = False - - model.summary() - - # fit model while also keeping track of data for dash plots. - model.fit(train_generator, - epochs=epochs, - verbose=0, - validation_data=valid_generator, - callbacks=[TrainCustomCallback()], - shuffle=data_parameters.shuffle) - - # save model - model.save(out_dir+'/model.h5') - print("Saved to disk") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..33985a1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,46 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] + packages = ["src/**/*"] + +[tool.hatch.metadata] +allow-direct-references = true + +[project] +name = "mlcoach" +version = "0.2.0" +description = "MLCoach" +readme = { file = "README.md", content-type = "text/markdown" } +requires-python = ">=3.11" + +dependencies = [ + "dash==2.9.3", + "dash[diskcache]", + "dash_bootstrap_components==1.0.2", + "dash_daq==0.5.0", + "dash-extensions==0.0.71", + "flask==3.0.0", + "Flask-Caching", + "kaleido", + "dash_component_editor@git+https://github.com/mlexchange/mlex_dash_component_editor", + "mlex_file_manager@git+https://github.com/mlexchange/mlex_file_manager", + "numpy>=1.19.5", + "pandas", + "Pillow", + "plotly==5.21.0", + "plotly-express", + "python-dotenv", + "requests==2.26.0", + "diskcache==5.6.3" +] + +[project.optional-dependencies] +dev = [ + "black==24.2.0", + "flake8==7.0.0", + "isort==5.13.2", + "pre-commit==3.6.2", + "pytest==8.1.1", +] diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app_layout.py b/src/app_layout.py new file mode 100644 index 0000000..2d3ad84 --- /dev/null +++ b/src/app_layout.py @@ -0,0 +1,88 @@ +import logging +import os +import pathlib + +import dash +import dash_bootstrap_components as dbc +import diskcache +from dash import html +from dash.long_callback import DiskcacheLongCallbackManager +from dotenv import load_dotenv +from file_manager.main import FileManager + +from src.components.header import header +from src.components.job_table import job_table +from src.components.main_display import main_display +from src.components.resources_setup import resources_setup +from src.components.sidebar import sidebar +from src.utils.job_utils import TableJob, get_host +from src.utils.model_utils import get_model_list + +load_dotenv(".env") + +USER = "admin" +DATA_DIR = os.getenv("DATA_DIR") +DOCKER_DATA = pathlib.Path.home() / "data" +UPLOAD_FOLDER_ROOT = DOCKER_DATA / "upload" +SPLASH_URL = os.getenv("SPLASH_URL") +DEFAULT_TILED_URI = os.getenv("DEFAULT_TILED_URI") +TILED_KEY = os.getenv("TILED_KEY") +if TILED_KEY == "": + TILED_KEY = None +HOST_NICKNAME = os.getenv("HOST_NICKNAME") +num_processors, num_gpus = get_host(HOST_NICKNAME) + +# SETUP LOGGING +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# SETUP DASH APP +cache = diskcache.Cache("./cache") +long_callback_manager = DiskcacheLongCallbackManager(cache) + +external_stylesheets = [ + dbc.themes.BOOTSTRAP, + "../assets/mlex-style.css", + "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css", +] +app = dash.Dash( + __name__, + external_stylesheets=external_stylesheets, + long_callback_manager=long_callback_manager, +) +app.title = "MLCoach" +app._favicon = "mlex.ico" +dash_file_explorer = FileManager( + DATA_DIR, + open_explorer=False, + api_key=TILED_KEY, + logger=logger, +) +dash_file_explorer.init_callbacks(app) + +# DEFINE LAYOUT +app.layout = html.Div( + [ + header("MLExchange | MLCoach", "https://github.com/mlexchange/mlex_mlcoach"), + dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + sidebar( + dash_file_explorer.file_explorer, + get_model_list(), + TableJob.get_counter(USER), + ), + width=4, + ), + dbc.Col(main_display(job_table()), width=8), + html.Div(id="dummy-output"), + ] + ), + resources_setup(num_processors, num_gpus), + ], + fluid=True, + ), + ] +) diff --git a/src/assets/image_transformation.js b/src/assets/image_transformation.js new file mode 100644 index 0000000..3211099 --- /dev/null +++ b/src/assets/image_transformation.js @@ -0,0 +1,357 @@ +if (typeof window.dash_clientside === 'undefined') { + window.dash_clientside = {}; +} + +if (typeof window.dash_clientside.clientside === 'undefined') { + window.dash_clientside.clientside = {}; +} + + +window.dash_clientside.clientside.transform_image = function(logToggle, data) { + console.log("Received logToggle:", logToggle); // Check logToggle value + console.log("Received src:", data); // Check src value + src = data; + // If logToggle is false or src is not provided, return the original src + if (!logToggle || !src) { + console.log("Returning original image without transformation."); + return Promise.resolve(src); + } + + return new Promise(function(resolve, reject) { + // Create an Image element + var image = new Image(); + image.onload = function() { + // If logToggle is true, proceed with the transformation + var canvas = document.createElement('canvas'); + canvas.width = image.width; + canvas.height = image.height; + var ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0, image.width, image.height); + var imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); + var data = imageData.data; + var floatData = new Float32Array(data.length); + + // Apply log(1+x) transformation to each pixel + var min = Infinity; + var max = -Infinity; + for (var i = 0; i < data.length; i += 4) { + floatData[i] = Math.log1p(data[i]); // Red + floatData[i + 1] = Math.log1p(data[i + 1]); // Green + floatData[i + 2] = Math.log1p(data[i + 2]); // Blue + // Alpha channel remains unchanged + + // Update min and max + min = Math.min(min, floatData[i], floatData[i + 1], floatData[i + 2]); + max = Math.max(max, floatData[i], floatData[i + 1], floatData[i + 2]); + } + + // Apply min-max normalization and scale to 0-255 + for (var i = 0; i < floatData.length; i += 4) { + floatData[i] = (floatData[i] - min) / (max - min) * 255; // Red + floatData[i + 1] = (floatData[i + 1] - min) / (max - min) * 255; // Green + floatData[i + 2] = (floatData[i + 2] - min) / (max - min) * 255; // Blue + floatData[i + 3] = data[i + 3]; // Alpha channel remains unchanged + } + + // Convert floatData back to Uint8ClampedArray for imageData + for (var i = 0; i < data.length; i++) { + data[i] = Math.round(floatData[i]); + } + + ctx.putImageData(imageData, 0, 0); + resolve(canvas.toDataURL()); // Convert the canvas back to a base64 URL + }; + image.onerror = function() { + console.error("Failed to load image"); + reject(new Error('Failed to load image')); + }; + image.src = src; + }); +} + +window.dash_clientside.clientside.transform_raw_data = function(logToggle, mask, data, minMaxValues) { + console.log("Received logToggle:", logToggle); // Check logToggle value + console.log("Received mask:", mask); // Check maskPath value + console.log("Received minMaxValues:", minMaxValues); // Check minMaxValues value + console.log("Received data:", data); // Check data value + if (typeof data === 'undefined' || data === null || !Array.isArray(data)) { + console.log("Data is not defined or null."); + console.log("Returning original image without transformation."); + return Promise.resolve(src); + } + src = data; + // If src is not provided, or maskPath or minMaxValues have changed, return the original src + if (!src || !minMaxValues || minMaxValues.length !== 2) { + console.log("Returning original image without transformation."); + return Promise.resolve(src); + } + console.log("Received data size:", data.length); // Check data value + + return new Promise(function(resolve, reject) { + var maskImage = new Image(); + var outImage = new Image(); + + if (data[0].length > data.length) { + outImage.width = 200; + outImage.height = 200 * (data.length / data[0].length); + } else { + outImage.width = 200 * (data[0].length / data.length); + outImage.height = 200; + } + + // Get the dimensions of the original 2D array + var originalWidth = data[0].length; + var originalHeight = data.length; + + var maskPromise = mask ? + new Promise((resolve, reject) => { maskImage.onload = resolve; maskImage.onerror = reject; maskImage.src = mask; }) : + Promise.resolve(); + + Promise.all([ + maskPromise + ]).then(() => { + // If logToggle is true, proceed with the transformation + var data = src; + + // Flatten the data + data = data.flat(); + + // Create a new canvas and context + var canvas = document.createElement('canvas'); + var ctx = canvas.getContext('2d'); + + // Set the canvas size to match the image size + canvas.width = originalWidth; + canvas.height = originalHeight; + + console.log("Canvas width:", canvas.width); + console.log("Canvas height:", canvas.height); + + var maskData; + if (mask) { + console.log("Mask image loaded successfully"); + ctx.clearRect(0, 0, canvas.width, canvas.height); + // Resize the mask if it's larger than the image + ctx.drawImage(maskImage, 0, 0, originalWidth, originalHeight); + var tempMaskData = ctx.getImageData(0, 0, canvas.width, canvas.height).data; + maskData = new Uint8Array(data.length); + var j = 0; + for (var i = 0; i < data.length; i++) { + maskData[i] = tempMaskData[j] < 254 ? 0 : tempMaskData[j]; + j += 4; + } + } else { + // Default mask of all 1s + maskData = new Uint8Array(data.length).fill(1); + } + + // Determine the number of channels + var numChannels = src.numChannels || 1; // Assuming src.numChannels is defined. If not, default to 1. + + // Apply mask and log transformation to each pixel + new_min = 255; + new_max = 0; + var new_data = new Float32Array(data.length / numChannels); + console.log("New data length:", new_data.length); + + for (var i = 0; i < new_data.length; i += numChannels) { + var grey = data[i * numChannels]; + if (maskData[i]!=0 && grey>0 && !isNaN(grey)){ + // Clip data between min - max values + grey = Math.max(minMaxValues[0], Math.min(minMaxValues[1], grey)); + // Normalize the data between the min-max values + grey = (grey - minMaxValues[0]) / (minMaxValues[1] - minMaxValues[0]); + if (logToggle) { + grey = Math.log(grey + 0.000000000001); + } + if (grey < new_min) { + new_min = grey; + } + if (grey > new_max) { + new_max = grey; + } + new_data[i] = grey; + + } else { + new_data[i] = 0; + } + + } + console.log("New min value:", new_min); + console.log("New max value:", new_max); + + var reshapedData = new Uint8ClampedArray(originalWidth * originalHeight * 4); + + for (var i = 0, j = 0; i < new_data.length; i++, j += 4) { + var tmp = Math.round((new_data[i] - new_min) / (new_max - new_min) * 255); + if (maskData[i] != 0){ + reshapedData[j] = tmp; // Red channel + reshapedData[j + 1] = tmp; // Green channel + reshapedData[j + 2] = tmp; // Blue channel + reshapedData[j + 3] = 255; // Alpha channel + } + else { + reshapedData[j] = 0; // Red channel + reshapedData[j + 1] = 0; // Green channel + reshapedData[j + 2] = 0; // Blue channel + reshapedData[j + 3] = 255; // Alpha channel + } + } + + // Create a new ImageData object + var imageData = new ImageData(reshapedData, originalWidth, originalHeight); + + // Create a temporary canvas to hold the original image + var tempCanvas = document.createElement('canvas'); + var tempCtx = tempCanvas.getContext('2d'); + tempCanvas.width = originalWidth; + tempCanvas.height = originalHeight; + tempCtx.putImageData(imageData, 0, 0); + + // Change the canvas dimensions to the new size + canvas.width = outImage.width; + canvas.height = outImage.height; + + // Clear the canvas and draw the image again, this time resizing it + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.drawImage(tempCanvas, 0, 0, originalWidth, originalHeight, 0, 0, canvas.width, canvas.height); + + // Convert the canvas to a data URL + var dataUrl = canvas.toDataURL(); + + // Set the source of the outImage to the data URL of the canvas + outImage.src = dataUrl; + resolve(dataUrl); // Convert the canvas back to a base64 URL + + }).catch((err) => { + console.error("Failed to load image", err); + reject(new Error('Failed to load image')); + }); + }); +} + +window.dash_clientside.clientside.expanded_transform_image = function(logToggle, mask, minMaxValues, data) { + console.log("Received logToggle:", logToggle); // Check logToggle value + console.log("Received mask:", mask); // Check maskPath value + console.log("Received minMaxValues:", minMaxValues); // Check minMaxValues value + src = data; + // // If src is not provided, or maskPath or minMaxValues have changed, return the original src + if (!src || !minMaxValues || minMaxValues.length !== 2) { + console.log("Returning original image without transformation."); + return Promise.resolve(src); + } + + return new Promise(function(resolve, reject) { + // Create an Image element + var image = new Image(); + var maskImage = new Image(); + + var maskPromise = mask ? + new Promise((resolve, reject) => { maskImage.onload = resolve; maskImage.onerror = reject; maskImage.src = mask; }) : + Promise.resolve(); + + Promise.all([ + new Promise((resolve, reject) => { image.onload = resolve; image.onerror = reject; image.src = src; }), + maskPromise + ]).then(() => { + // If logToggle is true, proceed with the transformation + var canvas = document.createElement('canvas'); + canvas.width = image.width; + canvas.height = image.height; + var ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0, image.width, image.height); + var imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); + var data = imageData.data; + + var maskData; + if (mask) { + console.log("Mask image loaded successfully"); + ctx.clearRect(0, 0, canvas.width, canvas.height); + // Resize the mask if it's larger than the image + ctx.drawImage(maskImage, 0, 0, image.width, image.height); + var tempMaskData = ctx.getImageData(0, 0, canvas.width, canvas.height).data; + maskData = new Uint8Array(tempMaskData.length); + for (var i = 0; i < tempMaskData.length; i++) { + maskData[i] = tempMaskData[i] < 254 ? 0 : tempMaskData[i]; + } + } else { + console.log("Loading default mask"); + // Default mask of all 1s + maskData = new Uint8Array(data.length / 1).fill(1); + } + // console.log("Min value of mask:", Math.min(...maskData)); + // console.log("Max value of mask:", Math.max(...maskData)); + + // Apply mask and log transformation to each pixel + new_min = 255; + new_max = 0; + var new_data = new Float32Array(data.length / 1); + + for (var i = 0; i < data.length; i += 4) { + var red = data[i]; + var green = data[i + 1]; + var blue = data[i + 2]; + + if (maskData[i]!=0 && red>0 && green>0 && blue>0 && !isNaN(red) && !isNaN(green) && !isNaN(blue)){ + + // Clip data between min - max values + red = Math.max(minMaxValues[0], Math.min(minMaxValues[1], red)); + green = Math.max(minMaxValues[0], Math.min(minMaxValues[1], green)); + blue = Math.max(minMaxValues[0], Math.min(minMaxValues[1], blue)); + + // Normalize the data between the min-max values + red = (red - minMaxValues[0]) / (minMaxValues[1] - minMaxValues[0]); + green = (green - minMaxValues[0]) / (minMaxValues[1] - minMaxValues[0]); + blue = (blue - minMaxValues[0]) / (minMaxValues[1] - minMaxValues[0]); + + if (logToggle) { + red = Math.log(red + 0.000000000001); + green = Math.log(green + 0.000000000001); + blue = Math.log(blue + 0.000000000001); + } + + min_value = Math.min(red, green, blue); + if (min_value < new_min) { + new_min = min_value; + } + max_value = Math.max(red, green, blue); + if (max_value > new_max) { + new_max = max_value; + } + + new_data[i] = red; + new_data[i + 1] = green; + new_data[i + 2] = blue; + + } else { + + new_data[i] = 0; + new_data[i + 1] = 0; + new_data[i + 2] = 0; + + } + + } + + for (var i = 0; i < data.length; i += 4) { + var red = data[i]; + var green = data[i + 1]; + var blue = data[i + 2]; + + if (maskData[i] != 0 && red>0 && green>0 && blue>0 && !isNaN(red) && !isNaN(green) && !isNaN(blue)){ + data[i] = Math.round((new_data[i] - new_min) / (new_max - new_min) * 255); + data[i + 1] = Math.round((new_data[i + 1] - new_min) / (new_max - new_min) * 255); + data[i + 2] = Math.round((new_data[i + 2] - new_min) / (new_max - new_min) * 255); + } + } + + ctx.putImageData(imageData, 0, 0); + resolve(canvas.toDataURL()); // Convert the canvas back to a base64 URL + }).catch((err) => { + console.error("Failed to load image", err); + reject(new Error('Failed to load image')); + }); + }); +} diff --git a/src/assets/segmentation-style.css b/src/assets/mlex-style.css similarity index 91% rename from src/assets/segmentation-style.css rename to src/assets/mlex-style.css index 4cb18ba..4f68490 100644 --- a/src/assets/segmentation-style.css +++ b/src/assets/mlex-style.css @@ -22,7 +22,7 @@ label { margin-bottom: 0; } -#transparent-loader-wrapper > div { +.transparent-loader-wrapper > div { visibility: visible !important; } diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/callbacks/display.py b/src/callbacks/display.py new file mode 100644 index 0000000..ce68944 --- /dev/null +++ b/src/callbacks/display.py @@ -0,0 +1,333 @@ +import os +import pathlib +import pickle +import time +import traceback + +import dash +import pandas as pd +import requests +from dash import Input, Output, State, callback +from dash.exceptions import PreventUpdate +from file_manager.data_project import DataProject + +from src.app_layout import DATA_DIR, SPLASH_URL, TILED_KEY, USER, logger +from src.utils.job_utils import TableJob +from src.utils.plot_utils import generate_loss_plot, get_class_prob, plot_figure + + +@callback( + Output("img-output-store", "data"), + Output("img-uri", "data"), + Input("img-slider", "value"), + State({"base_id": "file-manager", "name": "data-project-dict"}, "data"), + State("jobs-table", "selected_rows"), + State("jobs-table", "data"), + prevent_initial_call=True, +) +def refresh_image( + img_ind, + data_project_dict, + row, + data_table, +): + """ + This callback updates the image in the display + Args: + img_ind: Index of image according to the slider value + log: Log toggle + data_project_dict: Selected data + row: Selected job (model) + data_table: Data in table of jobs + Returns: + img-output: Output figure + """ + start = time.time() + # Get selected job type + if row and len(row) > 0 and row[0] < len(data_table): + selected_job_type = data_table[row[0]]["job_type"] + else: + selected_job_type = None + + if selected_job_type == "prediction_model": + job_id = data_table[row[0]]["experiment_id"] + data_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{USER}/{job_id}") + + with open(f"{data_path}/.file_manager_vars.pkl", "rb") as file: + data_project_dict = pickle.load(file) + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + if ( + len(data_project.datasets) > 0 + and data_project.datasets[-1].cumulative_data_count > 0 + ): + fig, uri = data_project.read_datasets(indices=[img_ind], resize=True) + fig = fig[0] + uri = uri[0] + else: + uri = None + fig = plot_figure() + logger.info(f"Time to read data: {time.time() - start}") + return fig, uri + + +@callback( + Output("img-slider", "max", allow_duplicate=True), + Output("img-slider", "value", allow_duplicate=True), + Input("jobs-table", "selected_rows"), + Input("jobs-table", "data"), + State("img-slider", "value"), + prevent_initial_call=True, +) +def update_slider_boundaries_prediction( + row, + data_table, + slider_ind, +): + """ + This callback updates the slider boundaries according to the selected job type + Args: + row: Selected row (job) + data_table: Lists of jobs + slider_ind: Slider index + Returns: + img-slider: Maximum value of the slider + img-slider: Slider index + """ + # Get selected job type + if row and len(row) > 0 and row[0] < len(data_table): + selected_job_type = data_table[row[0]]["job_type"] + else: + selected_job_type = None + + # If selected job type is train_model or tune_model + if selected_job_type == "prediction_model": + job_id = data_table[row[0]]["experiment_id"] + data_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{USER}/{job_id}") + + with open(f"{data_path}/.file_manager_vars.pkl", "rb") as file: + data_project_dict = pickle.load(file) + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + + # Check if slider index is out of bounds + if ( + len(data_project.datasets) > 0 + and slider_ind > data_project.datasets[-1].cumulative_data_count - 1 + ): + slider_ind = 0 + + return data_project.datasets[-1].cumulative_data_count - 1, slider_ind + + else: + raise PreventUpdate + + +@callback( + Output("img-slider", "max"), + Output("img-slider", "value"), + Input({"base_id": "file-manager", "name": "data-project-dict"}, "data"), + Input("jobs-table", "selected_rows"), + State("img-slider", "value"), + prevent_initial_call=True, +) +def update_slider_boundaries_new_dataset( + data_project_dict, + row, + slider_ind, +): + """ + This callback updates the slider boundaries according to the selected job type + Args: + data_project_dict: Data project dictionary + row: Selected row (job) + slider_ind: Slider index + Returns: + img-slider: Maximum value of the slider + img-slider: Slider index + """ + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + if len(data_project.datasets) > 0: + max_ind = data_project.datasets[-1].cumulative_data_count - 1 + else: + max_ind = 0 + + slider_ind = min(slider_ind, max_ind) + return max_ind, slider_ind + + +@callback( + Output("img-slider", "value", allow_duplicate=True), + Input("img-labeled-indx", "value"), + prevent_initial_call=True, +) +def update_slider_value(labeled_img_ind): + """ + This callback updates the slider value according to the labeled image index + Args: + labeled_img_ind: Index of labeled image + Returns: + img-slider: Slider index + """ + return labeled_img_ind + + +@callback( + Output("img-label", "children"), + Input("img-uri", "data"), + Input("event-id", "value"), + State({"base_id": "file-manager", "name": "data-project-dict"}, "data"), + prevent_initial_call=True, +) +def refresh_label(uri, event_id, data_project_dict): + """ + This callback updates the label of the image in the display + Args: + uri: URI of the image + event_id: Event ID + data_project_dict: Data project dictionary + Returns: + img-label: Label of the image + """ + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + label = "Not labeled" + if event_id is not None and uri is not None: + datasets = requests.get( + f"{SPLASH_URL}/datasets", + params={ + "uris": uri, + "event_id": event_id, + "project": data_project.project_id, + }, + ).json() + if len(datasets) > 0: + for dataset in datasets: + for tag in dataset["tags"]: + if tag["event_id"] == event_id: + label = f"Label: {tag['name']}" + break + return label + + +@callback( + Output("results-plot", "figure"), + Output("results-plot", "style"), + Input("img-slider", "value"), + Input("jobs-table", "selected_rows"), + Input("interval", "n_intervals"), + State("jobs-table", "data"), + State("results-plot", "figure"), + prevent_initial_call=True, +) +def refresh_results(img_ind, row, interval, data_table, current_fig): + """ + This callback updates the results in the display + Args: + img_ind: Index of image according to the slider value + row: Selected job (model) + data_table: Data in table of jobs + current_fig: Current loss plot + Returns: + results_plot: Output results with probabilities per class + results_style: Modify visibility of output results + """ + changed_id = dash.callback_context.triggered[-1]["prop_id"] + results_fig = dash.no_update + results_style_fig = dash.no_update + + if row is not None and len(row) > 0 and row[0] < len(data_table): + # Get the job logs + try: + job_data = TableJob.get_job( + USER, "mlcoach", job_id=data_table[row[0]]["job_id"] + ) + except Exception: + logger.error(traceback.format_exc()) + raise PreventUpdate + log = job_data["logs"] + + # Plot classification probabilities per class + if ( + "interval" not in changed_id + and data_table[row[0]]["job_type"] == "prediction_model" + ): + job_id = data_table[row[0]]["experiment_id"] + data_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{USER}/{job_id}") + + # Check if the results file exists + if os.path.exists(f"{data_path}/results.parquet"): + df_prob = pd.read_parquet(f"{data_path}/results.parquet") + + # Get the probabilities for the selected image + probs = df_prob.iloc[img_ind] + results_fig = get_class_prob(probs) + results_style_fig = { + "width": "100%", + "height": "100%", + "display": "block", + } + + # Plot the loss plot + elif log and data_table[row[0]]["job_type"] == "train_model": + if data_table[row[0]]["job_type"] == "train_model": + job_id = data_table[row[0]]["experiment_id"] + loss_file_path = ( + f"{DATA_DIR}/mlex_store/{USER}/{job_id}/training_log.csv" + ) + if os.path.exists(loss_file_path): + results_fig = generate_loss_plot(loss_file_path) + results_style_fig = { + "width": "100%", + "height": "100%", + "display": "block", + } + + else: + results_fig = [] + results_style_fig = {"display": "none"} + + # Do not update the plot unless loss plot changed + if ( + current_fig + and results_fig != dash.no_update + and current_fig["data"][0]["y"] == list(results_fig["data"][0]["y"]) + ): + results_fig = dash.no_update + results_style_fig = dash.no_update + + return results_fig, results_style_fig + elif current_fig: + return [], {"display": "none"} + else: + raise PreventUpdate + + +@callback( + Output("warning-modal", "is_open"), + Output("warning-msg", "children"), + Input("warning-cause", "data"), + Input("warning-cause-execute", "data"), + prevent_initial_call=True, +) +def open_warning_modal(warning_cause, warning_cause_exec): + """ + This callback opens a warning/error message + Args: + warning_cause: Cause that triggered the warning + warning_cause_exec: Execution-related cause that triggered the warning + is_open: Close/open state of the warning + """ + if warning_cause_exec == "no_row_selected": + return False, "Please select a trained model from the List of Jobs." + elif warning_cause_exec == "no_dataset": + return False, "Please upload the dataset before submitting the job." + else: + return False, "" + + +@callback( + Output("warning-modal", "is_open", allow_duplicate=True), + Output("warning-msg", "children", allow_duplicate=True), + Input("ok-button", "n_clicks"), + prevent_initial_call=True, +) +def close_warning_modal(ok_n_clicks): + return False, "" diff --git a/src/callbacks/download.py b/src/callbacks/download.py new file mode 100644 index 0000000..6fc97db --- /dev/null +++ b/src/callbacks/download.py @@ -0,0 +1,31 @@ +from dash import Input, Output, State, callback + + +@callback( + Output("download-button", "disabled"), + Input("jobs-table", "selected_rows"), + State("jobs-table", "data"), + prevent_initial_call=True, +) +def disable_download(row, job_table): + """ + This callback enables or disables the download button + """ + disabled_button = True + if row is not None and len(row) > 0 and job_table[row[0]]["status"] == "complete": + disabled_button = False + return disabled_button + + +@callback( + Output("storage-modal", "is_open"), + Input("download-button", "n_clicks"), + Input("close-storage-modal", "n_clicks"), + State("storage-modal", "is_open"), + prevent_initial_call=True, +) +def toggle_storage_modal(download, close_modal, is_open): + """ + This callback toggles the storage message in modal + """ + return not (is_open) diff --git a/src/callbacks/execute.py b/src/callbacks/execute.py new file mode 100644 index 0000000..7f17d80 --- /dev/null +++ b/src/callbacks/execute.py @@ -0,0 +1,47 @@ +from dash import Input, Output, State, callback + + +@callback( + Output("resources-setup", "is_open"), + Output("warning-cause-execute", "data"), + Input("execute", "n_clicks"), + State("action", "value"), + State("jobs-table", "data"), + State("jobs-table", "selected_rows"), + State({"base_id": "file-manager", "name": "total-num-data-points"}, "data"), + prevent_initial_call=True, +) +def execute(execute, action_selection, job_data, row, num_data_points): + """ + This callback submits a job request to the compute service according to the selected action & model + Args: + execute: Execute button + action_selection: Action selected + job_data: Lists of jobs + row: Selected row (job) + num_data_points: Total number of data points in the dataset + Returns: + open/close the resources setup modal, and submits the training/prediction job accordingly + warning_cause: Activates a warning pop-up window if needed + """ + if num_data_points == 0: + return False, "no_dataset" + elif action_selection != "train_model" and not row: + return False, "no_row_selected" + elif ( + action_selection != "train_model" + and job_data[row[0]]["job_type"] != "train_model" + ): + return False, "no_row_selected" + else: + return True, "" + + +@callback( + Output("resources-setup", "is_open", allow_duplicate=True), + Output("warning-cause", "data", allow_duplicate=True), + Input("submit", "n_clicks"), + prevent_initial_call=True, +) +def close_resources_popup(submit): + return False, "" diff --git a/src/callbacks/load_labels.py b/src/callbacks/load_labels.py new file mode 100644 index 0000000..e4fdc12 --- /dev/null +++ b/src/callbacks/load_labels.py @@ -0,0 +1,120 @@ +import time +from datetime import datetime, timezone + +import dash +import requests +from dash import Input, Output, State, callback +from file_manager.data_project import DataProject + +from src.app_layout import SPLASH_URL, TILED_KEY, logger + + +@callback( + Output("event-id", "options"), + Output("modal-load-splash", "is_open"), + Input("button-load-splash", "n_clicks"), + Input("confirm-load-splash", "n_clicks"), + Input({"base_id": "file-manager", "name": "data-project-dict"}, "data"), + State("timezone-browser", "value"), + prevent_initial_call=True, +) +def load_from_splash_modal( + load_n_click, confirm_load, data_project_dict, timezone_browser +): + """ + Load labels from splash-ml associated with the project_id + Args: + load_n_click: Number of clicks in load from splash-ml button + confirm_load: Number of clicks in confim button within loading from splash-ml modal + data_project_dict: Data project information + timezone_browser: Timezone of the browser + Returns: + event_id: Available tagging event IDs associated with the current data project + modal_load_splash: True/False to open/close loading from splash-ml modal + """ + changed_id = dash.callback_context.triggered[-1]["prop_id"] + if ( + changed_id == "confirm-load-splash.n_clicks" + ): # if confirmed, load chosen tagging event + return dash.no_update, False + # If unconfirmed, retrieve the tagging event IDs associated with the current data project + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + if len(data_project.datasets) > 0: + start = time.time() + response = requests.get( + f"{SPLASH_URL}/events", params={"page[offset]": 0, "page[limit]": 1000} + ) + + event_ids = response.json() + + # Present the tagging event options with their corresponding tagger id and runtime + temp = [] + for tagging_event in event_ids: + tagger_id = tagging_event["tagger_id"] + utc_tagging_event_time = tagging_event["run_time"] + tagging_event_time = datetime.strptime( + utc_tagging_event_time, "%Y-%m-%dT%H:%M:%S.%f" + ) + tagging_event_time = ( + tagging_event_time.replace(tzinfo=timezone.utc) + .astimezone(tz=None) + .strftime("%d-%m-%Y %H:%M:%S") + ) + temp.append( + ( + tagging_event_time, + { + "label": f"Tagger ID: {tagger_id}, modified: {tagging_event_time}", + "value": tagging_event["uid"], + }, + ) + ) + + # Sort temp by time in descending order and extract the dictionaries + options = [item[1] for item in sorted(temp, key=lambda x: x[0], reverse=True)] + + logger.info(f"Time taken to fetch tagging events: {time.time() - start}") + return options, True + else: + return dash.no_update, dash.no_update + + +@callback( + Output("img-labeled-indx", "options"), + Input("confirm-load-splash", "n_clicks"), + State({"base_id": "file-manager", "name": "data-project-dict"}, "data"), + State("event-id", "value"), + prevent_initial_call=True, +) +def get_labeled_indx(confirm_load, data_project_dict, event_id): + """ + This callback retrieves the indexes of labeled images + Args: + confirm_load: Number of clicks in "confirm loading from splash" button + data_project_dict: Data project information + event_id: Tagging event id for version control of tags + Returns: + List of indexes of labeled images in this tagging event + """ + data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY) + num_imgs = data_project.datasets[-1].cumulative_data_count + data_uris = data_project.read_datasets(list(range(num_imgs)), just_uri=True) + options = [] + if num_imgs > 0: + response = requests.post( + f"{SPLASH_URL}/datasets/search", + params={"page[limit]": num_imgs}, + json={"event_id": event_id}, + ) + for dataset in response.json(): + index = next( + ( + i + for i, uri in enumerate(data_uris) + if uri == dataset["uri"] and len(dataset["tags"]) > 0 + ), + None, + ) + if index is not None: + options.append(index) + return sorted(options) diff --git a/src/callbacks/table.py b/src/callbacks/table.py new file mode 100644 index 0000000..86e3701 --- /dev/null +++ b/src/callbacks/table.py @@ -0,0 +1,115 @@ +import dash +from dash import Input, Output, State, callback, dcc + +from src.app_layout import USER +from src.utils.job_utils import TableJob + + +@callback( + Output("jobs-table", "data"), + Input("interval", "n_intervals"), + State("jobs-table", "data"), +) +def update_table(n, current_job_table): + """ + This callback updates the job table + Args: + n: Time intervals that trigger this callback + current_job_table: Current job table + Returns: + jobs-table: Updates the job table + """ + job_list = TableJob.get_job(USER, "mlcoach") + data_table = [] + if job_list is not None: + for job in job_list: + simple_job = TableJob.compute_job_to_table_job(job) + data_table.insert(0, simple_job.__dict__) + if data_table == current_job_table: + data_table = dash.no_update + return data_table + + +@callback( + Output("info-modal", "is_open"), + Output("info-display", "children"), + Input("show-info", "n_clicks"), + Input("modal-close", "n_clicks"), + State("jobs-table", "data"), + State("info-modal", "is_open"), + State("jobs-table", "selected_rows"), +) +def open_job_modal(n_clicks, close_clicks, current_job_table, is_open, rows): + """ + This callback updates shows the job logs and parameters + Args: + n_clicks: Number of clicks in "show details" button + close_clicks: Close modal with close-up details of selected cell + current_job_table: Current job table + is_open: Open/close modal state + rows: Selected rows in jobs table + Returns: + info-modal: Open/closes the modal + info-display: Display the job logs and parameters + """ + if not is_open and rows is not None and len(rows) > 0: + job_id = current_job_table[rows[0]]["job_id"] + job_data = TableJob.get_job(USER, "mlcoach", job_id=job_id) + logs = job_data["logs"] + params = job_data["job_kwargs"]["kwargs"]["params"] + info_display = dcc.Textarea( + value=f"Parameters: {params}\n\nLogs: {logs}", + style={"width": "100%", "height": "30rem", "font-family": "monospace"}, + ) + return True, info_display + else: + return False, dash.no_update + + +@callback( + Output("jobs-table", "selected_rows"), + Input("deselect-row", "n_clicks"), + prevent_initial_call=True, +) +def deselect_row(n_click): + """ + This callback deselects the row in the data table + """ + if n_click: + return [] + else: + raise dash.exceptions.PreventUpdate + + +@callback( + Output("delete-modal", "is_open"), + Input("confirm-delete-row", "n_clicks"), + Input("delete-row", "n_clicks"), + Input("stop-row", "n_clicks"), + State("jobs-table", "selected_rows"), + State("jobs-table", "data"), + prevent_initial_call=True, +) +def delete_row(confirm_delete, delete, stop, row, job_data): + """ + This callback deletes the selected model in the table + Args: + confirm_delete: Number of clicks in "confirm delete row" button + delete: Number of clicks in "delete row" button + stop: Number of clicks in "stop job at row" button + row: Selected row in jobs table + job_data: Data within jobs table + Returns: + Open/closes confirmation modal + """ + changed_id = [p["prop_id"] for p in dash.callback_context.triggered][0] + if "delete-row.n_clicks" == changed_id: + return True + elif "stop-row.n_clicks" == changed_id: + job_uid = job_data[row[0]]["job_id"] + TableJob.terminate_job(job_uid) + return False + else: + job_uid = job_data[row[0]]["job_id"] + TableJob.delete_job(job_uid) + return False diff --git a/src/components/__init__.py b/src/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/components/header.py b/src/components/header.py new file mode 100644 index 0000000..21be06c --- /dev/null +++ b/src/components/header.py @@ -0,0 +1,92 @@ +import dash_bootstrap_components as dbc +from dash import html + + +def header(app_title, github_url): + """ + This header will exist at the top of the webpage rather than browser tab + Args: + app_title: Title of dash app + github_url: URL to github repo + """ + header = dbc.Navbar( + dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + html.Img( + id="logo", + src="assets/mlex.png", + height="60px", + ), + md="auto", + ), + dbc.Col( + [ + html.Div( + [ + html.H3(app_title), + ], + id="app-title", + ) + ], + md=True, + align="center", + ), + ], + align="center", + ), + dbc.Row( + [ + dbc.Col( + [ + dbc.NavbarToggler(id="navbar-toggler"), + dbc.Collapse( + dbc.Nav( + [ + dbc.NavItem( + dbc.Button( + className="fa fa-github", + style={ + "font-size": "30px", + "margin-right": "1rem", + "color": "#00313C", + "border": "0px", + "background-color": "white", + }, + href=github_url, + ) + ), + dbc.NavItem( + dbc.Button( + className="fa fa-question-circle-o", + style={ + "font-size": "30px", + "color": "#00313C", + "background-color": "white", + "border": "0px", + }, + href="https://mlexchange.als.lbl.gov", + ) + ), + ], + navbar=True, + ), + id="navbar-collapse", + navbar=True, + ), + ], + md=2, + ), + ], + align="center", + ), + ], + fluid=True, + ), + dark=True, + color="dark", + sticky="top", + ) + return header diff --git a/src/components/job_table.py b/src/components/job_table.py new file mode 100644 index 0000000..872e644 --- /dev/null +++ b/src/components/job_table.py @@ -0,0 +1,193 @@ +import dash_bootstrap_components as dbc +from dash import dash_table, dcc + + +def job_table(): + job_table = dbc.Card( + style={"margin-top": "0rem"}, + children=[ + dbc.CardHeader("List of Jobs"), + dbc.CardBody( + children=[ + dbc.Row( + [ + dbc.Toast( + "Your ML job is being prepared, it will be shown in the table shortly.", + id="job-alert", + dismissable=True, + is_open=False, + header="Job Notification", + icon="info", + style={ + "position": "fixed", + "top": "85%", + "right": "1%", + "width": 350, + "z-index": "1050", + }, + duration=5000, + ), + dbc.Toast( + "Your ML job has been succesfully submitted.", + id="job-alert-confirm", + dismissable=True, + is_open=False, + header="Job Notification", + icon="success", + style={ + "position": "fixed", + "top": "75%", + "right": "1%", + "width": 350, + "z-index": "1050", + }, + duration=5000, + ), + dbc.Col( + dbc.Button( + "Deselect Row", + id="deselect-row", + style={"width": "100%", "margin-bottom": "1rem"}, + ) + ), + dbc.Col( + dbc.Button( + "Show Details", + id="show-info", + style={"width": "100%", "margin-bottom": "1rem"}, + ) + ), + dbc.Col( + [ + dbc.Button( + "Download Results", + id="download-button", + style={ + "width": "100%", + "margin-bottom": "1rem", + }, + disabled=True, + ), + dcc.Download(id="download-out"), + dbc.Modal( + [ + dbc.ModalBody( + "Download will start shortly" + ), + dbc.ModalFooter( + dbc.Button( + "OK", id="close-storage-modal" + ) + ), + ], + id="storage-modal", + is_open=False, + ), + ] + ), + dbc.Col( + dbc.Button( + "Stop Job", + id="stop-row", + color="warning", + style={"width": "100%"}, + ) + ), + dbc.Col( + dbc.Button( + "Delete Job", + id="delete-row", + color="danger", + style={"width": "100%"}, + ) + ), + ] + ), + dash_table.DataTable( + id="jobs-table", + columns=[ + {"name": "Job ID", "id": "job_id"}, + {"name": "Type", "id": "job_type"}, + {"name": "Name", "id": "name"}, + {"name": "Status", "id": "status"}, + {"name": "Parameters", "id": "parameters"}, + {"name": "Experiment ID", "id": "experiment_id"}, + {"name": "Dataset", "id": "dataset"}, + ], + data=[], + hidden_columns=[ + "job_id", + "experiment_id", + "dataset", + "parameters", + ], + row_selectable="single", + style_cell={ + "padding": "1rem", + "textAlign": "left", + "overflow": "hidden", + "textOverflow": "ellipsis", + "maxWidth": 0, + }, + fixed_rows={"headers": True}, + css=[{"selector": ".show-hide", "rule": "display: none"}], + page_size=2, + style_data_conditional=[ + { + "if": { + "column_id": "status", + "filter_query": "{status} = complete", + }, + "backgroundColor": "green", + "color": "white", + }, + { + "if": { + "column_id": "status", + "filter_query": "{status} = failed", + }, + "backgroundColor": "red", + "color": "white", + }, + ], + style_table={"overflowY": "auto"}, + ), + ], + ), + dbc.Modal( + [ + dbc.ModalHeader("Warning"), + dbc.ModalBody( + 'Models cannot be recovered after deletion. \ + Do you still want to proceed?"' + ), + dbc.ModalFooter( + [ + dbc.Button( + "OK", + id="confirm-delete-row", + color="danger", + outline=False, + className="ms-auto", + n_clicks=0, + ), + ] + ), + ], + id="delete-modal", + is_open=False, + ), + dbc.Modal( + [ + dbc.ModalHeader("Job Information"), + dbc.ModalBody(id="info-display"), + dbc.ModalFooter( + dbc.Button("Close", id="modal-close", className="ml-auto") + ), + ], + id="info-modal", + size="xl", + ), + ], + ) + return job_table diff --git a/src/components/main_display.py b/src/components/main_display.py new file mode 100644 index 0000000..3da7a07 --- /dev/null +++ b/src/components/main_display.py @@ -0,0 +1,129 @@ +import dash_bootstrap_components as dbc +from dash import dcc, html + +from src.utils.plot_utils import plot_figure + + +def main_display(job_table): + """ + Creates the dash components within the main display in the app + Args: + job_table: Job table + """ + main_display = html.Div( + [ + dbc.Row( + [ + dbc.Col( + dbc.Card( + id="inter_graph", + style={"width": "100%"}, + children=[ + dbc.CardHeader("Data Overview"), + dbc.CardBody( + children=[ + dcc.Loading( + id="loading-display", + parent_className="transparent-loader-wrapper", + children=[ + html.Div( + [ + html.Img( + id="img-output", + src=plot_figure(), + style={ + "height": "60%", + "display": "block", + "margin": "auto", + }, + ), + dcc.Store( + id="img-output-store", + data=None, + ), + ] + ), + ], + ), + dcc.Store(id="img-uri", data=None), + html.Div( + [ + dbc.Label( + id="img-label", + style={"height": "2rem"}, + ), + dcc.Slider( + id="img-slider", + min=0, + step=1, + marks=None, + value=0, + tooltip={ + "placement": "bottom", + "always_visible": True, + }, + ), + dbc.Row( + [ + dbc.Col( + dbc.Label( + "List of labeled images:", + style={ + "height": "100%", + "display": "flex", + "align-items": "center", + }, + ), + ), + dbc.Col( + dcc.Loading( + id="loading-labeled-imgs", + parent_className="transparent-loader-wrapper", + children=[ + dcc.Dropdown( + id="img-labeled-indx", + options=[], + clearable=False, + ), + ], + ), + ), + ] + ), + ], + style={"vertical-align": "bottom"}, + ), + ], + style={ + "height": "45vh", + "vertical-align": "bottom", + }, + ), + ], + ), + width=5, + ), + dbc.Col( + dbc.Card( + id="results", + children=[ + dbc.CardHeader("Results"), + dbc.CardBody( + children=[ + dcc.Graph( + id="results-plot", style={"display": "none"} + ) + ], + style={"height": "45vh"}, + ), + ], + ), + width=7, + ), + ], + ), + job_table, + dcc.Interval(id="interval", interval=5 * 1000, n_intervals=0), + ] + ) + return main_display diff --git a/src/components/resources_setup.py b/src/components/resources_setup.py new file mode 100644 index 0000000..9f8454a --- /dev/null +++ b/src/components/resources_setup.py @@ -0,0 +1,55 @@ +import dash_bootstrap_components as dbc +from dash import html + + +def resources_setup(num_processors, num_gpus): + """ + Window for computing resources setup before job execution + Args: + num_processors: Maximum number of processors at host + num_gpus: Maximum number of gpus at host + """ + resources_setup = html.Div( + [ + dbc.Modal( + [ + dbc.ModalHeader("Choose number of computing resources:"), + dbc.ModalBody( + children=[ + dbc.Row( + [ + dbc.Label( + f"Number of CPUs (Maximum available: {num_processors})" + ), + dbc.Input(id="num-cpus", type="int", value=2), + ] + ), + dbc.Row( + [ + dbc.Label( + f"Number of GPUs (Maximum available: {num_gpus})" + ), + dbc.Input(id="num-gpus", type="int", value=0), + ] + ), + dbc.Row( + [ + dbc.Label("Model Name"), + dbc.Input(id="model-name", type="str", value=""), + ] + ), + ] + ), + dbc.ModalFooter( + dbc.Button( + "Submit Job", id="submit", className="ms-auto", n_clicks=0 + ) + ), + ], + id="resources-setup", + centered=True, + is_open=False, + ), + ] + ) + return resources_setup diff --git a/src/components/sidebar.py b/src/components/sidebar.py new file mode 100644 index 0000000..b5494b3 --- /dev/null +++ b/src/components/sidebar.py @@ -0,0 +1,175 @@ +import dash_bootstrap_components as dbc +import dash_daq as daq +from dash import dcc + + +def sidebar(file_explorer, models, counters): + """ + Creates the dash components in the left sidebar of the app + Args: + file_explorer: Dash file explorer + models: Currently available ML algorithms in content registry + counters: Init training and testing model counters to be used by default when no + job description/name is added + """ + sidebar = [ + dbc.Accordion( + id="sidebar", + children=[ + dbc.AccordionItem( + title="Data selection", + children=[ + file_explorer, + dbc.Button( + "Load Labels from Splash-ML", + id="button-load-splash", + color="primary", + style={"width": "100%", "margin-top": "10px"}, + ), + dbc.Modal( + [ + dbc.ModalHeader(dbc.ModalTitle("Labeling versions")), + dbc.ModalBody( + [ + dcc.Input( + id="timezone-browser", + style={"display": "none"}, + ), + dcc.Dropdown(id="event-id"), + ] + ), + dbc.ModalFooter( + [ + dbc.Button( + "LOAD", + id="confirm-load-splash", + color="primary", + outline=False, + className="ms-auto", + n_clicks=0, + ) + ] + ), + ], + id="modal-load-splash", + is_open=False, + ), + ], + ), + dbc.AccordionItem( + title="Data transformation", + children=[ + dbc.Label("Log-transform"), + daq.BooleanSwitch( + id="log-transform", + on=False, + ), + ], + ), + dbc.AccordionItem( + title="Model configuration", + children=[ + dbc.Row( + [ + dbc.Col( + dbc.Label( + "Action", + style={ + "height": "100%", + "display": "flex", + "align-items": "center", + }, + ), + width=2, + ), + dbc.Col( + dcc.Dropdown( + id="action", + options=[ + {"label": "Train", "value": "train_model"}, + { + "label": "Prediction", + "value": "prediction_model", + }, + ], + value="train_model", + ), + width=10, + ), + ], + className="mb-3", + ), + dbc.Row( + [ + dbc.Col( + dbc.Label( + "Model", + style={ + "height": "100%", + "display": "flex", + "align-items": "center", + }, + ), + width=2, + ), + dbc.Col( + dcc.Dropdown( + id="model-selection", + options=models, + value=models[0]["value"], + ), + width=10, + ), + ], + className="mb-3", + ), + dbc.Card( + [ + dbc.CardBody( + id="app-parameters", + style={ + "overflowY": "scroll", + "height": "50vh", # Adjust as needed + }, + ), + ] + ), + dbc.Button( + "Execute", + id="execute", + n_clicks=0, + style={ + "width": "100%", + "margin-left": "0px", + "margin-top": "10px", + }, + ), + ], + ), + ], + ), + dbc.Modal( + [ + dbc.ModalHeader("Warning"), + dbc.ModalBody(id="warning-msg"), + dbc.ModalFooter( + [ + dbc.Button( + "OK", + id="ok-button", + color="danger", + outline=False, + className="ms-auto", + n_clicks=0, + ), + ] + ), + ], + id="warning-modal", + is_open=False, + ), + dcc.Store(id="warning-cause", data=""), + dcc.Store(id="warning-cause-execute", data=""), + dcc.Store(id="counters", data=counters), + ] + return sidebar diff --git a/src/file_manager.py b/src/file_manager.py deleted file mode 100644 index cdf62ab..0000000 --- a/src/file_manager.py +++ /dev/null @@ -1,411 +0,0 @@ -import copy -import os -import pathlib -import requests - -import dash_bootstrap_components as dbc -import dash_core_components as dcc -import dash_daq as daq -import dash_html_components as html -import dash_table -import dash_uploader as du - -DOCKER_DATA = pathlib.Path.home() / 'data' -LOCAL_DATA = str(os.environ['DATA_DIR']) -DOCKER_HOME = str(DOCKER_DATA) + '/' -LOCAL_HOME = str(LOCAL_DATA) - -UPLOAD_FOLDER_ROOT = DOCKER_DATA / 'upload' -DATAPATH_DEFAULT, FILENAMES_DEFAULT = [], [] -try: - DATAPATH = requests.get(f'http://labelmaker-api:8005/api/v0/datapath/export_dataset').json() -except Exception as e: - DATAPATH = False - print(e) -if DATAPATH: - if bool(DATAPATH['datapath']): - if DATAPATH['datapath']['where'] != 'splash': - if DATAPATH['datapath']['file_path']: - if os.path.isdir(DATAPATH['datapath']['file_path'][0]): - DATAPATH_DEFAULT = DATAPATH['datapath']['file_path'][0] - FILENAMES_DEFAULT = DATAPATH['filenames'] - -# FILES DISPLAY -file_paths_table = html.Div( - children=[ - dash_table.DataTable( - id='files-table', - columns=[ - {'name': 'type', 'id': 'file_type'}, - {'name': 'File Table', 'id': 'file_path'}, - ], - data = [], - hidden_columns = ['file_type'], - row_selectable='single', #'multi', - style_cell={'padding': '0.5rem', 'textAlign': 'left'}, - fixed_rows={'headers': False}, - css=[{"selector": ".show-hide", "rule": "display: none"}], - style_data_conditional=[ - {'if': {'filter_query': '{file_type} = dir'}, - 'color': 'blue'}, - ], - style_table={'height':'18rem', 'overflowY': 'auto'} - ) - ] - ) - - -# UPLOAD DATASET OR USE PRE-DEFINED DIRECTORY -data_access = html.Div([ - dbc.Card([ - dbc.CardBody(id='data-body', - children=[ - dbc.Label('Upload a new file or folder (zip) to work dir:', className='mr-2'), - html.Div([html.Div([du.Upload( - id="dash-uploader", - max_file_size=1800, # 1800 Mb - cancel_button=True, - pause_button=True)], - style={ # wrapper div style - 'textAlign': 'center', - 'width': '300px', - 'padding': '5px', - 'display': 'inline-block', - 'margin-bottom': '30px', - 'margin-right': '20px'}), - html.Div([ - dbc.Col([ - dbc.Label("Dataset is by default uploaded to '{}'. \ - You can move the selected files or dirs (from File Table) \ - into a new dir.".format(UPLOAD_FOLDER_ROOT), className='mr-5'), - dbc.Label("Home data dir (HOME) is '{}'.".format(DOCKER_DATA), - className='mr-5'), - html.Div([ - dbc.Label('Move data into dir:'.format(DOCKER_DATA), className='mr-5'), - dcc.Input(id='dest-dir-name', placeholder="Input relative path to HOME", - style={'width': '40%', 'margin-bottom': '10px'}), - dbc.Button("Move", - id="move-dir", - className="ms-auto", - color="secondary", - outline=True, - n_clicks=0, - #disabled = True, - style={'width': '22%', 'margin': '5px'}), - ], - style = {'width': '100%', 'display': 'flex', 'align-items': 'center'}, - ) - ]) - ]) - ], - style = {'width': '100%', 'display': 'flex', 'align-items': 'center'} - ), - dbc.Label('Choose files/directories:', className='mr-2'), - html.Div( - [dbc.Button("Browse", - id="browse-dir", - className="ms-auto", - color="secondary", - outline=True, - n_clicks=0, - style={'width': '15%', 'margin': '5px'}), - html.Div([ - dcc.Dropdown( - id='browse-format', - options=[ - {'label': 'dir', 'value': 'dir'}, - {'label': 'all (*)', 'value': '*'}, - {'label': '.png', 'value': '*.png'}, - {'label': '.jpg/jpeg', 'value': '*.jpg,*.jpeg'}, - {'label': '.tif/tiff', 'value': '*.tif,*.tiff'}, - {'label': '.txt', 'value': '*.txt'}, - {'label': '.csv', 'value': '*.csv'}, - {'label': '.npz', 'value': '*.npz'}, - ], - value='dir') - ], - style={"width": "15%", 'margin-right': '60px'} - ), - dbc.Button("Delete the Selected", - id="delete-files", - className="ms-auto", - color="danger", - outline=True, - n_clicks=0, - style={'width': '22%', 'margin-right': '10px'} - ), - dbc.Modal( - [ - dbc.ModalHeader("Warning"), - dbc.ModalBody("Files cannot be recovered after deletion. \ - Do you still want to proceed?"), - dbc.ModalFooter([ - dbc.Button( - "Delete", id="confirm-delete", color='danger', outline=False, - className="ms-auto", n_clicks=0 - ), - ]), - ], - id="modal", - is_open=False, - style = {'color': 'red'} - ), - dbc.Button("Import", - id="import-dir", - className="ms-auto", - color="secondary", - outline=True, - n_clicks=0, - style={'width': '22%', 'margin': '5px'} - ), - html.Div([ - dcc.Dropdown( - id='import-format', - options=[ - {'label': 'all files (*)', 'value': '*'}, - {'label': '.png', 'value': '*.png'}, - {'label': '.jpg/jpeg', 'value': '*.jpg,*.jpeg'}, - {'label': '.tif/tiff', 'value': '*.tif,*.tiff'}, - {'label': '.txt', 'value': '*.txt'}, - {'label': '.csv', 'value': '*.csv'}, - ], - value='*') - ], - style={"width": "15%"} - ), - ], - style = {'width': '100%', 'display': 'flex', 'align-items': 'center'}, - ), - html.Div([ html.Div([dbc.Label('Show Local/Docker Path')], style = {'margin-right': '10px'}), - daq.ToggleSwitch( - id='my-toggle-switch', - value=False - )], - style = {'width': '100%', 'display': 'flex', 'align-items': 'center', 'margin': '10px', - 'margin-left': '0px'}, - ), - file_paths_table, - ]), - ], - id="data-access", - ) -]) - - -file_explorer = html.Div( - [ - dbc.Row([ - dbc.Col(dbc.Button( - "Load/Refresh Data", - id="refresh-data", - size="lg", - className='m-1', - color="secondary", - outline=True, - n_clicks=0, - style={'width': '100%'}), width=7), - dbc.Col(dbc.Button( - "Clear Data", - id="clear-data", - size="lg", - className='m-1', - color="secondary", - outline=True, - n_clicks=0, - style={'width': '100%'}), width=5) - ], - justify = 'center' - ), - dbc.Button( - "Open File Manager", - id="collapse-button", - size="lg", - className='m-1', - color="secondary", - outline=True, - n_clicks=0, - style={'width': '100%', 'justify-content': 'center'} - ), - dbc.Modal( - data_access, - id="collapse", - is_open=False, - size='xl' - ), - dbc.Modal( - [ - dbc.ModalHeader("You have selected a ZIP file"), - dbc.ModalBody([ - dbc.Label("Keyword for images (x_train):"), - dcc.Dropdown(id='npz-img-key'), - dbc.Label("Keyword for labels (y_train):"), - dcc.Dropdown(id='npz-label-key'), - ]), - dbc.ModalFooter([ - dbc.Button( - "Confirm Import", id="confirm-import", outline=False, - className="ms-auto", n_clicks=0 - ), - ]), - ], - id="npz-modal", - is_open=False, - ), - dcc.Store(id='dummy-data', data=[]), - dcc.Store(id='docker-file-paths', data=FILENAMES_DEFAULT), - dcc.Store(id='data-path', data=DATAPATH_DEFAULT), - ] -) - - - -def move_a_file(source, destination): - ''' - Args: - source, str: full path of a file from source directory - destination, str: full path of destination directory - ''' - pathlib.Path(destination).mkdir(parents=True, exist_ok=True) - filename = source.split('/')[-1] - new_destination = destination + '/' + filename - os.rename(source, new_destination) - - -def move_dir(source, destination): - ''' - Args: - source, str: full path of source directory - destination, str: full path of destination directory - ''' - dir_path, list_dirs, filenames = next(os.walk(source)) - original_dir_name = dir_path.split('/')[-1] - destination = destination + '/' + original_dir_name - pathlib.Path(destination).mkdir(parents=True, exist_ok=True) - for filename in filenames: - file_source = dir_path + '/' + filename - move_a_file(file_source, destination) - - for dirname in list_dirs: - dir_source = dir_path + '/' + dirname - move_dir(dir_source, destination) - - -def add_paths_from_dir(dir_path, supported_formats, list_file_path): - ''' - Args: - dir_path, str: full path of a directory - supported_formats, list: supported formats, e.g., ['tiff', 'tif', 'jpg', 'jpeg', 'png'] - list_file_path, [str]: list of absolute file paths - - Returns: - Adding unique file paths to list_file_path, [str] - ''' - root_path, list_dirs, filenames = next(os.walk(dir_path)) - for filename in filenames: - exts = filename.split('.') - if exts[-1] in supported_formats and exts[0] != '': - file_path = root_path + '/' + filename - if file_path not in list_file_path: - list_file_path.append(file_path) - - for dirname in list_dirs: - new_dir_path = dir_path + '/' + dirname - list_file_path = add_paths_from_dir(new_dir_path, supported_formats, list_file_path) - - return list_file_path - - -def filename_list(directory, form): - ''' - Args: - directory, str: full path of a directory - format, list(str): list of supported formats - Return: - a full list of absolute file path (filtered by file formats) inside a directory. - ''' - hidden_formats = ['DS_Store'] - files = [] - if form == 'dir': - if os.path.exists(directory): - for filepath in pathlib.Path(directory).glob('**/*'): - if os.path.isdir(filepath): - files.append({'file_path': str(filepath.absolute()), 'file_type': 'dir'}) - else: - form = form.split(',') - for f_ext in form: - if os.path.exists(directory): - for filepath in pathlib.Path(directory).glob('**/{}'.format(f_ext)): - if os.path.isdir(filepath): - files.append({'file_path': str(filepath.absolute()), 'file_type': 'dir'}) - else: - filename = str(filepath).split('/')[-1] - exts = filename.split('.') - if exts[-1] not in hidden_formats and exts[0] != '': - files.append({'file_path': str(filepath.absolute()), 'file_type': 'file'}) - - return files - - -def check_duplicate_filename(dir_path, filename): - root_path, list_dirs, filenames = next(os.walk(dir_path)) - if filename in filenames: - return True - else: - return False - - -def docker_to_local_path(paths, docker_home, local_home, type='list-dict'): - ''' - Args: - paths: docker file paths - docker_home, str: full path of home dir (ends with '/') in docker environment - local_home, str: full path of home dir (ends with '/') mounted in local machine - type: - list-dict, default: a list of dictionary (docker paths), e.g., [{'file_path': 'docker_path1'},{...}] - str: a single file path string - Return: - replace docker path with local path. - ''' - if type == 'list-dict': - files = copy.deepcopy(paths) - for file in files: - if not file['file_path'].startswith(local_home): - file['file_path'] = local_home + file['file_path'].split(docker_home)[-1] - - if type == 'str': - if not paths.startswith(local_home): - files = local_home + paths.split(docker_home)[-1] - else: - files = paths - - return files - - -def local_to_docker_path(paths, docker_home, local_home, type='list'): - ''' - Args: - paths: selected local (full) paths - docker_home, str: full path of home dir (ends with '/') in docker environment - local_home, str: full path of home dir (ends with '/') mounted in local machine - type: - list: a list of path string - str: single path string - Return: - replace local path with docker path - ''' - if type == 'list': - files = [] - for i in range(len(paths)): - if not paths[i].startswith(docker_home): - files.append(docker_home + paths[i].split(local_home)[-1]) - else: - files.append(paths[i]) - - if type == 'str': - if not paths.startswith(docker_home): - files = docker_home + paths.split(local_home)[-1] - else: - files = paths - - return files - diff --git a/src/frontend.py b/src/frontend.py deleted file mode 100644 index 30295f9..0000000 --- a/src/frontend.py +++ /dev/null @@ -1,962 +0,0 @@ -import ast -import json -import os -import pathlib -import shutil -import zipfile - -import dash -from dash.dependencies import Input, Output, State, MATCH, ALL -import dash_bootstrap_components as dbc -import dash_core_components as dcc -import dash_html_components as html -import dash_table -import dash_daq as daq -import dash_uploader as du -import numpy as np -import pandas as pd -import PIL.Image as Image -import plotly.graph_objects as go -import uuid -import requests - -from file_manager import filename_list, move_a_file, move_dir, add_paths_from_dir, \ - check_duplicate_filename, docker_to_local_path, local_to_docker_path, \ - file_explorer, DOCKER_DATA, DOCKER_HOME, LOCAL_HOME, UPLOAD_FOLDER_ROOT -from helpers import SimpleJob -from helpers import get_job, generate_figure, get_class_prob, model_list_GET_call, plot_figure, get_gui_components,\ - get_counter, load_from_splash, get_host -from kwarg_editor import JSONParameterEditor -import templates - -external_stylesheets = [dbc.themes.BOOTSTRAP, "../assets/segmentation-style.css"] -app = dash.Dash(__name__, external_stylesheets=external_stylesheets, suppress_callback_exceptions=True) - -# Global variables -DATA_DIR = str(os.environ['DATA_DIR']) -USER = 'admin' -MODELS = model_list_GET_call() -HOST_NICKNAME = str(os.environ['HOST_NICKNAME']) -num_processors, num_gpus = get_host(HOST_NICKNAME) - - -RESOURCES_SETUP = html.Div( - [ - dbc.Modal( - [ - dbc.ModalHeader("Choose number of computing resources:"), - dbc.ModalBody( - children=[ - dbc.FormGroup([ - dbc.Label(f'Number of CPUs (Maximum available: {num_processors})'), - dbc.Input(id='num-cpus', - type="int", - value=2)]), - dbc.FormGroup([ - dbc.Label(f'Number of GPUs (Maximum available: {num_gpus})'), - dbc.Input(id='num-gpus', - type="int", - value=0)]), - dbc.FormGroup([ - dbc.Label('Model Name'), - dbc.Input(id='model-name', - type="str", - value="")]) - ]), - dbc.ModalFooter( - dbc.Button( - "Submit Job", id="submit", className="ms-auto", n_clicks=0 - ) - ), - ], - id="resources-setup", - centered=True, - is_open=False, - ), - ] -) - - -# Job Status Display -JOB_STATUS = dbc.Card( - children=[ - dbc.CardHeader("List of Jobs"), - dbc.CardBody( - children=[ - dbc.Row( - [ - dbc.Button("Deselect Row", id="deselect-row", style={'margin-left': '1rem'}), - dbc.Button("Stop Job", id="stop-row", color='warning'), - dbc.Button("Delete Job", id="delete-row", color='danger'), - ] - ), - dash_table.DataTable( - id='jobs-table', - columns=[ - {'name': 'Job ID', 'id': 'job_id'}, - {'name': 'Type', 'id': 'job_type'}, - {'name': 'Name', 'id': 'name'}, - {'name': 'Status', 'id': 'status'}, - {'name': 'Parameters', 'id': 'parameters'}, - {'name': 'Experiment ID', 'id': 'experiment_id'}, - {'name': 'Dataset', 'id': 'dataset'}, - {'name': 'Logs', 'id': 'job_logs'} - ], - data=[], - hidden_columns=['job_id', 'experiment_id', 'dataset'], - row_selectable='single', - style_cell={'padding': '1rem', - 'textAlign': 'left', - 'overflow': 'hidden', - 'textOverflow': 'ellipsis', - 'maxWidth': 0}, - fixed_rows={'headers': True}, - css=[{"selector": ".show-hide", "rule": "display: none"}], - style_data_conditional=[ - {'if': {'column_id': 'status', 'filter_query': '{status} = complete'}, - 'backgroundColor': 'green', - 'color': 'white'}, - {'if': {'column_id': 'status', 'filter_query': '{status} = failed'}, - 'backgroundColor': 'red', - 'color': 'white'} - ], - page_size=8, - style_table={'height': '30rem', 'overflowY': 'auto', 'overflowX': 'scroll'} - ) - ], - ), - dbc.Modal( - [ - dbc.ModalHeader("Warning"), - dbc.ModalBody('Models cannot be recovered after deletion. \ - Do you still want to proceed?"'), - dbc.ModalFooter([ - dbc.Button( - "OK", id="confirm-delete-row", color='danger', outline=False, - className="ms-auto", n_clicks=0 - ), - ]), - ], - id="delete-modal", - is_open=False, - ), - dbc.Modal([ - dbc.ModalHeader("Job Logs"), - dbc.ModalBody(id='log-display'), - dbc.ModalFooter(dbc.Button("Close", id="modal-close", className="ml-auto")), - ], - id='log-modal', - size='xl') - ] -) - -# Sidebar with actions, model, and parameters selection -SIDEBAR = [ - dbc.Card( - id="sidebar", - children=[ - dbc.CardHeader("Select an Action & a Model"), - dbc.CardBody([ - dbc.FormGroup([ - dbc.Label('Action'), - dcc.Dropdown( - id='action', - options=[ - {'label': 'Model Training', 'value': 'train_model'}, - # {'label': 'Evaluate Model on Data', 'value': 'evaluate_model'}, - {'label': 'Test Prediction using Model', 'value': 'prediction_model'}, - # {'label': 'Transfer Learning', 'value': 'transfer_learning'}, - ], - value='train_model') - ]), - dbc.FormGroup([ - dbc.Label('Model'), - dcc.Dropdown( - id='model-selection', - options=MODELS, - value=MODELS[0]['value']) - ]), - dbc.FormGroup([ - dbc.Label('Data'), - file_explorer, - ]), - dbc.Button('Execute', - id='execute', - n_clicks=0, - className='m-1', - style={'width': '100%', 'justify-content': 'center'}) - ]) - ] - ), - dbc.Card( - children=[ - dbc.CardHeader("Parameters"), - dbc.CardBody(html.Div(id='app-parameters')) - ] - ), - dbc.Modal( - [ - dbc.ModalHeader("Warning"), - dbc.ModalBody(id="warning-msg"), - dbc.ModalFooter([ - dbc.Button( - "OK", id="ok-button", color='danger', outline=False, - className="ms-auto", n_clicks=0 - ), - ]), - ], - id="warning-modal", - is_open=False, - ), - dcc.Store(id='warning-cause', data=''), - dcc.Store(id='warning-cause-execute', data=''), - dcc.Store(id='counter', data=get_counter(USER)), - dcc.Store(id='splash-indicator', data=False) -] - -# App contents (right hand side) -CONTENT = [ - html.Div([dbc.Row([ - dbc.Col(dbc.Card( - children=[dbc.CardHeader('Data Overview'), - dbc.CardBody(children=[ - html.Div( - id='app-content', - children = [dcc.Graph(id='img-output'), - html.Output(id='label-output', - style={'height': '2rem', 'overflow': 'hidden', - 'text-overflow': 'hidden'}), - dbc.Label(id='current-image-label'), - dcc.Slider(id='img-slider', - min=0, - value=0, - tooltip={'always_visible': True, 'placement': 'bottom'}) - ], - style={'display': 'none'}), - ], style={'height': '34rem'}) - ]), - width=5), - dbc.Col(dbc.Card( - id = 'results', - children=[dbc.CardHeader('Results'), - dbc.CardBody(children = [dcc.Graph(id='results-plot', - style={'display': 'none'}), - dcc.Textarea(id='results-text', - style={'display': 'none'}, - className='mb-2'), - dbc.Button('Download Results', - id='download-button', - n_clicks=0, - className='m-1', - style={'display': 'None'}), - dcc.Download(id='download-out') - ], - style={'height': '34rem'})]), - width=7)]), - dcc.Interval(id='interval', interval=5 * 1000, n_intervals=0) - ]), - JOB_STATUS -] - -# Setting up initial webpage layout -app.title = 'MLCoach' -app._favicon = 'mlex.ico' -du.configure_upload(app, UPLOAD_FOLDER_ROOT, use_upload_id=False) -app.layout = html.Div([templates.header(), - dbc.Container([ - dbc.Row([dbc.Col(SIDEBAR, width=3), - dbc.Col(CONTENT, - width=9, - style={'align-items': 'center', 'justify-content': 'center'}), - html.Div(id='dummy-output') - ]), - RESOURCES_SETUP], - fluid=True - )]) - - -@app.callback( - Output("collapse", "is_open"), - - Input("collapse-button", "n_clicks"), - Input("import-dir", "n_clicks"), - - State("collapse", "is_open") -) -def toggle_collapse(collapse_button, import_button, is_open): - ''' - This callback toggles the file manager - Args: - collapse_button: "Open File Manager" button - import_button: Import button - is_open: Open/close File Manager modal state - ''' - if collapse_button or import_button: - return not is_open - return is_open - - -@app.callback( - Output("warning-modal", "is_open"), - Output("warning-msg", "children"), - - Input("warning-cause", "data"), - Input("warning-cause-execute", "data"), - Input("ok-button", "n_clicks"), - - State("warning-modal", "is_open"), - prevent_initial_call=True -) -def toggle_warning_modal(warning_cause, warning_cause_exec, ok_n_clicks, is_open): - ''' - This callback toggles a warning/error message - Args: - warning_cause: Cause that triggered the warning - ok_n_clicks: Close the warning - is_open: Close/open state of the warning - ''' - changed_id = dash.callback_context.triggered[0]['prop_id'] - if 'ok-button.n_clicks' in changed_id: - return False, "" - if warning_cause == 'wrong_dataset': - return not is_open, "The dataset you have selected is not supported. Please select (1) a data directory " \ - "where each subfolder corresponds to a given category, OR (2) an NPZ file." - if warning_cause == 'different_size': - return not is_open, "The number of images and labels do not match. Please select a different dataset." - if warning_cause_exec == 'no_row_selected': - return not is_open, "Please select a trained model from the List of Jobs." - if warning_cause_exec == 'no_dataset': - return not is_open, "Please upload the dataset before submitting the job." - else: - return False, "" - - -@app.callback( - Output("modal", "is_open"), - - Input("delete-files", "n_clicks"), - Input("confirm-delete", "n_clicks"), - - State("modal", "is_open") -) -def toggle_modal(n1, n2, is_open): - ''' - This callback toggles a confirmation message for file manager - Args: - n1: Delete files button - n2: Confirm delete button - is_open: Open/close confirmation modal state - ''' - if n1 or n2: - return not is_open - return is_open - - -@app.callback( - Output("npz-modal", "is_open"), - Output("npz-img-key", "options"), - Output("npz-label-key", "options"), - - Input("import-dir", "n_clicks"), - Input("confirm-import", "n_clicks"), - Input("npz-img-key", "value"), - Input("npz-label-key", "value"), - - State("npz-modal", "is_open"), - State("docker-file-paths", "data"), -) -def toggle_modal_keyword(import_button, confirm_import, img_key, label_key, is_open, npz_path): - ''' - This callback opens the modal to select the keywords within the NPZ file. When a keyword is selected for images or - labels, this option is removed from the options of the other. - Args: - import_button: Import button - confirm_import: Confirm import button - img_key: Selected keyword for the images - label_key: Selected keyword for the labels - is_open: Open/close status of the modal - npz_path: Path to NPZ file - Returns: - toggle_modal: Open/close modal - img_options: Keyword options for images - label_options: Keyword options for labels - ''' - img_options = [] - label_options = [] - toggle_modal = is_open - changed_id = dash.callback_context.triggered[0]['prop_id'] - if npz_path: - if npz_path[0].split('.')[-1] == 'npz': - data = np.load(npz_path[0]) - img_key_list = list(data.keys()) - label_key_list = list(data.keys()) - # if this value has been previously selected, it is removed from its options - if label_key in img_key_list: - img_key_list.remove(label_key) - df_img = pd.DataFrame({'c': img_key_list}) - if img_key in label_key_list: - label_key_list.remove(img_key) - df_label = pd.DataFrame({'c': label_key_list}) - img_options = [{'label':i, 'value':i} for i in df_img['c']] - label_options = [{'label':i, 'value':i} for i in df_label['c']] - toggle_modal = True - if is_open and 'confirm-import.n_clicks' in changed_id: - toggle_modal = False - return toggle_modal, img_options, label_options - - -@app.callback( - Output('dummy-data', 'data'), - - Input('dash-uploader', 'isCompleted'), - - State('dash-uploader', 'fileNames') -) -def upload_zip(iscompleted, upload_filename): - ''' - This callback uploads a ZIP file - Args: - iscompleted: The upload operation is completed (bool) - upload_filename: Filename of the uploaded content - ''' - if not iscompleted: - return 0 - if upload_filename is not None: - path_to_zip_file = pathlib.Path(UPLOAD_FOLDER_ROOT) / upload_filename[0] - if upload_filename[0].split('.')[-1] == 'zip': # unzip files and delete zip file - zip_ref = zipfile.ZipFile(path_to_zip_file) # create zipfile object - path_to_folder = pathlib.Path(UPLOAD_FOLDER_ROOT) / upload_filename[0].split('.')[-2] - if (upload_filename[0].split('.')[-2] + '/') in zip_ref.namelist(): - zip_ref.extractall(pathlib.Path(UPLOAD_FOLDER_ROOT)) # extract file to dir - else: - zip_ref.extractall(path_to_folder) - zip_ref.close() # close file - os.remove(path_to_zip_file) - return 0 - - -@app.callback( - Output('files-table', 'data'), - Output('docker-file-paths', 'data'), - Output('data-path', 'data'), - Output('splash-indicator', 'data'), - - Input('browse-format', 'value'), - Input('browse-dir', 'n_clicks'), - Input('import-dir', 'n_clicks'), - Input('confirm-delete', 'n_clicks'), - Input('move-dir', 'n_clicks'), - Input('files-table', 'selected_rows'), - Input('data-path', 'data'), - Input('import-format', 'value'), - Input('my-toggle-switch', 'value'), - Input('jobs-table', 'selected_rows'), - Input("clear-data", "n_clicks"), - Input("refresh-data", "n_clicks"), - - State('dest-dir-name', 'value'), - State('jobs-table', 'data') -) -def file_manager(browse_format, browse_n_clicks, import_n_clicks, delete_n_clicks, move_dir_n_clicks, rows, - selected_paths, import_format, docker_path, job_rows, clear_data, refresh_data, dest, job_data): - ''' - This callback displays manages the actions of file manager - Args: - browse_format: File extension to browse - browse_n_clicks: Browse button - import_n_clicks: Import button - delete_n_clicks: Delete button - move_dir_n_clicks: Move button - rows: Selected rows - selected_paths: Selected paths in cache - import_format: File extension to import - docker_path: [bool] docker vs local path - job_rows: Selected rows in job table. If it's not a "training" model, it will load its results - instead of the data uploaded through File Manager. This is so that the user can observe - previous evaluation results - clear_data: Clear the loaded images - refresh_data: Refresh the loaded images - dest: Destination path - job_data: Data in job table - Returns - files: Filenames to be displayed in File Manager according to browse_format from docker/local path - list_filename: List of selected filenames in the directory AND SUBDIRECTORIES FROM DOCKER PATH - selected_files: List of selected filename FROM DOCKER PATH (no subdirectories) - splash: Bool variable that indicates whether the labels are retrieved from splash-ml or not - ''' - changed_id = dash.callback_context.triggered[0]['prop_id'] - splash = dash.no_update - - # if a previous job is selected, it's data is automatically plotted - if 'jobs-table.selected_rows' in changed_id and job_rows is not None: - if len(job_rows)>0: - if job_data[job_rows[0]]["job_type"].split()[0] != 'train_model': - filenames = add_paths_from_dir(job_data[job_rows[0]]["dataset"], ['tiff', 'tif', 'jpg', 'jpeg', 'png'], []) - return dash.no_update, filenames, dash.no_update, splash - - supported_formats = [] - import_format = import_format.split(',') - if import_format[0] == '*': - supported_formats = ['tiff', 'tif', 'jpg', 'jpeg', 'png'] - else: - for ext in import_format: - supported_formats.append(ext.split('.')[1]) - - # files = [] - # if browse_n_clicks or import_n_clicks: - files = filename_list(DOCKER_DATA, browse_format) - - selected_files = [] - list_filename = [] - if bool(rows): - for row in rows: - file_path = files[row] - selected_files.append(file_path) - if file_path['file_type'] == 'dir': - list_filename = add_paths_from_dir(file_path['file_path'], supported_formats, list_filename) - else: - list_filename.append(file_path['file_path']) - - if browse_n_clicks and changed_id == 'confirm-delete.n_clicks': - for filepath in selected_files: - if os.path.isdir(filepath['file_path']): - shutil.rmtree(filepath['file_path']) - else: - os.remove(filepath['file_path']) - selected_files = [] - files = filename_list(DOCKER_DATA, browse_format) - - if browse_n_clicks and changed_id == 'move-dir.n_clicks': - if dest is None: - dest = '' - destination = DOCKER_DATA / dest - destination.mkdir(parents=True, exist_ok=True) - if bool(rows): - sources = selected_paths - for source in sources: - if os.path.isdir(source['file_path']): - move_dir(source['file_path'], str(destination)) - shutil.rmtree(source['file_path']) - else: - move_a_file(source['file_path'], str(destination)) - selected_files = [] - files = filename_list(DOCKER_DATA, browse_format) - if not docker_path: - files = docker_to_local_path(files, DOCKER_HOME, LOCAL_HOME) - - if changed_id == 'refresh-data.n_clicks': - list_filename, selected_files = [], [] - datapath = requests.get(f'http://labelmaker-api:8005/api/v0/datapath/export_dataset').json() - if datapath: - if bool(datapath['datapath']) and os.path.isdir(datapath['datapath']['file_path'][0]): - list_filename, selected_files = datapath['filenames'], datapath['datapath']['file_path'][0] - if datapath['datapath']['where'] == 'splash': - splash = True - return files, list_filename, selected_files, splash - - elif changed_id == 'import-dir.n_clicks': - return files, list_filename, selected_files, False - - elif changed_id == 'clear-data.n_clicks': - return [], [], [], False - - else: - return files, dash.no_update, dash.no_update, splash - - -@app.callback( - Output('jobs-table', 'data'), - Output('results-plot', 'figure'), - Output('results-plot', 'style'), - Output('results-text', 'value'), - Output('results-text', 'style'), - Output('log-modal', 'is_open'), - Output('log-display', 'children'), - Output('jobs-table', 'active_cell'), - - Input('interval', 'n_intervals'), - Input('jobs-table', 'selected_rows'), - Input('jobs-table', 'active_cell'), - Input('img-slider', 'value'), - Input('modal-close', 'n_clicks'), - - State('docker-file-paths', 'data'), - State('jobs-table', 'data'), - State('results-plot', 'figure'), - prevent_initial_call=True -) -def update_table(n, row, active_cell, slider_value, close_clicks, filenames, current_job_table, current_fig): - ''' - This callback updates the job table, loss plot, and results according to the job status in the compute service. - Args: - n: Time intervals that trigger this callback - row: Selected row (job) - slider_value: Image slider value (current image) - filenames: Selected data files - current_job_table: Current job table - current_fig: Current loss plot - Returns: - jobs-table: Updates the job table - show-plot: Shows/hides the loss plot - loss-plot: Updates the loss plot according to the job status (logs) - results: Testing results (probability) - ''' - changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0] - if 'modal-close.n_clicks' in changed_id: - return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, False, dash.no_update, None - job_list = get_job(USER, 'mlcoach') - data_table = [] - if job_list is not None: - for job in job_list: - params = str(job['job_kwargs']['kwargs']['params']) - if job['job_kwargs']['kwargs']['job_type'].split(' ')[0] != 'train_model': - params = params + '\nTraining Parameters: ' + str(job['job_kwargs']['kwargs']['train_params']) - data_table.insert(0, - dict( - job_id=job['uid'], - job_type=job['job_kwargs']['kwargs']['job_type'], - name=job['description'], - status=job['status']['state'], - parameters=params, - experiment_id=job['job_kwargs']['kwargs']['experiment_id'], - job_logs=job['logs'], - dataset=job['job_kwargs']['kwargs']['dataset']) - ) - is_open = dash.no_update - log_display = dash.no_update - if active_cell: - row_log = active_cell["row"] - col_log = active_cell["column_id"] - if col_log == 'job_logs': # show job logs - is_open = True - log_display = dcc.Textarea(value=data_table[row_log]["job_logs"], - style={'width': '100%', 'height': '30rem', 'font-family':'monospace'}) - if col_log == 'parameters': # show job parameters - is_open = True - log_display = dcc.Textarea(value=str(job['job_kwargs']['kwargs']['params']), - style={'width': '100%', 'height': '30rem', 'font-family': 'monospace'}) - style_fig = {'display': 'none'} - style_text = {'display': 'none'} - val = '' - fig = go.Figure(go.Scatter(x=[], y=[])) - if row: - if row[0] < len(data_table): - log = data_table[row[0]]["job_logs"] - if log: - if data_table[row[0]]['job_type'].split(' ')[0] == 'train_model': - start = log.find('epoch') - if start > -1 and len(log) > start + 5: - fig = generate_figure(log, start) - style_fig = {'width': '100%', 'display': 'block'} - if data_table[row[0]]['job_type'].split(' ')[0] == 'evaluate_model': - val = log - style_text = {'width': '100%', 'display': 'block'} - if data_table[row[0]]['job_type'].split(' ')[0] == 'prediction_model': - start = log.find('filename ') - if start > -1 and len(log) > start + 10 and len(filenames)>slider_value: - fig = get_class_prob(log, start, filenames[slider_value]) - style_fig = {'width': '100%', 'display': 'block'} - if current_fig: - try: - if current_fig['data'][0]['y'] == list(fig['data'][0]['y']): - fig = dash.no_update - except Exception as e: - print(e) - if data_table == current_job_table: - data_table = dash.no_update - return data_table, fig, style_fig, val, style_text, is_open, log_display, None - - -@app.callback( - Output('jobs-table', 'selected_rows'), - Input('deselect-row', 'n_clicks'), - prevent_initial_call=True -) -def deselect_row(n_click): - ''' - This callback deselects the row in the data table - ''' - return [] - - -@app.callback( - Output('delete-modal', 'is_open'), - Input('confirm-delete-row', 'n_clicks'), - Input('delete-row', 'n_clicks'), - Input('stop-row', 'n_clicks'), - State('jobs-table', 'selected_rows'), - State('jobs-table', 'data'), - prevent_initial_call=True -) -def delete_row(confirm_delete, delete, stop, row, job_data): - ''' - This callback deletes the selected model in the table - ''' - changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0] - if 'delete-row.n_clicks' == changed_id: - return True - elif 'stop-row.n_clicks' == changed_id: - job_uid = job_data[row[0]]['job_id'] - requests.patch(f'http://job-service:8080/api/v0/jobs/{job_uid}/terminate') - return False - else: - job_uid = job_data[row[0]]['job_id'] - requests.delete(f'http://job-service:8080/api/v0/jobs/{job_uid}/delete') - return False - - -@app.callback( - Output("app-parameters", "children"), - Output("download-button", "style"), - - Input("model-selection", "value"), - Input("action", "value"), - Input("jobs-table", "selected_rows"), - prevent_intial_call=True) -def load_parameters(model_selection, action_selection, row): - ''' - This callback dynamically populates the parameters of the website according to the selected action & model. - Args: - model_selection: Selected model (from content registry) - action_selection: Selected action (pre-defined actions in MLCoach) - row: Selected job (model) - Returns: - app-parameters: Parameters according to the selected model & action - download-button: Shows the download button - ''' - parameters = get_gui_components(model_selection, action_selection) - gui_item = JSONParameterEditor(_id={'type': 'parameter_editor'}, # pattern match _id (base id), name - json_blob=parameters) - gui_item.init_callbacks(app) - style = dash.no_update - if row is not None: - style = {'width': '100%', 'justify-content': 'center'} - return gui_item, style - - -@app.callback( - Output("img-output", "figure"), - Output("current-image-label", 'children'), - Output("label-output", "children"), - Output("img-slider", "max"), - Output("img-slider", "value"), - Output("app-content", "style"), - Output("warning-cause", "data"), - - Input("import-dir", "n_clicks"), - Input("confirm-import", "n_clicks"), - Input("img-slider", "value"), - Input("docker-file-paths", "data"), - - State("npz-img-key", "value"), - State("npz-label-key", "value"), - State("npz-modal", "is_open"), - State('splash-indicator', 'data'), - prevent_intial_call=True -) -def refresh_image(import_dir, confirm_import, img_ind, filenames, img_keyword, label_keyword, npz_modal, splash): - ''' - This callback updates the image in the display - Args: - import_dir: Import button - confirm_import: Confirm import button - img_ind: Index of image according to the slider value - filenames: Selected data files - jobs-table: Data in table of jobs - img_keyword: Keyword for images in NPZ file - label_keyword: Keyword for labels in NPZ file - npz_modal: Open/close status of NPZ modal - Returns: - img-output: Output figure - label-output: Output label - img-slider-max: Maximum value of the slider according to the dataset (train vs test) - img-slider-value: Current value of the slider - content_style: Content visibility - warning-cause: Cause that triggered warning pop-up - splash: Bool variable that indicates whether the labels are retrieved from splash-ml or not - ''' - current_im_label='' - changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0] - if len(filenames)>0 and not npz_modal: - try: - if filenames[0].split('.')[-1] == 'npz': # npz file - if img_keyword is not None and label_keyword is not None: - data_npz = np.load(filenames[0]) - data_npy = np.squeeze(data_npz[img_keyword]) - label_npy = np.squeeze(data_npz[label_keyword]) - if len(data_npy) != len(label_npy): - return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, 'different_size' - slider_max = len(data_npy) - 1 - if img_ind>slider_max: - img_ind = 0 - fig = plot_figure(data_npy[img_ind]) - current_im_label = f"Image: {filenames[0]}" - label = f"Label: {label_npy[img_ind]}" - else: - return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, {'display': 'None'}, dash.no_update - else: # directory - slider_max = len(filenames)-1 - if img_ind>slider_max: - img_ind = 0 - image = Image.open(filenames[img_ind]) - fig = plot_figure(image) - current_im_label = f"Image: {filenames[img_ind]}" - if splash: - label = load_from_splash(filenames[img_ind]) - else: - label = filenames[img_ind].split('/')[-2] # determined by the last directory in the path - label = f"Label: {label}" - return fig, current_im_label, label, slider_max, img_ind, {'display': 'block'}, dash.no_update - except Exception as e: - print(f'Exception in refresh_image callback {e}') - return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, {'display': 'None'}, 'wrong_dataset' - else: - return dash.no_update, dash.no_update, dash.no_update, dash.no_update, dash.no_update, {'display': 'None'}, dash.no_update - - -@app.callback( - Output("resources-setup", "is_open"), - Output("counter", "data"), - Output("warning-cause-execute", "data"), - - Input("execute", "n_clicks"), - Input("submit", "n_clicks"), - - State("app-parameters", "children"), - State("num-cpus", "value"), - State("num-gpus", "value"), - State("action", "value"), - State("jobs-table", "data"), - State("jobs-table", "selected_rows"), - State('data-path', 'data'), - State("docker-file-paths", "data"), - State("counter", "data"), - State("npz-img-key", "value"), - State("npz-label-key", "value"), - State("model-name", "value"), - State('splash-indicator', 'data'), - prevent_intial_call=True) -def execute(execute, submit, children, num_cpus, num_gpus, action_selection, job_data, row, data_path, filenames, - counters, x_key, y_key, model_name, splash): - ''' - This callback submits a job request to the compute service according to the selected action & model - Args: - execute: Execute button - submit: Submit button - children: Model parameters - num_cpus: Number of CPUs assigned to job - num_gpus: Number of GPUs assigned to job - action_selection: Action selected - job_data: Lists of jobs - row: Selected row (job) - data_path: Local path to data - filenames: Filenames in dataset - counters: List of counters to assign a number to each job according to its action (train vs evaluate) - x_key: Keyword for x data in NPZ file - y_key: Keyword for y data in NPZ file - splash: Bool variable that indicates whether the labels are retrieved from splash-ml or not - Returns: - open/close the resources setup modal - ''' - changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0] - if 'execute.n_clicks' in changed_id: - if len(filenames) == 0: - return False, counters, 'no_dataset' - if action_selection != 'train_model' and not row: - return False, counters, 'no_row_selected' - if row: - if action_selection != 'train_model' and job_data[row[0]]['job_type'].split(' ')[0] != 'train_model': - return False, counters, 'no_row_selected' - return True, counters, '' - if 'submit.n_clicks' in changed_id: - counters = get_counter(USER) - experiment_id = str(uuid.uuid4()) - out_path = pathlib.Path('/app/work/data/mlexchange_store/{}/{}'.format(USER, experiment_id)) - out_path.mkdir(parents=True, exist_ok=True) - input_params = {'x_key': x_key, 'y_key': y_key} - if splash: - input_params['splash'] = filenames - kwargs = {} - if bool(children): - try: - for child in children['props']['children']: - key = child["props"]["children"][1]["props"]["id"]["param_key"] - value = child["props"]["children"][1]["props"]["value"] - input_params[key] = value - except Exception: - for child in children: - key = child["props"]["children"][1]["props"]["id"] - value = child["props"]["children"][1]["props"]["value"] - input_params[key] = value - try: - data_path = data_path[0]['file_path'] - except Exception as e: - print(e) - if action_selection == 'train_model': - counters[0] = counters[0] + 1 - count = counters[0] - command = "python3 src/train_model.py" - directories = [data_path, str(out_path)] - else: - training_exp_id = job_data[row[0]]['experiment_id'] - in_path = pathlib.Path('/app/work/data/mlexchange_store/{}/{}'.format(USER, training_exp_id)) - if action_selection == 'evaluate_model': - counters[1] = counters[1] + 1 - count = counters[1] - command = "python3 src/evaluate_model.py" - directories = [data_path, str(in_path) + '/model.h5'] - if action_selection == 'prediction_model': - counters[2] = counters[2] + 1 - count = counters[2] - command = "python3 src/predict_model.py" - kwargs = {'train_params': job_data[row[0]]['parameters']} - directories = [data_path, str(in_path) + '/model.h5', str(out_path)] - if action_selection == 'transfer_learning': - counters[3] = counters[3] + 1 - count = counters[3] - command = "python3 src/transfer_learning.py" - directories = [data_path, str(in_path) + '/model.h5', str(out_path)] - if len(model_name)==0: # if model_name was not defined - model_name = f'{action_selection} {count}' - job = SimpleJob(service_type='backend', - description=model_name, - working_directory='{}'.format(DATA_DIR), - uri='mlexchange1/tensorflow-neural-networks', - cmd=' '.join([command] + directories + ['\'' + json.dumps(input_params) + '\'']), - kwargs={'job_type': action_selection, - 'experiment_id': experiment_id, - 'dataset': data_path, - 'params': input_params, - **kwargs}) - job.submit(USER, num_cpus, num_gpus) - return False, counters, '' - return False, counters, '' - - -@app.callback( - Output("download-out", "data"), - Input("download-button", "n_clicks"), - State("jobs-table", "data"), - State("jobs-table", "selected_rows"), - prevent_intial_call=True) -def save_results(download, job_data, row): - ''' - This callback saves the experimental results as a ZIP file - Args: - download: Download button - job_data: Table of jobs - row: Selected job/row - Returns: - ZIP file with results - ''' - if download and row: - experiment_id = job_data[row[0]]["experiment_id"] - experiment_path = pathlib.Path('data/mlexchange_store/{}/{}'.format(USER, experiment_id)) - shutil.make_archive('/app/tmp/results', 'zip', experiment_path) - return dcc.send_file('/app/tmp/results.zip') - else: - return None - - -if __name__ == '__main__': - app.run_server(debug=True, host='0.0.0.0', port=8062)#, dev_tools_ui=False) diff --git a/src/helpers.py b/src/helpers.py deleted file mode 100644 index d8aa649..0000000 --- a/src/helpers.py +++ /dev/null @@ -1,230 +0,0 @@ -import sys -if sys.version_info[0] < 3: - from StringIO import StringIO -else: - from io import StringIO - -import pandas as pd -import plotly.express as px -import plotly.graph_objects as go -from plotly.subplots import make_subplots -import requests - - -SPLASH_CLIENT = 'http://splash:80/api/v0' - - -class SimpleJob: - def __init__(self, - service_type, - description, - working_directory, - uri, - cmd, - kwargs=None, - mlex_app='mlcoach'): - self.mlex_app = mlex_app - self.description = description - self.service_type = service_type - self.working_directory = working_directory - self.job_kwargs = {'uri': uri, - 'type': 'docker', - 'cmd': cmd, - 'kwargs': kwargs} - - def submit(self, user, num_cpus, num_gpus): - ''' - Sends job to computing service - Args: - user: user UID - num_cpus: Number of CPUs - num_gpus: Number of GPUs - Returns: - Workflow status - ''' - workflow = {'user_uid': user, - 'job_list': [self.__dict__], - 'host_list': ['mlsandbox.als.lbl.gov', 'local.als.lbl.gov', 'vaughan.als.lbl.gov'], - 'dependencies': {'0':[]}, - 'requirements': {'num_processors': num_cpus, - 'num_gpus': num_gpus, - 'num_nodes': 1}} - url = 'http://job-service:8080/api/v0/workflows' - return requests.post(url, json=workflow).status_code - - -# Queries the job from the computing database -def get_job(user, mlex_app, job_type=None, deploy_location=None): - url = 'http://job-service:8080/api/v0/jobs?' - if user: - url += ('&user=' + user) - if mlex_app: - url += ('&mlex_app=' + mlex_app) - if job_type: - url += ('&job_type=' + job_type) - if deploy_location: - url += ('&deploy_location=' + deploy_location) - response = requests.get(url) - return response.json() - - -def get_class_prob(log, start, filename): - end = log.find('Prediction process completed') - if end == -1: - end = len(log) - log = log[start:end] - try: - #print(log) - df = pd.read_csv(StringIO(log.replace('\n\n', '\n')), sep=' ') - #print(df) - res = df.loc[df['filename'] == filename] # search results for the selected file - #print(res) - if res.shape[0]>1: - res = res.iloc[[0]] - fig = px.bar(res.iloc[: , 1:]) - fig.update_layout(yaxis_title="probability") - fig.update_xaxes(showgrid=False, - showticklabels=False, - zeroline=False) - return fig #res.to_string(index=False) - except Exception as err: - print(err) - return go.Figure(go.Scatter(x=[], y=[])) - - -# Generate loss plot -def generate_figure(log, start): - end = log.find('Train process completed') - if end == -1: - end = len(log) - log = log[start:end] - df = pd.read_csv(StringIO(log.replace('\n\n', '\n')), sep=' ') - try: - fig = make_subplots(specs=[[{"secondary_y": True}]]) - for col in list(df.columns)[1:]: - if 'loss' in col: - fig.add_trace(go.Scatter(x=df['epoch'], y=df[col], name=col), secondary_y=False) - fig.update_yaxes(title_text="loss", secondary_y=False) - else: - fig.add_trace(go.Scatter(x=df['epoch'], y=df[col], name=col), secondary_y=True) - fig.update_yaxes(title_text="accuracy", secondary_y=True, range=[0,1]) - fig.update_layout(xaxis_title="epoch", margin=dict(l=20, r=20, t=20, b=20)) - return fig - except Exception as e: - print(e) - return go.Figure(go.Scatter(x=[], y=[])) - - -def plot_figure(image): - fig = px.imshow(image, height=350) - fig.update_xaxes(showgrid=False, - showticklabels=False, - zeroline=False) - fig.update_yaxes(showgrid=False, - showticklabels=False, - zeroline=False) - fig.update_layout(margin=dict(l=0, r=0, t=0, b=10)) - try: - fig.update_traces(dict(showscale=False, coloraxis=None)) - except Exception as e: - print(e) - return fig - - -# saves model as an .h5 file on local disk -def save_model(model, save_path='my_model.h5'): - # save model - model.save(save_path) - print("Saved to disk") - - -def model_list_GET_call(): - """ - Get a list of algorithms from content registry - """ - url = 'http://content-api:8000/api/v0/models' - model_list = requests.get(url).json() - models = [] - for item in model_list: - if 'mlcoach' in item['application']: - models.append({'label': item['name'], 'value': item['content_id']}) - return models - - -def get_model(model_uid): - ''' - This function gets the algorithm dict from content registry - Args: - model_uid: Model UID - Returns: - service_type: Backend/Frontend - content_uri: URI - ''' - url = 'http://content-api:8000/api/v0/contents/{}/content'.format(model_uid) - content = requests.get(url).json() - if 'map' in content: - return content['service_type'], content['uri'] - return content['service_type'], content['uri'] - - -def get_gui_components(model_uid, comp_group): - ''' - Returns the GUI components of the corresponding model and action - Args: - model_uid: Model UID - comp_group: Action, e.g. training, testing, etc - Returns: - params: List of model parameters - ''' - url = f'http://content-api:8000/api/v0/models/{model_uid}/model/{comp_group}/gui_params' - response = requests.get(url) - return response.json() - - -def get_counter(username): - job_list = get_job(username, 'mlcoach') - job_types = ['train_model', 'evaluate_model', 'prediction_model', 'transfer_learning'] - counters = [-1, -1, -1, -1] - if job_list is not None: - for indx, job_type in enumerate(job_types): - for job in reversed(job_list): - last_job = job['job_kwargs']['kwargs']['job_type'] - if job['description']: - job_name = job['description'].split() - else: - job_name = job['job_kwargs']['kwargs']['job_type'].split() - if last_job == job_type and job_name[0] == job_type and len(job_name)==2 and job_name[-1].isdigit(): - value = int(job_name[-1]) - counters[indx] = value - break - return counters - - -def load_from_splash(filename): - ''' - This function queries labels from splash-ml. - Args: - filename: URI of dataset (e.g. filename) - Returns: - label: Label assigned to the filename - ''' - url = f'{SPLASH_CLIENT}/datasets?' - try: - params = {'uris': [filename]} - datasets = requests.get(url, params=params).json() - except Exception as e: - print(f'Loading from splash exception: {e}') - datasets = [] - for tag in datasets[0]['tags']: - if tag['name'] == 'labelmaker': - label = tag['locator']['path'] - break # just one tag at this time - return label - - -def get_host(host_nickname): - hosts = requests.get(f'http://job-service:8080/api/v0/hosts?&nickname={host_nickname}').json() - max_processors = hosts[0]['backend_constraints']['num_processors'] - max_gpus = hosts[0]['backend_constraints']['num_gpus'] - return max_processors, max_gpus - diff --git a/src/kwarg_editor.py b/src/kwarg_editor.py deleted file mode 100644 index c2aff4f..0000000 --- a/src/kwarg_editor.py +++ /dev/null @@ -1,313 +0,0 @@ -import re -from typing import Callable -# noinspection PyUnresolvedReferences -from inspect import signature, _empty - -import dash -import dash_html_components as html -import dash_core_components as dcc -import dash_bootstrap_components as dbc -import dash_daq as daq - -from dash.dependencies import Input, ALL, Output, State - -from targeted_callbacks import targeted_callback - -# Procedural dash form generation - -""" -{'name', 'title', 'value', 'type', -""" - - -class SimpleItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - type='number', - debounce=True, - **kwargs): - self.name = name - self.label = dbc.Label(title or name) - self.input = dbc.Input(type=type, - debounce=debounce, - id={**base_id, - 'name': name, - 'param_key': param_key}, - **kwargs) - - super(SimpleItem, self).__init__(children=[self.label, self.input]) - - -class FloatItem(SimpleItem): - pass - - -class IntItem(SimpleItem): - def __init__(self, *args, **kwargs): - if 'min' not in kwargs: - kwargs['min'] = -9007199254740991 - super(IntItem, self).__init__(*args, step=1, **kwargs) - - -class StrItem(SimpleItem): - def __init__(self, *args, **kwargs): - super(StrItem, self).__init__(*args, type='text', **kwargs) - - -class SliderItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - debounce=True, - visible=True, - **kwargs): - self.label = dbc.Label(title or name) - self.input = dcc.Slider(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - tooltip={"placement": "bottom", "always_visible": True}, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(SliderItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class DropdownItem(dbc.FormGroup): - def __init__(self, - name, - base_id, # shared by all components - title=None, - param_key=None, - debounce=True, - visible=True, - **kwargs): - self.label = dbc.Label(title or name) - self.input = dcc.Dropdown(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(DropdownItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class RadioItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - visible=True, - **kwargs): - self.label = dbc.Label(title or name) - self.input = dbc.RadioItems(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(RadioItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class BoolItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - visible=True, - **kwargs): - self.label = dbc.Label(title or name) - self.input = daq.ToggleSwitch(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) - self.output_label = dbc.Label('False/True') - - style = {} - if not visible: - style['display'] = 'none' - - super(BoolItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input, self.output_label], - style=style) - - -class GraphItem(dbc.FormGroup): - def __init__(self, - name, - base_id, - param_key=None, - title=None, - visible=True, - **kwargs): - self.label = dbc.Label(title or name) - self.input = dcc.Graph(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) - - style = {} - if not visible: - style['display'] = 'none' - - super(GraphItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) - - -class ParameterEditor(dbc.Form): # initialize dbc form object with input parameters (for each component) - - type_map = {float: FloatItem, - int: IntItem, - str: StrItem, - } - - def __init__(self, _id, parameters, **kwargs): - self._parameters = parameters - - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) - self.children = self.build_children() - - def init_callbacks(self, app): - targeted_callback(self.stash_value, - Input({**self.id, - 'name': ALL}, - 'value'), - Output(self.id, 'n_submit'), - State(self.id, 'n_submit'), - app=app) - - def stash_value(self, value): - # find the changed item name from regex - r = '(?<=\"name\"\:\")[\w\-_]+(?=\")' - matches = re.findall(r, dash.callback_context.triggered[0]['prop_id']) - - if not matches: - raise LookupError('Could not find changed item name. Check that all parameter names use simple chars (\\w)') - - name = matches[0] - self.parameters[name]['value'] = value - - print(self.values) - - return next(iter(dash.callback_context.states.values())) or 0 + 1 - - @property - def values(self): - return {param['name']: param.get('value', None) for param in self._parameters} - - @property - def parameters(self): - return {param['name']: param for param in self._parameters} - - def _determine_type(self, parameter_dict): - if 'type' in parameter_dict: - if parameter_dict['type'] in self.type_map: - return parameter_dict['type'] - elif parameter_dict['type'].__name__ in self.type_map: - return parameter_dict['type'].__name__ - elif type(parameter_dict['value']) in self.type_map: - return type(parameter_dict['value']) - raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') - - def build_children(self, values=None): - children = [] - for parameter_dict in self._parameters: - parameter_dict = parameter_dict.copy() - if values and parameter_dict['name'] in values: - parameter_dict['value'] = values[parameter_dict['name']] - type = self._determine_type(parameter_dict) - parameter_dict.pop('type', None) - item = self.type_map[type](**parameter_dict, base_id=self.id) - children.append(item) - - return children - - -class JSONParameterEditor(ParameterEditor): - type_map = {'float': FloatItem, - 'int': IntItem, - 'str': StrItem, - 'slider': SliderItem, - 'dropdown': DropdownItem, - 'radio': RadioItem, - 'bool': BoolItem, - 'graph': GraphItem, - } - - def __init__(self, _id, json_blob, **kwargs): - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) - self._json_blob = json_blob - self.children = self.build_children() - - def build_children(self, values=None): - children = [] - for json_record in self._json_blob: - ... - # build a parameter dict from self.json_blob - ... - type = json_record.get('type', self._determine_type(json_record)) - json_record = json_record.copy() - if values and json_record['name'] in values: - json_record['value'] = values[json_record['name']] - json_record.pop('type', None) - item = self.type_map[type](**json_record, base_id=self.id) - children.append(item) - - return children - - -class KwargsEditor(ParameterEditor): - def __init__(self, instance_index, func: Callable, **kwargs): - self.func = func - self._instance_index = instance_index - - parameters = [{'name': name, 'value': param.default} for name, param in signature(func).parameters.items() - if param.default is not _empty] - - super(KwargsEditor, self).__init__(dict(index=instance_index, type='kwargs-editor'), parameters=parameters, - **kwargs) - - def new_record(self): - return {name: p.default for name, p in signature(self.func).parameters.items() if p.default is not _empty} \ No newline at end of file diff --git a/src/main.cfg b/src/main.cfg deleted file mode 100644 index 0c145d5..0000000 --- a/src/main.cfg +++ /dev/null @@ -1,5 +0,0 @@ -DATA_DIR: 'data' -TRAIN_DATA_DIR: ${DATA_DIR} + '/train' -VALIDATION_DATA_DIR: ${DATA_DIR} +'/val' -TEST_DATA_DIR: ${DATA_DIR} +'/test' -MODEL_SAVE_DIR: ${DATA_DIR} +'/logs' diff --git a/src/targeted_callbacks.py b/src/targeted_callbacks.py deleted file mode 100644 index 80eb185..0000000 --- a/src/targeted_callbacks.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Union, Callable -from dash._utils import create_callback_id -from dash.dependencies import handle_callback_args, State -from dash.dependencies import Input, Output -from dash.exceptions import PreventUpdate -import dash -from dataclasses import dataclass -import json -import warnings - -app = None - -_targeted_callbacks = [] - - -@dataclass -class Callback: - input: Input - output: Output - callable: Callable - - -def _dispatcher(*_): - triggered = dash.callback_context.triggered - if not triggered: - raise PreventUpdate - - for callback in _targeted_callbacks: - _id, _property = triggered[0]['prop_id'].split('.') - if '{' in _id: - _id = json.loads(_id) - _input = Input(_id, _property) - _id, _property = dash.callback_context.outputs_list.values() - _output = Output(_id, _property) - if callback.input == _input and callback.output == _output: - return_value = callback.callable(triggered[0]['value']) - if return_value is None: - warnings.warn( - f'A callback returned None. Perhaps you forgot a return value? Callback: {repr(callback.callable)}') - return return_value - - -def targeted_callback(callback, input: Input, output: Output, *states: State, app=app, prevent_initial_call=None): - if prevent_initial_call is None: - prevent_initial_call = app.config.prevent_initial_callbacks - - callback_id = create_callback_id(output) - if callback_id in app.callback_map: - if app.callback_map[callback_id]["callback"].__name__ != '_dispatcher': - raise ValueError('Attempting to use a targeted callback with an output already assigned to a' - 'standard callback. These are not compatible.') - - # app.callback_map['state'].extend(states) - # app.callback_map['inputs'].extend(input.) - - for callback_spec in app._callback_list: - if callback_spec['output'] == callback_id: - if callback_spec['prevent_initial_call'] != prevent_initial_call: - raise ValueError('A callback has already been registered to this output with a conflicting value' - 'for prevent_initial_callback. You should decide which you want.') - callback_spec['inputs'].append(input.to_dict()) - callback_spec['state'].extend([state.to_dict() for state in states]) - else: - app.callback(output, input, *states, prevent_initial_call=prevent_initial_call)(_dispatcher) - - _targeted_callbacks.append(Callback(input, output, callback)) \ No newline at end of file diff --git a/src/templates.py b/src/templates.py deleted file mode 100644 index c34d749..0000000 --- a/src/templates.py +++ /dev/null @@ -1,42 +0,0 @@ -import dash_html_components as html -import dash_bootstrap_components as dbc -import dash_core_components as dcc - - -def header(): - header = dbc.Navbar( - dbc.Container( - [ - dbc.Row( - [ - dbc.Col( - html.Img( - id="logo", - src='assets/mlex.png', - height="60px", - ), - md="auto", - ), - dbc.Col( - [ - html.Div( - [ - html.H3("MLExchange | MLCoach"), - ], - id="app-title", - ) - ], - md=True, - align="center", - ), - ], - align="center", - ) - ], - fluid=True, - ), - dark=True, - color="dark", - sticky="top", - ) - return header diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/data_utils.py b/src/utils/data_utils.py new file mode 100644 index 0000000..3b98c3c --- /dev/null +++ b/src/utils/data_utils.py @@ -0,0 +1,89 @@ +import pathlib +import uuid + +import pandas as pd + +from src.app_layout import DATA_DIR, TILED_KEY + + +def prepare_directories( + user_id, data_project, labeled_indices=None, train=True, correct_path=False +): + """ + Prepare data directories that host experiment results and data movements processes for tiled + If data is served through tiled, a local copy will be made for ML training and inference + processes in file system located at data/mlexchange_store/user_id/tiledprojectid_localprojectid + Args: + user_id: User ID + data_project: List of data sets in the application + labeled_indices: List of indexes of labeled images in the data set + train: Flag to indicate if the data is used for training or inference + correct_path: Flag to indicate if the path should be corrected + Returns: + experiment_id: ML experiment ID + out_path: Path were experiment results will be stored + info_file: Filename of a parquet file that contains the list of data sets within the + current project + """ + experiment_id = str(uuid.uuid4()) + out_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{user_id}/{experiment_id}") + out_path.mkdir(parents=True, exist_ok=True) + data_type = data_project.data_type + if data_type == "tiled" and train: + # Download tiled data to local + uri_list = data_project.tiled_to_local_project( + DATA_DIR, indices=labeled_indices, correct_path=correct_path + ) + splash_uris = data_project.read_datasets(labeled_indices, just_uri=True) + data_info = pd.DataFrame({"uri": uri_list, "splash_uri": splash_uris}) + elif data_type == "tiled": + # Save sub uris + root_uri = data_project.root_uri + data_info = pd.DataFrame({"root_uri": [root_uri]}) + sub_uris_df = pd.DataFrame( + {"sub_uris": [dataset.uri for dataset in data_project.datasets]} + ) + data_info = pd.concat([data_info, sub_uris_df], axis=1) + data_info["api_key"] = [TILED_KEY] * len(data_info) + else: + # Save filenames + uri_list = [] + for dataset in data_project.datasets: + uri_list.extend( + [dataset.uri + "/" + filename for filename in dataset.filenames] + ) + if correct_path: + root_uri = "/app/work/data" + data_project.root_uri.split(DATA_DIR, 1)[-1] + data_info = pd.DataFrame( + {"uri": [root_uri + "/" + uri for uri in uri_list]} + ) + data_info["splash_uri"] = [ + data_project.root_uri + "/" + uri for uri in uri_list + ] + else: + root_uri = data_project.root_uri + data_info = pd.DataFrame( + {"uri": [root_uri + "/" + uri for uri in uri_list]} + ) + data_info["type"] = data_type + data_info.to_parquet(f"{out_path}/data_info.parquet", engine="pyarrow") + return experiment_id, out_path, f"{out_path}/data_info.parquet" + + +def get_input_params(children): + """ + Gets the model parameters and its corresponding values + """ + input_params = {} + if bool(children): + try: + for child in children["props"]["children"]: + key = child["props"]["children"][1]["props"]["id"]["param_key"] + value = child["props"]["children"][1]["props"]["value"] + input_params[key] = value + except Exception: + for child in children: + key = child["props"]["children"][1]["props"]["id"] + value = child["props"]["children"][1]["props"]["value"] + input_params[key] = value + return input_params diff --git a/src/utils/job_utils.py b/src/utils/job_utils.py new file mode 100644 index 0000000..799e3c5 --- /dev/null +++ b/src/utils/job_utils.py @@ -0,0 +1,165 @@ +import json +import os +import urllib + +import requests +from dotenv import load_dotenv + +load_dotenv(".env") + +COMPUTE_URL = str(os.environ["MLEX_COMPUTE_URL"]) + + +class TableJob: + def __init__( + self, job_id, job_name, job_type, job_status, job_params, experiment_id + ): + self.job_id = job_id + self.name = job_name + self.job_type = job_type + self.status = job_status + self.parameters = job_params + self.experiment_id = experiment_id + pass + + @staticmethod + def compute_job_to_table_job(input_job): + compute_job = MlexJob(**input_job) + params = compute_job.job_kwargs["kwargs"]["params"] + if compute_job.job_kwargs["kwargs"]["job_type"].split()[0] != "train_model": + params = f"{params}\nTraining Parameters: {compute_job.job_kwargs['kwargs']['train_params']}" + return TableJob( + compute_job.uid, + compute_job.description, + compute_job.job_kwargs["kwargs"]["job_type"], + compute_job.status["state"], + str(params), + compute_job.job_kwargs["kwargs"]["experiment_id"], + ) + + @staticmethod + def get_job(user, mlex_app, job_type=None, deploy_location=None, job_id=None): + """ + Queries the job from the computing database + Args: + user: username + mlex_app: mlexchange application + job_type: type of job + deploy_location: deploy location + Returns: + list of jobs that match the query + """ + url = f"{COMPUTE_URL}/jobs?" + if job_id: + response = urllib.request.urlopen(f"{url[:-1]}/{job_id}") + else: + if user: + url += "&user=" + user + if mlex_app: + url += "&mlex_app=" + mlex_app + if job_type: + url += "&job_type=" + job_type + if deploy_location: + url += "&deploy_location=" + deploy_location + response = urllib.request.urlopen(url) + data = json.loads(response.read()) + return data + + @staticmethod + def terminate_job(job_uid): + requests.patch(f"{COMPUTE_URL}/jobs/{job_uid}/terminate") + pass + + @staticmethod + def delete_job(job_uid): + requests.delete(f"{COMPUTE_URL}/jobs/{job_uid}/delete") + pass + + @staticmethod + def get_counter(username): + job_list = TableJob.get_job(username, "mlcoach") + job_types = ["train_model", "prediction_model"] + counters = [-1, -1] + if job_list is not None: + for indx, job_type in enumerate(job_types): + for job in reversed(job_list): + last_job = job["job_kwargs"]["kwargs"]["job_type"] + if job["description"]: + job_name = job["description"].split() + else: + job_name = job["job_kwargs"]["kwargs"]["job_type"].split() + if ( + last_job == job_type + and job_name[0] == job_type + and len(job_name) == 2 + and job_name[-1].isdigit() + ): + value = int(job_name[-1]) + counters[indx] = value + break + return counters + + +class MlexJob: + def __init__( + self, + service_type, + description, + working_directory, + job_kwargs, + mlex_app="mlcoach", + status={"state": "queue"}, + logs="", + uid="", + **kwargs, + ): + self.uid = uid + self.mlex_app = mlex_app + self.description = description + self.service_type = service_type + self.working_directory = working_directory + self.job_kwargs = job_kwargs + self.status = status + self.logs = logs + + def submit(self, user, num_cpus, num_gpus): + """ + Sends job to computing service + Args: + user: user UID + num_cpus: Number of CPUs + num_gpus: Number of GPUs + Returns: + Workflow status + """ + workflow = { + "user_uid": user, + "job_list": [self.__dict__], + "host_list": [ + "mlsandbox.als.lbl.gov", + "local.als.lbl.gov", + "vaughan.als.lbl.gov", + ], + "dependencies": {"0": []}, + "requirements": { + "num_processors": num_cpus, + "num_gpus": num_gpus, + "num_nodes": 1, + }, + } + url = f"{COMPUTE_URL}/workflows" + return requests.post(url, json=workflow).status_code + + +def get_host(host_nickname): + hosts = requests.get(f"{COMPUTE_URL}/hosts?&nickname={host_nickname}").json() + max_processors = hosts[0]["backend_constraints"]["num_processors"] + max_gpus = hosts[0]["backend_constraints"]["num_gpus"] + return max_processors, max_gpus + + +def str_to_dict(text): + text = text.replace("True", "true") + text = text.replace("False", "false") + text = text.replace("None", "null") + return json.loads(text.replace("'", '"')) diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py new file mode 100644 index 0000000..5505262 --- /dev/null +++ b/src/utils/model_utils.py @@ -0,0 +1,40 @@ +import os + +import requests + +CONTENT_URL = str(os.environ["MLEX_CONTENT_URL"]) + + +def get_model_list(): + """ + Get a list of algorithms from content registry + """ + response = requests.get(f"{CONTENT_URL}/models") + models = [] + for item in response.json(): + if "mlcoach" in item["application"]: + models.append({"label": item["name"], "value": item["content_id"]}) + return models + + +def get_gui_components(model_uid, comp_group): + """ + Returns the GUI components of the corresponding model and action + Args: + model_uid: Model UID + comp_group: Action, e.g. training, testing, etc + Returns: + params: List of model parameters + """ + response = requests.get( + f"{CONTENT_URL}/models/{model_uid}/model/{comp_group}/gui_params" + ) + return response.json() + + +def get_model_content(content_id): + """ + Get the model content: uri and commands + """ + response = requests.get(f"{CONTENT_URL}/contents/{content_id}/content").json() + return response["uri"], response["cmd"] diff --git a/src/utils/plot_utils.py b/src/utils/plot_utils.py new file mode 100644 index 0000000..2d22af9 --- /dev/null +++ b/src/utils/plot_utils.py @@ -0,0 +1,75 @@ +import base64 + +import numpy as np +import pandas as pd +import plotly +import plotly.express as px +import plotly.graph_objects as go +from PIL import Image +from plotly.subplots import make_subplots + + +def generate_loss_plot(loss_file_path): + """ + Generate loss plot + Args: + loss_file_path: Path to the loss file + Returns: + loss plot + """ + df = pd.read_csv(loss_file_path) + df.set_index("epoch", inplace=True) + fig = make_subplots(specs=[[{"secondary_y": True}]]) + cols = list(df.columns) + for col in cols: + if "loss" in col: + fig.add_trace( + go.Scatter(x=df.index, y=df[col], name=col), secondary_y=False + ) + fig.update_yaxes(title_text="loss", secondary_y=False) + else: + fig.add_trace(go.Scatter(x=df.index, y=df[col], name=col), secondary_y=True) + fig.update_yaxes(title_text="accuracy", secondary_y=True, range=[0, 1]) + fig.update_layout(xaxis_title="epoch", margin=dict(l=20, r=20, t=20, b=20)) + return fig + + +def get_class_prob(probs): + """ + Generate plot of probabilities per class + Args: + prob:L probabilities per class + Returns: + plot of probabilities per class + """ + probs.name = None + probs = probs.to_frame().T + fig = px.bar(probs) + fig.update_layout( + yaxis_title="probability", + legend_title_text="Labels", + margin=dict(l=20, r=20, t=20, b=20), + ) + fig.update_xaxes( + showgrid=False, visible=False, showticklabels=False, zeroline=False + ) + return fig + + +def plot_figure(image=None): + """ + Plot input data + """ + if not image: # Create a blank image + blank_image = np.zeros((200, 200, 3), dtype=np.uint8) + image = Image.fromarray(blank_image) + fig = px.imshow(image, height=200, width=200) + else: + fig = px.imshow(image, height=500) + fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) + fig.update_yaxes(showgrid=False, showticklabels=False, zeroline=False) + fig.update_layout(coloraxis_showscale=False) + fig.update_layout(margin=dict(l=0, r=10, t=0, b=10)) + png = plotly.io.to_image(fig, format="jpg") + png_base64 = base64.b64encode(png).decode("ascii") + return "data:image/jpg;base64,{}".format(png_base64)