diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 05d94e896..64ef0302e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -25,10 +25,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -42,10 +42,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -59,10 +59,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -77,10 +77,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -96,10 +96,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -113,10 +113,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -130,10 +130,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -148,10 +148,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -166,10 +166,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -184,10 +184,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest @@ -199,7 +199,7 @@ jobs: pip install .[pytorch_cpu] - name: Run pytest tests run: | - pytest -vx tests/version_test.py + pytest -vx tests/test_version.py pytest -vx tests/test_num_params.py pytest -vx tests/test_param_shapes.py pytest -vx tests/test_param_types.py @@ -208,10 +208,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 89b5ef288..992496b69 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -7,17 +7,17 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install pylint run: | python -m pip install --upgrade pip pip install pylint==2.16.1 - name: Run pylint run: | - pylint algorithmic_efficiency + pylint algoperf pylint reference_algorithms pylint prize_qualification_baselines pylint submission_runner.py @@ -27,14 +27,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install isort run: | python -m pip install --upgrade pip - pip install isort + pip install isort==5.12.0 - name: Run isort run: | isort . --check --diff @@ -43,14 +43,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install yapf run: | python -m pip install --upgrade pip - pip install yapf==0.32 + pip install yapf==0.32 toml - name: Run yapf run: | yapf . --diff --recursive diff --git a/.github/workflows/regression_tests_python_upgrade.yml b/.github/workflows/regression_tests_python_upgrade.yml deleted file mode 100644 index 783395353..000000000 --- a/.github/workflows/regression_tests_python_upgrade.yml +++ /dev/null @@ -1,183 +0,0 @@ -name: Containerized Regression Tests Python Upgrades - -on: - pull_request: - branches: - - 'python_test_env_upgrade' - -jobs: - build_and_push_jax_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=jax - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - build_and_push_pytorch_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=pytorch - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - fastmri_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_resnet_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_vit_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - ogbg_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_conformer_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_deepspeech_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - wmt_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - fastmri_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_resnet_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_vit_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - ogbg_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - exit $? - librispeech_conformer_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_deepspeech_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - wmt_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index ef1585d0d..b234575b7 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -72,7 +72,7 @@ jobs: run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_resnet_pytorch: + criteo_embed_init_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image steps: diff --git a/.gitignore b/.gitignore index d2e212366..7d35f0ccc 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,8 @@ makefile *.swp */data/ *events.out.tfevents* -algorithmic_efficiency/workloads/librispeech_conformer/data_dir -algorithmic_efficiency/workloads/librispeech_conformer/work_dir +algoperf/workloads/librispeech_conformer/data_dir +algoperf/workloads/librispeech_conformer/work_dir *.flac *.npy *.csv @@ -23,4 +23,6 @@ wandb/ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv \ No newline at end of file +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv + +algoperf/_version.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 95cd40775..685926506 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,34 +4,39 @@ - Finalized variant workload targets. - Fix in random_utils helper function. -- For conformer PyTorch Dropout layers set `inplace=True`. +- For conformer PyTorch Dropout layers set `inplace=True`. - Clear CUDA cache at begining of each trial for PyTorch. ## algoperf-benchmark-0.1.4 (2024-03-26) Upgrade CUDA version to CUDA 12.1: + - Upgrade CUDA version in Dockerfiles that will be used for scoring. - Update Jax and PyTorch package version tags to use local CUDA installation. -Add flag for completely disabling checkpointing. +Add flag for completely disabling checkpointing. + - Note that we will run with checkpointing off at scoring time. -Update Deepspeech and Conformer variant target setting configurations. -- Note that variant targets are not final. +Update Deepspeech and Conformer variant target setting configurations. + +- Note that variant targets are not final. Fixed bug in scoring code to take best trial in a study for external-tuning ruleset. -Added instructions for submission. +Added instructions for submission. -Changed default number of workers for PyTorch data loaders to 0. Running with >0 may lead to incorrect eval results see https://github.com/mlcommons/algorithmic-efficiency/issues/732. +Changed default number of workers for PyTorch data loaders to 0. Running with >0 may lead to incorrect eval results see . ## algoperf-benchmark-0.1.2 (2024-03-04) + Workload variant additions and fixes: + - Add Deepspeech workload variant - Fix bugs in Imagenet ResNet, WMT and Criteo1tb variants Add prize qualification logs for external tuning ruleset. -Note: FastMRI trials with dropout are not yet added due to https://github.com/mlcommons/algorithmic-efficiency/issues/664. +Note: FastMRI trials with dropout are not yet added due to . Add missing funcitonality to Docker startup script for self_tuning ruleset. Add self_tuning ruleset option to script that runs all workloads for scoring. @@ -41,6 +46,7 @@ Datasetup fixes. Fix tests that check training differences in PyTorch and JAX on GPU. ## algoperf-benchmark-0.1.1 (2024-01-19) + Bug fixes to FastMRI metric calculation and targets. Added workload variants and targets for ogbg, fastmri, librispeech_conformer, imagenet_resnet, imagenet_vit, criteo1tb to be used as held-out workloads. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 364bbee62..c98a5009e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -22,6 +22,7 @@ - [Style Testing](#style-testing) - [Unit and Integration Tests](#unit-and-integration-tests) - [Regression Tests](#regression-tests) + - [Versioning](#versioning) ## Contributing to MLCommons @@ -204,7 +205,7 @@ docker run -t -d \ -v $HOME/data/:/data/ \ -v $HOME/experiment_runs/:/experiment_runs \ -v $HOME/experiment_runs/logs:/logs \ --v $HOME/algorithmic-efficiency:/algorithmic-efficiency \ +-v $HOME/algorithmic-efficiency:/algoperf \ --gpus all \ --ipc=host \ \ @@ -228,7 +229,7 @@ To run the below commands, use the versions installed via `pip install -e '.[dev To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!): ```bash -yapf -i -r -vv -p algorithmic_efficiency datasets prize_qualification_baselines reference_algorithms tests *.py +yapf -i -r -vv -p algoperf datasets prize_qualification_baselines reference_algorithms tests *.py ``` To sort all import orderings, run the following: @@ -246,7 +247,7 @@ isort . --check --diff To print out all offending pylint issues, run the following: ```bash -pylint algorithmic_efficiency +pylint algoperf pylint datasets pylint prize_qualification_baselines pylint reference_algorithms @@ -276,3 +277,15 @@ To run a regression test: 2. Turn on the self-hosted runner. 3. Run the self-hosted runner application for the runner to accept jobs. 4. Open a pull request into mian to trigger the workflow. + +### Versioning + +The package version is automatically determined by the `setuptools_scm` package based on the last git tag. +It follows the structure `major.minor.patch` + `devN` where `N` is the number of commits since the last tag. +It automatically increments the patch version (i.e. it guesses the next version) if there are commits after the last tag. +Additionally, if there are uncommitted changes, the version will include a suffix separated by a `+` character and includes the last commit hash plus the date on dirt workdir (see [setuptools_scm's documentation](https://setuptools-scm.readthedocs.io/en/latest/extending/#setuptools_scmlocal_scheme) with the default version and local scheme). +You can check what version `setuptools_scm` is creating by running `python -m setuptools_scm`. + +To create a new version, create a new release (and tag) in the GitHub UI. +The package version is automatically updated to the new version. +Once the package is installed, the version can be accessed as the package attribute `algoperf.__version__`, i.e. via `python -c "import algoperf; print(algoperf.__version__)"`. diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 8722a441e..795846efd 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload). Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally. -The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. +The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. ##### Fixed functions @@ -91,7 +91,7 @@ With the exception of `_build_input_queue`, submitters can call any of these fun def step_hint(self): -> int ``` -- The `step_hint` function gives the number of global steps the baseline algorithm was allowed to use to reach the targets for a workload. Note that the baseline algorithms may have reached the target in fewer steps than this, but these were the max number of steps the baseline algorithms used for their learning rate schedules. Submitters can use this to help specify learning rate (or other) schedules. +- The `step_hint` function gives the number of global steps the baseline algorithm can perform with the `max_runtime` to reach the targets for a workload. The `step_hint` is therefore dependent on the `max_runtime` and the workload. Note that the baseline algorithms may have reached the target in fewer steps than this, but these were the max number of steps the baseline algorithms used for their learning rate schedules. Submitters can use this to help specify learning rate (or other) schedules. ###### Data augmentation and preprocessing @@ -220,9 +220,34 @@ def update_params( - Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques. - The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort. - Cannot replace the model parameters with pre-trained ones. -- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. - Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`. +###### Prepare for evaluation function + +```python +def prepare_for_eval( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState +) -> (updated_optimizer_state, updated_variables, updated_model_state) +``` + +- Arguments are the same of `update_param`, with the only exception of `batch`. +- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section). + - The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time. + - The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call. +- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. +- Allowed to update model state and model parameters. +- Allowed to update state for the optimizer. +- Cannot replace the model parameters with pre-trained ones. + ###### Data selection ```python @@ -252,7 +277,8 @@ def data_selection( In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms. -Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. +Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring. +The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. #### Valid submissions @@ -392,7 +418,7 @@ In each trial, the tuning trial with the fastest training time to achieve the *v Submissions to this ruleset are not allowed to have user-defined hyperparameters. This ruleset allows both submissions that use the same hyperparameters for all workloads, including the randomized ones (e.g. Adam with default parameters), as well as submissions that perform inner-loop tuning during their training run (e.g. SGD with line searches). -Submissions will run on one instance of the [benchmarking hardware](#benchmarking-hardware). As always, submissions are allowed to perform inner-loop tuning (e.g. for their learning rate) but the tuning efforts will be part of their score. A submission will run *S=5* times and its score will be the median time to reach the target evaluation metric value on the validation set. To account for the lack of external tuning, submissions have a longer time budget to reach the target performance. Compared to the [external tuning ruleset](#external-tuning-ruleset), the `max_runtime` is tripled. Runs that do not reach the target performance of the evaluation metric within this allotted time budget have an infinite time. +Submissions will run on one instance of the [benchmarking hardware](#benchmarking-hardware). As always, submissions are allowed to perform inner-loop tuning (e.g. for their learning rate) but the tuning efforts will be part of their score. A submission will run *S=5* times and its score will be the median time to reach the target evaluation metric value on the validation set. To account for the lack of external tuning, submissions have a longer time budget to reach the target performance. Compared to the [external tuning ruleset](#external-tuning-ruleset), the `max_runtime` is $1.5$ times longer. Runs that do not reach the target performance of the evaluation metric within this allotted time budget have an infinite time. ### Workloads @@ -413,11 +439,24 @@ The currently eight fixed workloads are: | | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation
**Target** | Test
**Target** | Maximum
**Runtime**
(in secs) | |------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------| | **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 | -| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 | -| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 63,008
77,520 | -| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 61,068
55,506 | -| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | -| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | +| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 4,430 | +| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 66,159
69,768 | +| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 58,015
44,405 | +| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,011 | +| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 43,336 | + +Default Dropout Values for Different Workloads: + +| Workload | Dropout Values | +|------------------------|------------------------------------------------------------------------------------------------------| +| criteo 1tb | dropout_rate: 0.0 | +| fastmri | dropout_rate: 0.0 | +| imagenet_resnet | dropout not used | +| imagenet_vit | dropout_rate: 0.0 | +| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | +| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | +| ogbg | dropout_rate: 0.1 | +| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | #### Randomized workloads @@ -464,7 +503,7 @@ For self-reported results, it is acceptable to perform the tuning trials on hard Target performances on the validation and test sets will be defined for each [workload](#workloads) separately. For the [fixed workloads](#fixed-workloads), we take the best performance achievable by one of four standard algorithms (AdamW, NadamW, Nesterov Momentum, and Heavy Ball Momentum). These target-setting algorithms will follow the general process of the external tuning ruleset, with a significantly larger tuning budget of $200$ trials to guarantee competitive performance. Once the best algorithm and its hyperparameters are determined, training is repeated $20$ times. The median of the best achieved validation errors across seeds is used as the *validation* target. Out of the $10$ repeated runs that achieved this validation target, we took the worst achieved test error across seeds as our *test* target. Taking the median validation performance after rerunning the best hyperparameter point prevents our procedure from selecting a lucky outlier. To save computational resources, we only tuned two training algorithms instead of four, for the [randomized workloads](#randomized-workloads). For each workload variant, we used NadamW and the other best-performing training algorithm on the corresponding fixed workload the randomized workload is based on. -Both [tuning rulesets](#tuning) will use the same target performances. The runtime of the target-setting algorithms on each workload will be chosen to match published results and is constrained by the overall time budget of roughly a single week for all fixed workloads. The `max_runtime` for submissions on each workload is $\frac{1}{3}$ longer than the runtime of the target-setting algorithms (this `max_runtime` will be three times as much for the self-tuning ruleset, see the [Self-tuning ruleset](#self-tuning-ruleset) section). +Both [tuning rulesets](#tuning) will use the same target performances. The runtime of the target-setting algorithms on each workload will be chosen to match published results and is constrained by the overall time budget of roughly a single week for all fixed workloads. The initial `max_runtime` for submissions on each workload was $\frac{1}{3}$ longer than the runtime of the target-setting algorithms (this `max_runtime` will be $1.5$ times as much for the self-tuning ruleset, see the [Self-tuning ruleset](#self-tuning-ruleset) section). After the initial round of submissions, we have adapated the `max_runtime` based on the performance of the submissions (see [this issue](https://github.com/mlcommons/algorithmic-efficiency/issues/836)). #### Benchmark score using performance profiles @@ -602,4 +641,4 @@ That said, while submitting Adam with some novel heuristic to set various hyperp The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT workloads use the same TensorFlow input pipelines. Due to differences in how JAX and PyTorch distribute computations across devices, the PyTorch workloads have an additional overhead for these workloads. Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. -While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. +While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algoperf/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 006b972ec..e59463f88 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -18,6 +18,8 @@ - [Docker Tips](#docker-tips) - [Score your Submission](#score-your-submission) - [Running workloads](#running-workloads) +- [Package your Submission code](#package-your-submission-code) +- [Package Logs for Self-Reporting Submissions](#package-logs-for-self-reporting-submissions) ## Set Up and Installation @@ -35,7 +37,7 @@ The specs on the benchmarking machines are: > **Prerequisites:** > -> - Python minimum requirement >= 3.8 +> - Python minimum requirement >= 3.11 > - CUDA 12.1 > - NVIDIA Driver version 535.104.05 @@ -56,7 +58,7 @@ To set up a virtual enviornment and install this repository cd algorithmic-efficiency ``` -3. Run the following pip3 install commands based on your chosen framework to install `algorithmic_efficiency` and its dependencies. +3. Run the following pip3 install commands based on your chosen framework to install `algoperf` and its dependencies. For **JAX**: @@ -80,7 +82,6 @@ To set up a virtual enviornment and install this repository pip3 install -e '.[full]' ``` -
Per workload installations @@ -414,22 +415,24 @@ submission_folder/ ``` Specifically we require that: + 1. There exist subdirectories in the the submission folder named after the ruleset: `external_tuning` or `self_tuning`. -2. The ruleset subdirectories contain directories named according to -some identifier of the algorithm. -3. Each algorithm subdirectory contains a `submission.py` module. Additional helper modules are allowed if prefer to you organize your code into multiple files. If there are additional python packages that have to be installed for the algorithm also include a `requirements.txt` with package names and versions in the algorithm subdirectory. +2. The ruleset subdirectories contain directories named according to +some identifier of the algorithm. +3. Each algorithm subdirectory contains a `submission.py` module. Additional helper modules are allowed if prefer to you organize your code into multiple files. If there are additional python packages that have to be installed for the algorithm also include a `requirements.txt` with package names and versions in the algorithm subdirectory. 4. For `external_tuning` algorithms the algorithm subdirectory should contain a `tuning_search_space.json`. To check that your submission folder meets the above requirements you can run the `submissions/repo_checker.py` script. ## Package Logs for Self-Reporting Submissions + To prepare your submission for self reporting run: -``` +```bash python3 package_logs.py --experiment_dir --destination_dir ``` -The destination directiory will contain the logs packed in studies and trials required for self-reporting. +The destination directiory will contain the logs packed in studies and trials required for self-reporting. **Good Luck!** diff --git a/algoperf/__init__.py b/algoperf/__init__.py new file mode 100644 index 000000000..7d54f8290 --- /dev/null +++ b/algoperf/__init__.py @@ -0,0 +1,5 @@ +"""Algorithmic Efficiency.""" + +from ._version import version as __version__ + +__all__ = ["__version__"] diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algoperf/checkpoint_utils.py similarity index 98% rename from algorithmic_efficiency/checkpoint_utils.py rename to algoperf/checkpoint_utils.py index 29c1a821e..f4cb6c2db 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -16,8 +16,8 @@ from tensorflow.io import gfile # pytype: disable=import-error import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup _, _, DEVICE, _ = pytorch_setup() CheckpointReturn = Tuple[spec.OptimizerState, @@ -231,7 +231,7 @@ def save_checkpoint(framework: str, target=checkpoint_state, step=global_step, overwrite=True, - keep=np.Inf if save_intermediate_checkpoints else 1) + keep=np.inf if save_intermediate_checkpoints else 1) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( diff --git a/algorithmic_efficiency/data_utils.py b/algoperf/data_utils.py similarity index 99% rename from algorithmic_efficiency/data_utils.py rename to algoperf/data_utils.py index 901f0b582..37d1bd20f 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algorithmic_efficiency import spec +from algoperf import spec def shard_and_maybe_pad_np( @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algorithmic_efficiency/halton.py b/algoperf/halton.py similarity index 97% rename from algorithmic_efficiency/halton.py rename to algoperf/halton.py index 9eb30861d..1f36b07bf 100644 --- a/algorithmic_efficiency/halton.py +++ b/algoperf/halton.py @@ -10,13 +10,13 @@ import functools import itertools import math -from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from absl import logging from numpy import random -_SweepSequence = List[Dict[Text, Any]] -_GeneratorFn = Callable[[float], Tuple[Text, float]] +_SweepSequence = List[Dict[str, Any]] +_GeneratorFn = Callable[[float], Tuple[str, float]] def generate_primes(n: int) -> List[int]: @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: Text, +def _generate_double_point(name: str, min_val: float, max_val: float, - scaling: Text, + scaling: str, halton_point: float) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]: return start, end -def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: +def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints return functools.partial(_generate_double_point, name, @@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( - name: Text, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + name: str, search_points: Union[_DiscretePoints, + Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, name, diff --git a/algorithmic_efficiency/init_utils.py b/algoperf/init_utils.py similarity index 86% rename from algorithmic_efficiency/init_utils.py rename to algoperf/init_utils.py index 66ed041ce..185480cc7 100644 --- a/algorithmic_efficiency/init_utils.py +++ b/algoperf/init_utils.py @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None: # Perform lecun_normal initialization. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) std = math.sqrt(1. / fan_in) / .87962566103423978 - nn.init.trunc_normal_(module.weight, std=std) + nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std) if module.bias is not None: nn.init.constant_(module.bias, 0.) diff --git a/algorithmic_efficiency/interop_utils.py b/algoperf/interop_utils.py similarity index 90% rename from algorithmic_efficiency/interop_utils.py rename to algoperf/interop_utils.py index e307042a9..0c6535d7a 100644 --- a/algorithmic_efficiency/interop_utils.py +++ b/algoperf/interop_utils.py @@ -1,7 +1,7 @@ import jax.dlpack import torch -from algorithmic_efficiency import spec +from algoperf import spec def jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False) -> spec.Tensor: diff --git a/algorithmic_efficiency/logger_utils.py b/algoperf/logger_utils.py similarity index 98% rename from algorithmic_efficiency/logger_utils.py rename to algoperf/logger_utils.py index 609d996e6..c988956dc 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algoperf/logger_utils.py @@ -18,8 +18,8 @@ import psutil import torch.distributed as dist -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, RANK, DEVICE, _ = pytorch_setup() @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict: system_software_info['os_platform'] = \ platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' system_software_info['python_version'] = platform.python_version( - ) # Ex. '3.8.10' + ) # Ex. '3.11.10' system_software_info['python_compiler'] = platform.python_compiler( ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive diff --git a/algorithmic_efficiency/param_utils.py b/algoperf/param_utils.py similarity index 98% rename from algorithmic_efficiency/param_utils.py rename to algoperf/param_utils.py index b430366b1..05d882404 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algoperf/param_utils.py @@ -6,7 +6,7 @@ import jax from torch import nn -from algorithmic_efficiency import spec +from algoperf import spec def pytorch_param_shapes(model: nn.Module) -> Dict[str, spec.ShapeTuple]: @@ -66,7 +66,7 @@ def pytorch_param_types( def jax_param_shapes( params: spec.ParameterContainer) -> spec.ParameterShapeTree: - return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) + return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params) def jax_param_types(param_shapes: spec.ParameterShapeTree, diff --git a/algorithmic_efficiency/profiler.py b/algoperf/profiler.py similarity index 100% rename from algorithmic_efficiency/profiler.py rename to algoperf/profiler.py diff --git a/algorithmic_efficiency/pytorch_utils.py b/algoperf/pytorch_utils.py similarity index 89% rename from algorithmic_efficiency/pytorch_utils.py rename to algoperf/pytorch_utils.py index 590f500fa..4a674985d 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -7,11 +7,11 @@ import torch import torch.distributed as dist -from algorithmic_efficiency import spec -from algorithmic_efficiency.profiler import Profiler -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf import spec +from algoperf.profiler import Profiler +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ BatchNorm as ConformerBatchNorm -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ BatchNorm as DeepspeechBatchNorm diff --git a/algorithmic_efficiency/random_utils.py b/algoperf/random_utils.py similarity index 92% rename from algorithmic_efficiency/random_utils.py rename to algoperf/random_utils.py index cf1ea6c32..a579976ad 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algoperf/random_utils.py @@ -16,21 +16,21 @@ FLAGS = flags.FLAGS -# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an +# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 -MIN_INT32 = -MAX_INT32 +MAX_INT32 = 2**31 - 1 +MIN_INT32 = 0 SeedType = Union[int, list, np.ndarray] def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_INT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_INT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_INT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: diff --git a/algorithmic_efficiency/spec.py b/algoperf/spec.py similarity index 93% rename from algorithmic_efficiency/spec.py rename to algoperf/spec.py index 7bc86b505..cf4f1a14e 100644 --- a/algorithmic_efficiency/spec.py +++ b/algoperf/spec.py @@ -206,7 +206,7 @@ def eval_period_time_sec(self) -> int: @property @abc.abstractmethod def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" + """Approx. steps the baseline can do in the allowed runtime budget.""" @property def param_shapes(self): @@ -431,6 +431,36 @@ def update_params(workload: Workload, pass +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] + + +# Prepare model and optimizer for evaluation. +def prepare_for_eval(workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState) -> UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + pass + + DataSelectionFn = Callable[[ Workload, Iterator[Dict[str, Any]], diff --git a/algorithmic_efficiency/workloads/__init__.py b/algoperf/workloads/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/__init__.py rename to algoperf/workloads/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/__init__.py b/algoperf/workloads/cifar/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/cifar/__init__.py rename to algoperf/workloads/cifar/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/__init__.py b/algoperf/workloads/cifar/cifar_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/__init__.py rename to algoperf/workloads/cifar/cifar_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py rename to algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 3e6a68844..728d05f29 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -13,8 +13,8 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import spec -from algorithmic_efficiency.data_utils import shard_and_maybe_pad_np +from algoperf import spec +from algoperf.data_utils import shard_and_maybe_pad_np def preprocess_for_train(image: spec.Tensor, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py similarity index 93% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/models.py rename to algoperf/workloads/cifar/cifar_jax/models.py index 059352fb6..957079272 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -10,9 +10,8 @@ from flax import linen as nn import jax.numpy as jnp -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNetBlock +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock ModuleDef = nn.Module diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py similarity index 92% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py rename to algoperf/workloads/cifar/cifar_jax/workload.py index 8268c6ca3..ad43bc62f 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -5,18 +5,18 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp import optax import tensorflow_datasets as tfds -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.cifar.cifar_jax import models -from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \ - create_input_iter -from algorithmic_efficiency.workloads.cifar.workload import BaseCifarWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.cifar.cifar_jax import models +from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter +from algoperf.workloads.cifar.workload import BaseCifarWorkload class CifarWorkload(BaseCifarWorkload): @@ -75,8 +75,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -93,7 +93,7 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) @@ -206,4 +206,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/__init__.py b/algoperf/workloads/cifar/cifar_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/cifar/cifar_pytorch/__init__.py rename to algoperf/workloads/cifar/cifar_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py similarity index 92% rename from algorithmic_efficiency/workloads/cifar/cifar_pytorch/models.py rename to algoperf/workloads/cifar/cifar_pytorch/models.py index b592e10ab..e6a7a8a81 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -10,14 +10,13 @@ import torch from torch import nn -from algorithmic_efficiency import spec -from algorithmic_efficiency.init_utils import pytorch_default_init -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf import spec +from algoperf.init_utils import pytorch_default_init +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ BasicBlock -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ Bottleneck -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ - conv1x1 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1 class ResNet(nn.Module): diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py rename to algoperf/workloads/cifar/cifar_pytorch/workload.py index 7abcf4d6c..d05131c27 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -12,13 +12,12 @@ from torchvision import transforms from torchvision.datasets import CIFAR10 -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.cifar.cifar_pytorch.models import \ - resnet18 -from algorithmic_efficiency.workloads.cifar.workload import BaseCifarWorkload +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.cifar.cifar_pytorch.models import resnet18 +from algoperf.workloads.cifar.workload import BaseCifarWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -82,7 +81,7 @@ def _build_dataset( } if split == 'eval_train': train_indices = indices_split['train'] - random.Random(data_rng[0]).shuffle(train_indices) + random.Random(int(data_rng[0])).shuffle(train_indices) indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) diff --git a/algorithmic_efficiency/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/cifar/workload.py rename to algoperf/workloads/cifar/workload.py index 9e36cb291..c0d565108 100644 --- a/algorithmic_efficiency/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -7,9 +7,9 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -import algorithmic_efficiency.random_utils as prng +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +import algoperf.random_utils as prng USE_PYTORCH_DDP, _, _, _ = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/criteo1tb/__init__.py b/algoperf/workloads/criteo1tb/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/__init__.py rename to algoperf/workloads/criteo1tb/__init__.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/__init__.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/__init__.py rename to algoperf/workloads/criteo1tb/criteo1tb_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py rename to algoperf/workloads/criteo1tb/criteo1tb_jax/models.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py rename to algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 3743dc1ff..91761e458 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -8,10 +8,10 @@ import jax.numpy as jnp import numpy as np -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models -from algorithmic_efficiency.workloads.criteo1tb.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/__init__.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/__init__.py rename to algoperf/workloads/criteo1tb/criteo1tb_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py rename to algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py rename to algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 446267440..726aa8705 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -7,11 +7,11 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch import models -from algorithmic_efficiency.workloads.criteo1tb.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models +from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py b/algoperf/workloads/criteo1tb/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py rename to algoperf/workloads/criteo1tb/input_pipeline.py index cb091b3a5..7e254336a 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py +++ b/algoperf/workloads/criteo1tb/input_pipeline.py @@ -12,7 +12,7 @@ import tensorflow as tf -from algorithmic_efficiency import data_utils +from algoperf import data_utils _NUM_DAY_23_FILES = 36 diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/criteo1tb/workload.py rename to algoperf/workloads/criteo1tb/workload.py index f18f2656f..617b2e987 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -7,8 +7,8 @@ from absl import flags import torch.distributed as dist -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb import input_pipeline +from algoperf import spec +from algoperf.workloads.criteo1tb import input_pipeline FLAGS = flags.FLAGS @@ -93,7 +93,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7703 # ~2 hours. + return 7_703 # ~2.1 hours. @property def eval_period_time_sec(self) -> int: @@ -123,7 +123,7 @@ def _build_input_queue( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" + """Approx. steps the baseline can do in the allowed runtime budget.""" return 10_666 def _eval_model_on_split(self, diff --git a/algorithmic_efficiency/workloads/fastmri/__init__.py b/algoperf/workloads/fastmri/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/__init__.py rename to algoperf/workloads/fastmri/__init__.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/__init__.py b/algoperf/workloads/fastmri/fastmri_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/__init__.py rename to algoperf/workloads/fastmri/fastmri_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py rename to algoperf/workloads/fastmri/fastmri_jax/models.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py b/algoperf/workloads/fastmri/fastmri_jax/ssim.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py rename to algoperf/workloads/fastmri/fastmri_jax/ssim.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py similarity index 94% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py rename to algoperf/workloads/fastmri/fastmri_jax/workload.py index a5dfe8c22..1156cf30a 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -8,13 +8,12 @@ import jax import jax.numpy as jnp -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.models import UNet -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import ssim -from algorithmic_efficiency.workloads.fastmri.workload import \ - BaseFastMRIWorkload +from algoperf import param_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_jax.models import UNet +from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim +from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload class FastMRIWorkload(BaseFastMRIWorkload): diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/__init__.py b/algoperf/workloads/fastmri/fastmri_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/__init__.py rename to algoperf/workloads/fastmri/fastmri_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py similarity index 99% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py rename to algoperf/workloads/fastmri/fastmri_pytorch/models.py index 6c0ab19e2..28f20bf20 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -12,7 +12,7 @@ from torch import Tensor from torch.nn import functional as F -from algorithmic_efficiency import init_utils +from algoperf import init_utils class UNet(nn.Module): diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py similarity index 98% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py rename to algoperf/workloads/fastmri/fastmri_pytorch/ssim.py index eff6fb62f..45b61bea4 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torchvision.transforms.functional import pad as pad_fn -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup DEVICE = pytorch_setup()[2] diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py rename to algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..58943de2f 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -9,15 +9,13 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.models import \ - UNet -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import ssim -from algorithmic_efficiency.workloads.fastmri.workload import \ - BaseFastMRIWorkload +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim +from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/fastmri/input_pipeline.py b/algoperf/workloads/fastmri/input_pipeline.py similarity index 99% rename from algorithmic_efficiency/workloads/fastmri/input_pipeline.py rename to algoperf/workloads/fastmri/input_pipeline.py index 8f6ddafd1..f20611f43 100644 --- a/algorithmic_efficiency/workloads/fastmri/input_pipeline.py +++ b/algoperf/workloads/fastmri/input_pipeline.py @@ -9,7 +9,7 @@ import jax import tensorflow as tf -from algorithmic_efficiency import data_utils +from algoperf import data_utils _TRAIN_DIR = 'knee_singlecoil_train' _VAL_DIR = 'knee_singlecoil_val' diff --git a/algorithmic_efficiency/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py similarity index 93% rename from algorithmic_efficiency/workloads/fastmri/workload.py rename to algoperf/workloads/fastmri/workload.py index a8fd1abbb..051749cc3 100644 --- a/algorithmic_efficiency/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -3,8 +3,8 @@ import math from typing import Optional -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri import input_pipeline +from algoperf import spec +from algoperf.workloads.fastmri import input_pipeline class BaseFastMRIWorkload(spec.Workload): @@ -95,7 +95,7 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 8859 # ~2.5 hours + return 4_430 # ~1.2 hours @property def eval_period_time_sec(self) -> int: @@ -103,8 +103,8 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 36_189 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 18_094 def _build_input_queue(self, data_rng: spec.RandomState, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/__init__.py b/algoperf/workloads/imagenet_resnet/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/__init__.py rename to algoperf/workloads/imagenet_resnet/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/__init__.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/__init__.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/__init__.py diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py new file mode 100644 index 000000000..3d6939218 --- /dev/null +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -0,0 +1,438 @@ +""" +Note: +The following code is adapted from: +https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image + + +""" + +from typing import List, Optional, Union + +import numpy as np +import tensorflow as tf + +_IMAGE_DTYPES = { + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, +} + +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +def to_4d_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D `Tensor`. + + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message="`image` must be 2/3/4D tensor") + ]): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4d_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4d_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4d_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. + + Returns: + `ndims`-D `Tensor` with the same type. + """ + with tf.control_dependencies( + [tf.debugging.assert_rank(image, 4, + message="`image` must be 4D tensor")]): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4d_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4d_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def transform( + images: TensorLike, + transforms: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Applies the given transform(s) to the image(s). + + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or "transform"): + image_or_images = tf.convert_to_tensor(images, name="images") + transform_or_transforms = tf.convert_to_tensor( + transforms, name="transforms", dtype=tf.dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4d_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + if output_shape is None: + output_shape = tf.shape(images)[1:3] + + output_shape = tf.convert_to_tensor( + output_shape, tf.dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise ValueError("transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + transforms = transform_or_transforms + raise ValueError("transforms should have rank 1 or 2, but got rank %d" % + len(transforms.get_shape())) + + fill_value = tf.convert_to_tensor( + fill_value, dtype=tf.float32, name="fill_value") + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, + ) + return from_4d_image(output, original_ndims) + + +def angles_to_projective_transforms( + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, +) -> tf.Tensor: + """Returns projective transform(s) for the given angle(s). + + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or "angles_to_projective_transforms"): + angle_or_angles = tf.convert_to_tensor( + angles, name="angles", dtype=tf.dtypes.float32) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError("angles should have rank 0 or 1.") + + if len(angle_or_angles.get_shape()) == 0: + angles = angle_or_angles[None] + else: + angles = angle_or_angles + + cos_angles = tf.math.cos(angles) + sin_angles = tf.math.sin(angles) + x_offset = ((image_width - 1) - + (cos_angles * (image_width - 1) - sin_angles * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - + (sin_angles * (image_width - 1) + cos_angles * + (image_height - 1))) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def rotate_img( + images: TensorLike, + angles: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Rotate image(s) counterclockwise by the passed angle(s) in radians. + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "rotate"): + image_or_images = tf.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4d_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] + image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) + return from_4d_image(output, original_ndims) + + +def translations_to_projective_transforms(translations: TensorLike, + name: Optional[str] = None + ) -> tf.Tensor: + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or "translations_to_projective_transforms"): + translation_or_translations = tf.convert_to_tensor( + translations, name="translations", dtype=tf.dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError("Translations should have rank 1 or 2.") + + if len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + else: + translations = translation_or_translations + + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +@tf.function +def translate( + images: TensorLike, + translations: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Translate image(s) by the passed vectors(s). + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 422eb9f7a..66105335b 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -12,10 +12,9 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ - randaugment +from algoperf import data_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment TFDS_SPLIT_NAME = { 'train': 'train', 'eval_train': 'train', 'validation': 'validation' diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index 34cd17440..ffa60b260 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -10,7 +10,7 @@ from flax import linen as nn import jax.numpy as jnp -from algorithmic_efficiency import spec +from algoperf import spec ModuleDef = nn.Module diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py similarity index 96% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..c68e2de33 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,13 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image + +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + rotate_img +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + transform +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. @@ -176,19 +182,19 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = contrib_image.rotate(wrap(image), radians) + image = rotate_img(wrap(image), radians) return unwrap(image, replace) def translate_x(image, pixels, replace): """Equivalent of PIL Translate in X dimension.""" - image = contrib_image.translate(wrap(image), [-pixels, 0]) + image = translate(wrap(image), [-pixels, 0]) return unwrap(image, replace) def translate_y(image, pixels, replace): """Equivalent of PIL Translate in Y dimension.""" - image = contrib_image.translate(wrap(image), [0, -pixels]) + image = translate(wrap(image), [0, -pixels]) return unwrap(image, replace) @@ -198,8 +204,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = contrib_image.transform( - wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) return unwrap(image, replace) @@ -209,8 +214,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = contrib_image.transform( - wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) return unwrap(image, replace) @@ -478,13 +482,13 @@ def _parse_policy_info(name, # Check to see if prob is passed into function. This is used for operations # where we alter bboxes independently. - if 'prob' in inspect.getargspec(func)[0]: + if 'prob' in inspect.getfullargspec(func)[0]: args = tuple([prob] + list(args)) # Add in replace arg if it is required for the function that is being called. - if 'replace' in inspect.getargspec(func)[0]: + if 'replace' in inspect.getfullargspec(func)[0]: # Make sure replace is the final argument - assert 'replace' == inspect.getargspec(func)[0][-1] + assert 'replace' == inspect.getfullargspec(func)[0][-1] args = tuple(list(args) + [replace_value]) return (func, prob, args) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py similarity index 93% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 2747fc2db..4ec3937b8 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,21 +11,20 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp import optax import tensorflow_datasets as tfds -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet import imagenet_v2 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ - input_pipeline -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ - models -from algorithmic_efficiency.workloads.imagenet_resnet.workload import \ +from algoperf import param_utils +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.workloads.imagenet_resnet import imagenet_v2 +from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline +from algoperf.workloads.imagenet_resnet.imagenet_jax import models +from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload @@ -79,8 +78,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() # Create a shallow copy + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -111,7 +110,7 @@ def init_model_fn( input_shape = (1, 224, 224, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) @@ -263,7 +262,7 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples), + eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), eval_metrics) return eval_metrics diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/__init__.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/__init__.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/models.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index 2b9093940..aba9e671f 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -11,8 +11,8 @@ from torch import nn from torch import Tensor -from algorithmic_efficiency import spec -from algorithmic_efficiency.init_utils import pytorch_default_init +from algoperf import spec +from algoperf.init_utils import pytorch_default_init def conv3x3(in_planes: int, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py similarity index 99% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py index 829d82d74..c7a98e77a 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py @@ -14,7 +14,7 @@ from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode -from algorithmic_efficiency import spec +from algoperf import spec def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py similarity index 94% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..ed29271f3 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -16,17 +16,15 @@ from torchvision import transforms from torchvision.datasets.folder import ImageFolder -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.imagenet_resnet import imagenet_v2 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch import \ - randaugment -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet50 -from algorithmic_efficiency.workloads.imagenet_resnet.workload import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.imagenet_resnet import imagenet_v2 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 +from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -120,7 +118,7 @@ def _build_dataset( if split == 'eval_train': indices = list(range(self.num_train_examples)) - random.Random(data_rng[0]).shuffle(indices) + random.Random(int(data_rng[0])).shuffle(indices) dataset = torch.utils.data.Subset(dataset, indices[:self.num_eval_train_examples]) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py similarity index 88% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_v2.py rename to algoperf/workloads/imagenet_resnet/imagenet_v2.py index 05ab12eb1..84d364586 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -8,10 +8,9 @@ import tensorflow_datasets as tfds -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ - input_pipeline +from algoperf import data_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline def get_imagenet_v2_iter(data_dir: str, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/imagenet_resnet/workload.py rename to algoperf/workloads/imagenet_resnet/workload.py index 2e06805f7..83fe97108 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -3,7 +3,7 @@ import math from typing import Dict, Iterator, Optional, Tuple -from algorithmic_efficiency import spec +from algoperf import spec class BaseImagenetResNetWorkload(spec.Workload): @@ -102,7 +102,7 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 63_008 # ~17.5 hours + return 66_159 # ~18.4 hours @property def eval_period_time_sec(self) -> int: @@ -144,5 +144,5 @@ def _build_input_queue( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 186_666 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 195_999 diff --git a/algorithmic_efficiency/workloads/imagenet_vit/__init__.py b/algoperf/workloads/imagenet_vit/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_vit/__init__.py rename to algoperf/workloads/imagenet_vit/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/__init__.py b/algoperf/workloads/imagenet_vit/imagenet_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/__init__.py rename to algoperf/workloads/imagenet_vit/imagenet_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py rename to algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 639800b44..7ce3a0395 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -10,7 +10,7 @@ from flax import linen as nn import jax.numpy as jnp -from algorithmic_efficiency import spec +from algoperf import spec def posemb_sincos_2d(h: int, @@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module): def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, @@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = x + y else: y = x - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py similarity index 89% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py rename to algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..35a6c46be 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -4,18 +4,17 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax import jax.numpy as jnp -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax import models -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ - BaseImagenetVitWorkload -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ - decode_variant +from algoperf.workloads.imagenet_vit.imagenet_jax import models +from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload +from algoperf.workloads.imagenet_vit.workload import decode_variant # Make sure we inherit from the ViT base workload first. @@ -28,7 +27,7 @@ def initialized(self, key: spec.RandomState, variables = jax.jit( model.init)({'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") return params, model_state def init_model_fn( diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/__init__.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/__init__.py rename to algoperf/workloads/imagenet_vit/imagenet_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py rename to algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 02d708da8..fcf0992d3 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -12,10 +12,9 @@ from torch import nn import torch.nn.functional as F -from algorithmic_efficiency import init_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ - MultiheadAttention +from algoperf import init_utils +from algoperf import spec +from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py similarity index 85% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py rename to algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 703d40b07..97bb38515 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -6,17 +6,14 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch import \ - models -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ - BaseImagenetVitWorkload -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ - decode_variant +from algoperf.workloads.imagenet_vit.imagenet_pytorch import models +from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload +from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py similarity index 92% rename from algorithmic_efficiency/workloads/imagenet_vit/workload.py rename to algoperf/workloads/imagenet_vit/workload.py index ed0118ca0..f249ddee8 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -2,8 +2,8 @@ from typing import Dict, Iterator, Optional -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload @@ -81,7 +81,7 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 77_520 # ~22 hours + return 69_768 # ~19.4 hours @property def eval_period_time_sec(self) -> int: @@ -110,5 +110,5 @@ def _build_dataset( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 186_666 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 167_999 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/__init__.py b/algoperf/workloads/librispeech_conformer/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/__init__.py rename to algoperf/workloads/librispeech_conformer/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/input_pipeline.py rename to algoperf/workloads/librispeech_conformer/input_pipeline.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/__init__.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/__init__.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py similarity index 98% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index cb6287c5e..593d463c3 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -22,9 +22,9 @@ import jax.numpy as jnp import numpy as np -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter @@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train): mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) inputs = LayerNorm(dim=config.encoder_dim)(inputs) - attention_fn = functools.partial( dot_product_attention, temperature=config.attention_temperature) - result = nn.SelfAttention( + result = nn.MultiHeadDotProductAttention( num_heads=config.num_attention_heads, qkv_features=config.encoder_dim, decode=False, @@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train): broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=config.attention_dropout_rate, - deterministic=not train)(inputs, attention_mask) + deterministic=not train)( + inputs_q=inputs, mask=attention_mask) if config.attention_residual_dropout_rate is None: attention_residual_dropout_rate = 0.1 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py similarity index 94% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index e362f973b..39012a20d 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils +from flax.core import pop import flax.linen as nn import jax from jax import lax @@ -11,15 +12,14 @@ import optax import torch -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer import metrics -from algorithmic_efficiency.workloads.librispeech_conformer import workload -from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.librispeech_conformer import metrics +from algoperf.workloads.librispeech_conformer import workload +from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ - models +from algoperf.workloads.librispeech_conformer.librispeech_jax import models class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): @@ -89,7 +89,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -226,11 +226,12 @@ def ctc_loss(self, labels: spec.Tensor, label_paddings: spec.Tensor, blank_id: int = 0) -> spec.Tensor: - return optax.ctc_loss(logits, - logit_paddings, - labels, - label_paddings, - blank_id) + return optax.ctc_loss( + logits=logits, + logit_paddings=logit_paddings, + labels=labels, + label_paddings=label_paddings, + blank_id=blank_id) # Adapted from lingvo's greedy decoding logic here: # https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138. @@ -378,8 +379,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/__init__.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/__init__.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index 61400806a..db1e24521 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -12,9 +12,9 @@ from torch.nn import init import torch.nn.functional as F -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..5ed37957e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -10,17 +10,16 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.librispeech_conformer import metrics -from algorithmic_efficiency.workloads.librispeech_conformer import workload -from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.librispeech_conformer import metrics +from algoperf.workloads.librispeech_conformer import workload +from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \ - models +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -166,7 +165,7 @@ def _build_input_queue( ds = LibriSpeechDataset(split=ds_split, data_dir=data_dir) if split == 'eval_train': indices = list(range(len(ds))) - random.Random(data_rng[0]).shuffle(indices) + random.Random(int(data_rng[0])).shuffle(indices) ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples]) sampler = None diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/metrics.py rename to algoperf/workloads/librispeech_conformer/metrics.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py similarity index 92% rename from algorithmic_efficiency/workloads/librispeech_conformer/workload.py rename to algoperf/workloads/librispeech_conformer/workload.py index c2413c076..94f01dd97 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -1,7 +1,7 @@ import math from typing import Dict -from algorithmic_efficiency import spec +from algoperf import spec class BaseLibrispeechWorkload(spec.Workload): @@ -79,7 +79,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 61_068 # ~17 hours + return 58_015 # ~16.1 hours @property def eval_period_time_sec(self) -> int: @@ -87,5 +87,5 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 80_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 76_000 diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/__init__.py b/algoperf/workloads/librispeech_deepspeech/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/__init__.py rename to algoperf/workloads/librispeech_deepspeech/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/__init__.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/__init__.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py similarity index 99% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index f9eb732e9..b116f44cd 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -16,9 +16,9 @@ from jax.experimental import rnn import jax.numpy as jnp -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter Array = jnp.ndarray diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py similarity index 92% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index a0db6d607..d3b616f43 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -6,12 +6,11 @@ import jax.numpy as jnp import numpy as np -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax import \ - models +from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): @@ -100,12 +99,12 @@ def test_target_value(self) -> float: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 38_400 @property def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours + return 44_405 # ~12.3 hours @property def use_tanh(self) -> bool: diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index bdf556f1c..84d317326 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -11,9 +11,9 @@ import torch.distributed.nn as dist_nn import torch.nn.functional as F -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py similarity index 85% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 626bac278..e5387f5cb 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -3,16 +3,16 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechConfig -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechEncoderDecoder USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -76,12 +76,12 @@ def test_target_value(self) -> float: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 38_400 @property def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours + return 44_405 # ~12.3 hours @property def use_tanh(self) -> bool: diff --git a/algorithmic_efficiency/workloads/mnist/__init__.py b/algoperf/workloads/mnist/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/mnist/__init__.py rename to algoperf/workloads/mnist/__init__.py diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/__init__.py b/algoperf/workloads/mnist/mnist_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/mnist/mnist_jax/__init__.py rename to algoperf/workloads/mnist/mnist_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py rename to algoperf/workloads/mnist/mnist_jax/workload.py index efbd73e33..5a4382da1 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,9 +10,9 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): @@ -132,4 +132,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/__init__.py b/algoperf/workloads/mnist/mnist_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/mnist/mnist_pytorch/__init__.py rename to algoperf/workloads/mnist/mnist_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py rename to algoperf/workloads/mnist/mnist_pytorch/workload.py index e638df078..780e1bca0 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py @@ -10,11 +10,11 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import init_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload +from algoperf import init_utils +from algoperf import param_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.mnist.workload import BaseMnistWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/mnist/workload.py rename to algoperf/workloads/mnist/workload.py index dcc195170..f53aadd0b 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -10,10 +10,10 @@ import tensorflow_datasets as tfds import torch -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -import algorithmic_efficiency.random_utils as prng +from algoperf import data_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +import algoperf.random_utils as prng USE_PYTORCH_DDP, _, _, _ = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/ogbg/__init__.py b/algoperf/workloads/ogbg/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/__init__.py rename to algoperf/workloads/ogbg/__init__.py diff --git a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py similarity index 97% rename from algorithmic_efficiency/workloads/ogbg/input_pipeline.py rename to algoperf/workloads/ogbg/input_pipeline.py index a301d677a..3cb6f51de 100644 --- a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): def _to_jraph(example): """Converts an example graph to jraph.GraphsTuple.""" - example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access + example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access edge_feat = example['edge_feat'] node_feat = example['node_feat'] edge_index = example['edge_index'] @@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): if count == num_shards: def f(x): - return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) + return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) diff --git a/algorithmic_efficiency/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py similarity index 98% rename from algorithmic_efficiency/workloads/ogbg/metrics.py rename to algoperf/workloads/ogbg/metrics.py index a654eb2ae..55f83d905 100644 --- a/algorithmic_efficiency/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -11,7 +11,7 @@ import torch import torch.distributed as dist -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/__init__.py b/algoperf/workloads/ogbg/ogbg_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/ogbg_jax/__init__.py rename to algoperf/workloads/ogbg/ogbg_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py rename to algoperf/workloads/ogbg/ogbg_jax/models.py diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py rename to algoperf/workloads/ogbg/ogbg_jax/workload.py index ec0c0658d..e895d15a7 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,11 +8,11 @@ import jraph import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg import metrics -from algorithmic_efficiency.workloads.ogbg.ogbg_jax import models -from algorithmic_efficiency.workloads.ogbg.workload import BaseOgbgWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_jax import models +from algoperf.workloads.ogbg.workload import BaseOgbgWorkload class OgbgWorkload(BaseOgbgWorkload): diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/__init__.py b/algoperf/workloads/ogbg/ogbg_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/__init__.py rename to algoperf/workloads/ogbg/ogbg_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py similarity index 99% rename from algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py rename to algoperf/workloads/ogbg/ogbg_pytorch/models.py index d93013b87..fe9b29bc1 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -8,7 +8,7 @@ import torch from torch import nn -from algorithmic_efficiency import init_utils +from algoperf import init_utils def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py rename to algoperf/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..45295ac7f 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -8,20 +8,20 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg import metrics -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.models import GNN -from algorithmic_efficiency.workloads.ogbg.workload import BaseOgbgWorkload +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN +from algoperf.workloads.ogbg.workload import BaseOgbgWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: - return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) - return jax.tree_map( + return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) + return jax.tree.map( lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1), inputs) @@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any: def _shard(inputs: Any) -> Any: if not USE_PYTORCH_DDP: return inputs - return jax.tree_map(lambda tensor: tensor[RANK], inputs) + return jax.tree.map(lambda tensor: tensor[RANK], inputs) def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple: diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py similarity index 94% rename from algorithmic_efficiency/workloads/ogbg/workload.py rename to algoperf/workloads/ogbg/workload.py index a32f385cb..971e7f0f6 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -7,10 +7,10 @@ import jax -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg import input_pipeline -from algorithmic_efficiency.workloads.ogbg import metrics +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.workloads.ogbg import input_pipeline +from algoperf.workloads.ogbg import metrics class BaseOgbgWorkload(spec.Workload): @@ -88,7 +88,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 18_477 # ~5 hours + return 12_011 # ~3.3 hours @property def eval_period_time_sec(self) -> int: @@ -140,8 +140,8 @@ def loss_fn( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 80_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 52_000 @abc.abstractmethod def _normalize_eval_metrics( diff --git a/algorithmic_efficiency/workloads/utils.py b/algoperf/workloads/utils.py similarity index 100% rename from algorithmic_efficiency/workloads/utils.py rename to algoperf/workloads/utils.py diff --git a/algorithmic_efficiency/workloads/wmt/__init__.py b/algoperf/workloads/wmt/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/__init__.py rename to algoperf/workloads/wmt/__init__.py diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py new file mode 100644 index 000000000..ad314a7d3 --- /dev/null +++ b/algoperf/workloads/wmt/bleu.py @@ -0,0 +1,462 @@ +""" +Removing the dependency on sacrebleu, we reimplement the BLEU score computation +in this file. +Reference: +https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. +""" + +from collections import Counter +from collections import namedtuple +from itertools import zip_longest +import logging +import math +import re +import sys +from typing import List, Sequence +import unicodedata + +from absl import logging +import torch +import torch.distributed as dist + +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() + +NGRAM_ORDER = 4 +# The default floor value to use with `--smooth floor` +SMOOTH_VALUE_DEFAULT = 0.0 + + +def my_log(num): + """ + Floors the log function + + :param num: the number + :return: log(num) floored to a very low number + """ + + if num == 0.0: + return -9999999999 + return math.log(num) + + +def tokenize_13a(line): + """ + Tokenizes an input line using a relatively minimal tokenization that is + however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + norm = line + + # language-independent part: + norm = norm.replace('', '') + norm = norm.replace('-\n', '') + norm = norm.replace('\n', ' ') + norm = norm.replace('"', '"') + norm = norm.replace('&', '&') + norm = norm.replace('<', '<') + norm = norm.replace('>', '>') + + # language-dependent part (assuming Western languages): + norm = " {} ".format(norm) + norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) + norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', + norm) # tokenize period and comma unless preceded by a digit + norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', + norm) # tokenize period and comma unless followed by a digit + norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', + norm) # tokenize dash when preceded by a digit + norm = re.sub(r'\s+', ' ', norm) # one space only between words + norm = re.sub(r'^\s+', '', norm) # no leading space + norm = re.sub(r'\s+$', '', norm) # no trailing space + + return norm + + +class UnicodeRegex: + """Ad-hoc hack to recognize all punctuation and symbols. + + without depending on https://pypi.python.org/pypi/regex/.""" + + @staticmethod + def _property_chars(prefix): + return ''.join( + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix)) + + punctuation = _property_chars('P') + nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') + punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])') + symbol_re = re.compile('([' + _property_chars('S') + '])') + + +def tokenize_v14_international(string): + r"""Tokenize a string following the official BLEU implementation. + + See + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence + is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses, + `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ + string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) + string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) + string = UnicodeRegex.symbol_re.sub(r' \1 ', string) + return string.strip() + + +def tokenize_zh(sentence): + """MIT License + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: + separate each Chinese characters (by utf-8 encoding); + tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ + + def is_chinese_char(uchar): + """ + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if "\u3400" <= uchar <= "\u4db5": + return True + elif "\u4e00" <= uchar <= "\u9fa5": + return True + elif "\u9fa6" <= uchar <= "\u9fbb": + return True + elif "\uf900" <= uchar <= "\ufa2d": + return True + elif "\ufa30" <= uchar <= "\ufa6a": + return True + elif "\ufa70" <= uchar <= "\ufad9": + return True + elif "\u20000" <= uchar <= "\u2a6d6": + return True + elif "\u2f800" <= uchar <= "\u2fa1d": + return True + elif "\uff00" <= uchar <= "\uffef": + return True + elif "\u2e80" <= uchar <= "\u2eff": + return True + elif "\u3000" <= uchar <= "\u303f": + return True + elif "\u31c0" <= uchar <= "\u31ef": + return True + elif "\u2f00" <= uchar <= "\u2fdf": + return True + elif "\u2ff0" <= uchar <= "\u2fff": + return True + elif "\u3100" <= uchar <= "\u312f": + return True + elif "\u31a0" <= uchar <= "\u31bf": + return True + elif "\ufe10" <= uchar <= "\ufe1f": + return True + elif "\ufe30" <= uchar <= "\ufe4f": + return True + elif "\u2600" <= uchar <= "\u26ff": + return True + elif "\u2700" <= uchar <= "\u27bf": + return True + elif "\u3200" <= uchar <= "\u32ff": + return True + elif "\u3300" <= uchar <= "\u33ff": + return True + return False + + sentence = sentence.strip() + sentence_in_chars = "" + for char in sentence: + if is_chinese_char(char): + sentence_in_chars += " " + sentence_in_chars += char + sentence_in_chars += " " + else: + sentence_in_chars += char + sentence = sentence_in_chars + + # TODO: the code above could probably be replaced with the following line: + # import regex + # sentence = regex.sub(r'(\p{Han})', r' \1 ', sentence) + + # tokenize punctuation + sentence = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sentence) + + # tokenize period and comma unless preceded by a digit + sentence = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sentence) + + # tokenize period and comma unless followed by a digit + sentence = re.sub(r'([\.,])([^0-9])', r' \1 \2', sentence) + + # tokenize dash when preceded by a digit + sentence = re.sub(r'([0-9])(-)', r'\1 \2 ', sentence) + + # one space only between words + sentence = re.sub(r'\s+', r' ', sentence) + + # no leading or trailing spaces + sentence = sentence.strip() + + return sentence + + +TOKENIZERS = { + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, +} +DEFAULT_TOKENIZER = '13a' + + +def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: + """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. + + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ + + ngrams = Counter() + tokens = line.split() + for n in range(min_order, max_order + 1): + for i in range(0, len(tokens) - n + 1): + ngram = ' '.join(tokens[i:i + n]) + ngrams[ngram] += 1 + + return ngrams + + +def ref_stats(output, refs): + ngrams = Counter() + closest_diff = None + closest_len = None + for ref in refs: + tokens = ref.split() + reflen = len(tokens) + diff = abs(len(output.split()) - reflen) + if closest_diff is None or diff < closest_diff: + closest_diff = diff + closest_len = reflen + elif diff == closest_diff: + if reflen < closest_len: + closest_len = reflen + + ngrams_ref = extract_ngrams(ref) + for ngram in ngrams_ref: + ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) + + return ngrams, closest_diff, closest_len + + +BLEU = namedtuple('BLE', + 'score, counts, totals, precisions, bp, sys_len, ref_len') + + +def compute_bleu(correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False) -> BLEU: + """Computes BLEU score from its sufficient statistics. Adds smoothing. + + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques + for Sentence-Level BLEU", Boxing Chen and Colin Cherry, + WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth is 'floor' + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ + + precisions = [0 for x in range(NGRAM_ORDER)] + + smooth_mteval = 1. + effective_order = NGRAM_ORDER + for n in range(NGRAM_ORDER): + if smooth_method == 'add-k' and n > 1: + correct[n] += smooth_value + total[n] += smooth_value + if total[n] == 0: + break + + if use_effective_order: + effective_order = n + 1 + + if correct[n] == 0: + if smooth_method == 'exp': + smooth_mteval *= 2 + precisions[n] = 100. / (smooth_mteval * total[n]) + elif smooth_method == 'floor': + precisions[n] = 100. * smooth_value / total[n] + else: + precisions[n] = 100. * correct[n] / total[n] + + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU + # score is 0 (technically undefined). This is a problem for sentence-level + # BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales + # NGRAM_ORDER to the observed maximum order. + # It is only available through the API and off by default + + brevity_penalty = 1.0 + if sys_len < ref_len: + brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 + + bleu = brevity_penalty * math.exp( + sum(map(my_log, precisions[:effective_order])) / effective_order) + + return BLEU._make( + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) + + +def corpus_bleu(sys_stream: Sequence[str], + ref_streams: Sequence[str], + smooth_method: str = 'exp', + smooth_value: float = 0.0, + force: bool = False, + lowercase: bool = False, + tokenize: str = '13a', + use_effective_order: bool = False) -> BLEU: + """Produces BLEU scores along with its sufficient statistics from a source + against one or more references. + :param sys_stream: The system stream (a sequence of segments). + :param ref_streams: A list of one or more reference streams + (each a sequence of segments). + :param smooth: The smoothing method to use. + :param smooth_value: For 'floor' smoothing, the floor to use. + :param force: Ignore data that looks already tokenized. + :param lowercase: Lowercase the data. + :param tokenize: The tokenizer to use. + :return: A BLEU object containing everything yo'd want. + """ + + # Add some robustness to the input arguments. + if isinstance(sys_stream, str): + sys_stream = [sys_stream] + if isinstance(ref_streams, str): + ref_streams = [[ref_streams]] + + sys_len = 0 + ref_len = 0 + + correct = [0 for _ in range(NGRAM_ORDER)] + total = [0 for _ in range(NGRAM_ORDER)] + + # Look for already-tokenized sentences. + tokenized_count = 0 + + fhs = [sys_stream] + ref_streams + for lines in zip_longest(*fhs): + if None in lines: + raise EOFError('Source and reference streams have different lengths!') + + if lowercase: + lines = [x.lower() for x in lines] + + if not (force or tokenize == 'none') and lines[0].rstrip().endswith(' .'): + tokenized_count += 1 + + if tokenized_count == 100: + logging.warning( + 'That\'s 100 lines that end in a tokenized period (\'.\')') + logging.warning('It looks like you forgot to detokenize your test ' + 'data, which may hurt your score.') + logging.warning('If you insist your data is detokenized, ' + 'or don\'t care, you can suppress this message with ' + '\'--force\'.') + + output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] + + ref_ngrams, _, closest_len = ref_stats(output, refs) + + sys_len += len(output.split()) + ref_len += closest_len + + sys_ngrams = extract_ngrams(output) + for ngram, sys_ngram in sys_ngrams.items(): + n = len(ngram.split()) + correct[n - 1] += min(sys_ngram, ref_ngrams.get(ngram, 0)) + total[n - 1] += sys_ngram + + # When using PyTorch DDP, get stats from all processes and sum them. + if USE_PYTORCH_DDP: + # Sum `sys_len` and `ref_len` integers from all processes. + sys_len = torch.tensor(sys_len, dtype=torch.int64, device=DEVICE) + dist.all_reduce(sys_len) + sys_len = sys_len.item() + ref_len = torch.tensor(ref_len, dtype=torch.int64, device=DEVICE) + dist.all_reduce(ref_len) + ref_len = ref_len.item() + # Sum `correct` and `total` sequences from all processes. + correct = torch.tensor(correct, dtype=torch.int64, device=DEVICE) + dist.all_reduce(correct) + correct = correct.cpu().numpy().tolist() + total = torch.tensor(total, dtype=torch.int64, device=DEVICE) + dist.all_reduce(total) + total = total.cpu().numpy().tolist() + + return compute_bleu( + correct, + total, + sys_len, + ref_len, + smooth_method=smooth_method, + smooth_value=smooth_value, + use_effective_order=use_effective_order) diff --git a/algorithmic_efficiency/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/wmt/input_pipeline.py rename to algoperf/workloads/wmt/input_pipeline.py index af1c54994..d743b43b0 100644 --- a/algorithmic_efficiency/workloads/wmt/input_pipeline.py +++ b/algoperf/workloads/wmt/input_pipeline.py @@ -6,9 +6,9 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import data_utils -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.wmt import tokenizer +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.wmt import tokenizer RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). diff --git a/algorithmic_efficiency/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/tokenizer.py rename to algoperf/workloads/wmt/tokenizer.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/__init__.py b/algoperf/workloads/wmt/wmt_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/__init__.py rename to algoperf/workloads/wmt/wmt_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py similarity index 98% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py rename to algoperf/workloads/wmt/wmt_jax/decode.py index 85d0eaac4..dfead5918 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py +++ b/algoperf/workloads/wmt/wmt_jax/decode.py @@ -86,7 +86,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @@ -139,7 +139,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -225,7 +225,7 @@ def beam_search_loop_body_fn(state): (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -236,7 +236,7 @@ def beam_search_loop_body_fn(state): logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py similarity index 97% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/models.py rename to algoperf/workloads/wmt/wmt_jax/models.py index e4b5cd014..97fee032f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -222,9 +222,8 @@ def __call__(self, inputs, encoder_mask=None): use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * x, - x, - encoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * x, x, mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +287,8 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask) + decode=cfg.decode)( + cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -309,9 +309,8 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * y, - encoded, - encoder_decoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py rename to algoperf/workloads/wmt/wmt_jax/workload.py index 046d5e469..cdfcb91df 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -13,12 +13,12 @@ import numpy as np import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu -from algorithmic_efficiency.workloads.wmt.wmt_jax import decode -from algorithmic_efficiency.workloads.wmt.wmt_jax import models -from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.wmt import bleu +from algoperf.workloads.wmt.wmt_jax import decode +from algoperf.workloads.wmt.wmt_jax import models +from algoperf.workloads.wmt.workload import BaseWmtWorkload def _to_host(x: spec.Tensor) -> spec.Tensor: @@ -94,7 +94,7 @@ def eval_step(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree_map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) + return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) @functools.partial( jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) @@ -291,7 +291,7 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/__init__.py b/algoperf/workloads/wmt/wmt_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/__init__.py rename to algoperf/workloads/wmt/wmt_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py b/algoperf/workloads/wmt/wmt_pytorch/decode.py similarity index 98% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py rename to algoperf/workloads/wmt/wmt_pytorch/decode.py index 0488a144f..26ff36650 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py +++ b/algoperf/workloads/wmt/wmt_pytorch/decode.py @@ -10,7 +10,7 @@ import torch import torch.nn.functional as F -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup DEVICE = pytorch_setup()[2] @@ -98,7 +98,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested: Dict[str, Any], @@ -164,7 +164,7 @@ def beam_init(batch_size: int, dtype=torch.bool, device=DEVICE) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -251,7 +251,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1]) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -262,7 +262,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py rename to algoperf/workloads/wmt/wmt_pytorch/models.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py rename to algoperf/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..d0716d6c8 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -12,13 +12,13 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu -from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer -from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.wmt import bleu +from algoperf.workloads.wmt.wmt_pytorch import decode +from algoperf.workloads.wmt.wmt_pytorch.models import Transformer +from algoperf.workloads.wmt.workload import BaseWmtWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() @@ -347,7 +347,7 @@ def _normalize_eval_metrics( dist.all_reduce(metric) total_metrics = {k: v.item() for k, v in total_metrics.items()} eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/algorithmic_efficiency/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/wmt/workload.py rename to algoperf/workloads/wmt/workload.py index 68ebdc94b..51b33373d 100644 --- a/algorithmic_efficiency/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -9,9 +9,9 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import input_pipeline -from algorithmic_efficiency.workloads.wmt.wmt_jax import decode +from algoperf import spec +from algoperf.workloads.wmt import input_pipeline +from algoperf.workloads.wmt.wmt_jax import decode VOCAB_PATH = './wmt_256/sentencepiece_model' WORKDIR = './wmt_256' @@ -88,7 +88,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 48_151 # ~13.5 hours + return 43_336 # ~12.0 hours @property def eval_period_time_sec(self) -> int: @@ -96,8 +96,8 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 133_333 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 120_000 @property def pre_ln(self) -> bool: diff --git a/algorithmic_efficiency/workloads/workloads.py b/algoperf/workloads/workloads.py similarity index 98% rename from algorithmic_efficiency/workloads/workloads.py rename to algoperf/workloads/workloads.py index bb57f598e..4712f4e25 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -4,9 +4,9 @@ import inspect import os -from algorithmic_efficiency import spec +from algoperf import spec -BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { 'cifar': { diff --git a/algorithmic_efficiency/__init__.py b/algorithmic_efficiency/__init__.py deleted file mode 100644 index a0e473e1d..000000000 --- a/algorithmic_efficiency/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Algorithmic Efficiency.""" - -__version__ = '0.1.0' diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py deleted file mode 100644 index 1efc87381..000000000 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ /dev/null @@ -1,110 +0,0 @@ -from itertools import zip_longest -from typing import Sequence - -from absl import logging -import sacrebleu -import torch -import torch.distributed as dist - -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() - - -# Modified (added sync for PyTorch DDP) from -# https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. -# Assumes that sacrebleu==1.3.1 is installed. -def corpus_bleu(sys_stream: Sequence[str], - ref_streams: Sequence[str], - smooth_method: str = 'exp', - smooth_value: float = 0.0, - force: bool = False, - lowercase: bool = False, - tokenize: str = '13a', - use_effective_order: bool = False) -> sacrebleu.BLEU: - """Produces BLEU scores along with its sufficient statistics from a source - against one or more references. - :param sys_stream: The system stream (a sequence of segments). - :param ref_streams: A list of one or more reference streams - (each a sequence of segments). - :param smooth: The smoothing method to use. - :param smooth_value: For 'floor' smoothing, the floor to use. - :param force: Ignore data that looks already tokenized. - :param lowercase: Lowercase the data. - :param tokenize: The tokenizer to use. - :return: A BLEU object containing everything you'd want. - """ - - # Add some robustness to the input arguments. - if isinstance(sys_stream, str): - sys_stream = [sys_stream] - if isinstance(ref_streams, str): - ref_streams = [[ref_streams]] - - sys_len = 0 - ref_len = 0 - - correct = [0 for _ in range(sacrebleu.NGRAM_ORDER)] - total = [0 for _ in range(sacrebleu.NGRAM_ORDER)] - - # Look for already-tokenized sentences. - tokenized_count = 0 - - fhs = [sys_stream] + ref_streams - for lines in zip_longest(*fhs): - if None in lines: - raise EOFError('Source and reference streams have different lengths!') - - if lowercase: - lines = [x.lower() for x in lines] - - if not (force or tokenize == 'none') and lines[0].rstrip().endswith(' .'): - tokenized_count += 1 - - if tokenized_count == 100: - logging.warning( - 'That\'s 100 lines that end in a tokenized period (\'.\')') - logging.warning('It looks like you forgot to detokenize your test ' - 'data, which may hurt your score.') - logging.warning('If you insist your data is detokenized, ' - 'or don\'t care, you can suppress this message with ' - '\'--force\'.') - - output, *refs = [sacrebleu.TOKENIZERS[tokenize](x.rstrip()) for x in lines] - - ref_ngrams, _, closest_len = sacrebleu.ref_stats(output, refs) - - sys_len += len(output.split()) - ref_len += closest_len - - sys_ngrams = sacrebleu.extract_ngrams(output) - for ngram, sys_ngram in sys_ngrams.items(): - n = len(ngram.split()) - correct[n - 1] += min(sys_ngram, ref_ngrams.get(ngram, 0)) - total[n - 1] += sys_ngram - - # When using PyTorch DDP, get stats from all processes and sum them. - if USE_PYTORCH_DDP: - # Sum `sys_len` and `ref_len` integers from all processes. - sys_len = torch.tensor(sys_len, dtype=torch.int64, device=DEVICE) - dist.all_reduce(sys_len) - sys_len = sys_len.item() - ref_len = torch.tensor(ref_len, dtype=torch.int64, device=DEVICE) - dist.all_reduce(ref_len) - ref_len = ref_len.item() - # Sum `correct` and `total` sequences from all processes. - correct = torch.tensor(correct, dtype=torch.int64, device=DEVICE) - dist.all_reduce(correct) - correct = correct.cpu().numpy().tolist() - total = torch.tensor(total, dtype=torch.int64, device=DEVICE) - dist.all_reduce(total) - total = total.cpu().numpy().tolist() - - return sacrebleu.compute_bleu( - correct, - total, - sys_len, - ref_len, - smooth_method=smooth_method, - smooth_value=smooth_value, - use_effective_order=use_effective_order) diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 5b43a3f87..efe923dbe 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -71,8 +71,8 @@ import tensorflow_datasets as tfds from torchvision.datasets import CIFAR10 -from algorithmic_efficiency.workloads.wmt import tokenizer -from algorithmic_efficiency.workloads.wmt.input_pipeline import \ +from algoperf.workloads.wmt import tokenizer +from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer diff --git a/docker/Dockerfile b/docker/Dockerfile index 47277d440..76bc5cfe0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,6 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar - RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg # Install prerequisites @@ -25,7 +24,8 @@ RUN apt-get update && apt-get install -y \ libffi-dev \ curl \ libbz2-dev \ - liblzma-dev + liblzma-dev \ + vim # Download and install Python 3.11 RUN cd /tmp \ @@ -56,6 +56,8 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ +RUN pip install --upgrade pip + # Install Algorithmic efficiency repo RUN pip install --upgrade pip diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 527e8306a..1dbba9565 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -156,7 +156,7 @@ fi if [[ ${TEST} == "true" ]]; then cd algorithmic-efficiency - COMMAND="python3 tests/test_traindiffs.py" + COMMAND="python tests/test_traindiffs.py" echo $COMMAND eval $COMMAND exit @@ -209,7 +209,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" # Set run command prefix depending on framework if [[ "${FRAMEWORK}" == "jax" ]]; then - COMMAND_PREFIX="python3" + COMMAND_PREFIX="python" else COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" fi diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index a235c50cd..c451a18ac 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -302,6 +302,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 06413f681..b8ac10f33 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -302,6 +302,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 0e654d43c..a2f9fb4c5 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -304,6 +304,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index dd0b8b076..a37b0d341 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -304,6 +304,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index a9f048f03..78c3b5b3e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -317,6 +317,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 4d3d2b341..ffe854a0e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -317,6 +317,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 5a5319957..3ef286877 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -319,6 +319,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 699b11268..e9f8810a6 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -319,6 +319,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..cc404f4b5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,334 @@ +############################################################################### +# MLCommons Algorithmic Efficiency. # +############################################################################### + +[project] +name = "algoperf" +dynamic = ["version"] +description = "Codebase for the AlgoPerf: Training Algorithms benchmark" +authors = [ + { name = "MLCommons Algorithms Working Group", email = "algorithms@mlcommons.org" }, +] +license = { text = "Apache 2.0" } +readme = "README.md" +requires-python = ">=3.11" +keywords = [ + "algoperf", + "algorithmic-efficiency", + "machine-learning", + "deep-learning", + "optimization", + "benchmarking", + "training-methods", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "absl-py==2.1.0", + "networkx==3.2.1", + "docker==7.1.0", + "numpy>=2.0.2", + "pandas>=2.0.1", + "tensorflow==2.18.0", + "tensorflow-datasets==4.9.7", + "tensorflow-probability==0.20.0", + "tensorflow-addons==0.20.0", + "gputil==1.4.0", + "psutil==6.1.0", + "clu==0.0.12", + "matplotlib>=3.9.2", + "tabulate==0.9.0", + +] + +[build-system] +requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +py-modules = ["submission_runner"] +include-package-data = true +zip-safe = false + +[tool.setuptools.packages] +find = {} # Scanning implicit namespaces is active by default + +[tool.setuptools_scm] +version_file = "algoperf/_version.py" + +############################################################################### +# (Optional) Dependencies # +############################################################################### +[project.optional-dependencies] +# All workloads +full = [ + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", +] +# All workloads plus development dependencies +full_dev = ["algoperf[full,dev]"] +# Dependencies for developing the package +dev = [ + "isort==5.12.0", + "pylint==2.17.4", + "pytest==8.3.3", + "yapf==0.32.0", + "pre-commit==4.0.1", +] + +wandb = ["wandb==0.19.6"] + +# Workloads +criteo1tb = ["scikit-learn==1.5.2"] +fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"] +ogbg = ["jraph==0.0.6.dev0", "scikit-learn==1.5.2"] +librispeech_conformer = [ + "sentencepiece==0.2.0", + "tensorflow-text==2.18.0", + "pydub==0.25.1", +] +wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] + +# Frameworks +jax_core_deps = [ + "flax==0.8.4", + "optax==0.2.2", + "chex==0.1.86", + "ml_dtypes==0.4.1", + "protobuf==4.25.5", +] +jax_cpu = [ + "jax==0.4.28", + "jaxlib==0.4.28", + "algoperf[jax_core_deps]", +] +jax_gpu = [ + "jax==0.4.28", + "jaxlib==0.4.28", + "jax-cuda12-plugin[with_cuda]==0.4.28", + "jax-cuda12-pjrt==0.4.28", + "algoperf[jax_core_deps]", +] +pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] +pytorch_gpu = [ + "torch==2.5.1", + "torchvision==0.20.1", +] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. + +############################################################################### +# Linting Configurations # +############################################################################### + +# yapf configuration +[tool.yapf] +based_on_style = "yapf" +each_dict_entry_on_separate_line = false +split_all_top_level_comma_separated_values = true +[tool.yapfignore] +ignore_patterns = ["algoperf/_version.py"] + +# isort configuration +[tool.isort] +profile = "google" + +# pylint configuration +[tool.pylint.MASTER] +persistent = false +ignore = "get_references_web.py,get_references_web_single_group.py,_version.py" + +[tool.pylint.REPORTS] +reports = false +msg-template = "{msg_id}:{line:3} {obj}: {msg} [{symbol}]" + +[tool.pylint.MESSAGES_CONTROL] +enable = "indexing-exception,old-raise-syntax" + +[tool.pylint.BASIC] +# Required attributes for module, separated by a comma +#required-attributes= +# Regular expression which should only match the name +# of functions or classes which do not require a docstring. +no-docstring-rgx = "(__.*__|main)" +# Min length in lines of a function that requires a docstring. +docstring-min-length = 10 +# Regular expression which should only match correct module names. The +# leading underscore is sanctioned for private modules by Google's style +# guide. +# +# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover +# requirements of Python's module system. +module-rgx = "^(_?[a-z][a-z0-9_]*)|__init__$" +# Regular expression which should only match correct module level names +const-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" +# Regular expression which should only match correct class attribute +class-attribute-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" +# Regular expression which should only match correct class names +class-rgx = "^_?[A-Z][a-zA-Z0-9]*$" +# Regular expression which should only match correct function names. +# 'camel_case' and 'snake_case' group names are used for consistency of naming +# styles across functions and methods. +function-rgx = "^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$" +# Regular expression which should only match correct method names. +# 'camel_case' and 'snake_case' group names are used for consistency of naming +# styles across functions and methods. 'exempt' indicates a name which is +# consistent with all naming styles. +method-rgx = "(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass|test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|(?:test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$" +# Regular expression which should only match correct instance attribute names +attr-rgx = "^_{0,2}[a-z][a-z0-9_]*$" +# Regular expression which should only match correct argument names +argument-rgx = "^[a-z][a-z0-9_]*$" +# Regular expression which should only match correct variable names +variable-rgx = "^[a-z][a-z0-9_]*$" +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx = "^[a-z][a-z0-9_]*$" +# Good variable names which should always be accepted, separated by a comma +good-names = "main,_" +# Bad variable names which should always be refused, separated by a comma +bad-names = "" +# List of builtins function names that should not be used, separated by a comma +#bad-functions=input,apply,reduce +# List of decorators that define properties, such as abc.abstractproperty. +property-classes = "abc.abstractproperty" + +[tool.pylint.typecheck] +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members = true + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators = [ + "contextlib.contextmanager", + "contextlib2.contextmanager", +] + +[tool.pylint.VARIABLES] +# Tells whether we should check for unused import in __init__ files. +init-import = false + +# A regular expression matching names used for dummy variables (i.e. not used). +dummy-variables-rgx = "^\\*{0,2}(_$|unused_|dummy_)" + +# List of additional names supposed to be defined in builtins. +additional-builtins = [] + +[tool.pylint.CLASSES] +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods = ["__init__", "__new__", "setUp"] + +# Valid names for the first argument to a class method. +valid-classmethod-first-arg = ["cls", "class_"] + +[tool.pylint.EXCEPTIONS] +overgeneral-exceptions = [ + "builtins.StandardError", + "builtins.Exception", + "builtins.BaseException", +] + +[tool.pylint.IMPORTS] +# Deprecated modules which should not be used, separated by a comma +deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec", "sets"] + +[tool.pylint.FORMAT] +# List of checkers and warnings to disable. +disable = [ + "abstract-method", + "access-member-before-definition", + "arguments-differ", + "assignment-from-no-return", + "attribute-defined-outside-init", + "bad-mcs-classmethod-argument", + "bad-option-value", + "c-extension-no-member", + "consider-merging-isinstance", + "consider-using-dict-comprehension", + "consider-using-enumerate", + "consider-using-in", + "consider-using-set-comprehension", + "consider-using-ternary", + "deprecated-method", + "design", + "file-ignored", + "fixme", + "global-statement", + "import-error", + "inconsistent-return-statements", + "invalid-unary-operand-type", + "len-as-condition", + "locally-disabled", + "locally-enabled", + "misplaced-comparison-constant", + "missing-docstring", + "multiple-imports", + "no-else-return", + "no-member", + "no-name-in-module", + "no-self-use", + "no-value-for-parameter", + "not-an-iterable", + "not-context-manager", + "pointless-except", + "protected-access", + "redefined-argument-from-local", + "signature-differs", + "similarities", + "simplifiable-if-expression", + "star-args", + "super-init-not-called", + "suppressed-message", + "too-many-function-args", + "trailing-comma-tuple", + "trailing-newlines", + "ungrouped-imports", + "unnecessary-pass", + "unsubscriptable-object", + "unused-argument", + "useless-object-inheritance", + "useless-return", + "useless-suppression", + "wrong-import-order", + "wrong-import-position", + "unneeded-not", + "unexpected-keyword-arg", + "redundant-keyword-arg", + "unspecified-encoding", + "logging-fstring-interpolation", + "consider-using-f-string", + "use-dict-literal", +] +# Maximum number of characters on a single line. +max-line-length = 80 +ignore-long-lines = "(?x)(^\\s*(import|from)\\s|^\\s*(\\#\\ )??$|^[a-zA-Z_][a-zA-Z0-9_]*\\s*=\\s*('[^']\\S+'|\"[^\"]\\S+\"))" +# Maximum number of lines in a module +max-module-lines = 99999 +# String used as indentation unit. We differ from PEP8's normal 4 spaces. +indent-string = ' ' +single-line-if-stmt = true +# Do not warn about multiple statements on a single line for constructs like +# if test: stmt +[tool.pylint.LOGGING] +logging-modules = "logging,absl.logging" +# Add logging modules. +[tool.pylint.MISCELLANEOUS] +# Maximum line length for lambdas +#short-func-length=1 +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren = 4 diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 97d6df9f1..3d8e35eaa 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): @@ -60,7 +60,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optimizer(hyperparameters, workload.num_train_examples) @@ -108,8 +108,6 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. def update_params( workload: spec.Workload, current_param_container: spec.ParameterContainer, @@ -137,6 +135,29 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. def data_selection(workload: spec.Workload, input_queue: Iterator[Dict[str, spec.Tensor]], optimizer_state: spec.OptimizerState, diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 853064957..d8b91f83a 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): @@ -99,6 +99,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 6d05954a1..c1f54597d 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): @@ -26,7 +26,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.chain( optax.scale_by_adam( @@ -109,6 +109,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), updated_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index d27d7f742..dedd96793 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -4,7 +4,7 @@ import torch -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): @@ -75,6 +75,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index 9f4da9132..ff98464ae 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -316,11 +316,11 @@ def to_state(self, count, result_tree): """Maps from a tree of (factored) values to separate trees of values.""" return ShardedAdafactorState( count=count, - m=jax.tree_map(lambda o: o.m, result_tree), - m_scale=jax.tree_map(lambda o: o.m_scale, result_tree), - vr=jax.tree_map(lambda o: o.vr, result_tree), - vc=jax.tree_map(lambda o: o.vc, result_tree), - v=jax.tree_map(lambda o: o.v, result_tree)) + m=jax.tree.map(lambda o: o.m, result_tree), + m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), + vr=jax.tree.map(lambda o: o.vr, result_tree), + vc=jax.tree.map(lambda o: o.vc, result_tree), + v=jax.tree.map(lambda o: o.v, result_tree)) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -667,7 +667,7 @@ def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( jnp.zeros([], jnp.int32), - jax.tree_map(sharded_adafactor_helper.init, params)) + jax.tree.map(sharded_adafactor_helper.init, params)) def update_fn(updates, state, params=None): if params is None: @@ -677,7 +677,7 @@ def update_fn(updates, state, params=None): compute_var_and_slot_update_fn = functools.partial( sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree_map(compute_var_and_slot_update_fn, + output = jax.tree.map(compute_var_and_slot_update_fn, updates, state.m, state.m_scale, @@ -685,7 +685,7 @@ def update_fn(updates, state, params=None): state.vc, state.v, params) - updates = jax.tree_map(lambda o: o.update, output) + updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) return updates, updated_states diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index efe238f26..1833ab8af 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \ sharded_adafactor @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): learning_rate=lr_schedule_fn, beta1=1.0 - hyperparameters.one_minus_beta1, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -160,6 +160,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 377468612..7aa457a25 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -10,8 +10,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -268,6 +268,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 31e0a6801..dde41fa6d 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -160,6 +160,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 27ceaeef7..21d9b6b57 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -9,8 +9,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -128,6 +128,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index be13ab540..70e305514 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -53,7 +53,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -102,7 +102,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -110,7 +110,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -168,6 +168,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index d3b491e75..c1c6cec0a 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec +from algoperf import spec # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py @@ -261,6 +261,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 3eef23942..cbb6d6dcd 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -194,6 +194,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index cf474ebdd..c3760d20e 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -8,8 +8,8 @@ import torch.distributed.nn as dist_nn from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -147,6 +147,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index a235c50cd..c451a18ac 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -302,6 +302,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 0e654d43c..a2f9fb4c5 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -304,6 +304,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 553b3e478..0e53aae42 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -194,6 +194,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index ba8c69e6c..b4432fbff 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -8,8 +8,8 @@ import torch.distributed.nn as dist_nn from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -147,6 +147,27 @@ def update_params( return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index b5c7069cb..b76589705 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -24,7 +24,7 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: """ gradient_norm = jnp.sqrt( sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) - normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) + normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -73,12 +73,12 @@ def update_fn(updates, state, grad_fn_params_tuple): # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), axis_name=batch_axis_name) - updates = jax.tree_map(lambda x: x / n_valid_examples, updates) + updates = jax.tree.map(lambda x: x / n_valid_examples, updates) if grad_clip: updates_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) - scaled_updates = jax.tree_map( + scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, @@ -136,7 +136,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): base_opt_update_fn=opt_update_fn) # Initialize optimizer state. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -186,7 +186,7 @@ def _loss_fn(params, update_batch_norm=True): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -247,6 +247,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index b69945d51..92603f036 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -9,8 +9,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -219,6 +219,27 @@ def _loss_fn(params, update_batch_norm=True): return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..a5c2732ac 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -342,7 +342,7 @@ def init_training_metrics( """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map( + return jax.tree.map( functools.partial(jnp.repeat, repeats=num_statistics), default_training_metrics()) @@ -356,14 +356,14 @@ def init_training_metrics_shapes( num_statistics, generate_training_metrics, ) - return jax.tree_map(lambda arr: [list(arr.shape), arr.dtype], seed) + return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) def init_training_metrics_pspec(generate_training_metrics,): """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map(lambda _: jax.sharding.PartitionSpec(), + return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), default_training_metrics()) @@ -1253,7 +1253,7 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start # pylint:disable=cell-var-from-loop Used immediately. - per_stat_metrics = jax.tree_map(lambda x: x[index_start:index_end], metrics) + per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics) # We don't want to update the metrics if we didn't do a new inverse p-th # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured @@ -1808,7 +1808,7 @@ def sharded_update_fn(grads, state, params): local_stat, )) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -1816,7 +1816,7 @@ def sharded_update_fn(grads, state, params): stats_flat, params_flat) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), @@ -1981,7 +1981,7 @@ def _init(param): )) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( @@ -2140,7 +2140,7 @@ def _internal_inverse_pth_root_all(): preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) preconditioners_flat = unbatch(preconditioners) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( all_statistics[0], @@ -2149,9 +2149,9 @@ def _internal_inverse_pth_root_all(): _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) - metrics = jax.tree_map( + metrics = jax.tree.map( functools.partial(jnp.expand_dims, axis=0), metrics) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2166,7 +2166,7 @@ def _internal_inverse_pth_root_all(): s[:, :precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, default_training_metrics().replace( inverse_pth_root_errors=inverse_failure_threshold)) @@ -2215,12 +2215,12 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=lambda x: isinstance(x, list)) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # If we skipped preconditioner computation, record old metrics. metrics_for_state = efficient_cond(perform_step, @@ -2441,7 +2441,7 @@ def update_fn(grads, state, params): if custom_preconditioner and grads_custom is not None: stats_grads = treedef.flatten_up_to(grads_custom) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -2452,7 +2452,7 @@ def update_fn(grads, state, params): new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 8f0b311a0..2cd054062 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \ distributed_shampoo @@ -49,7 +49,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): weight_decay=hyperparameters.weight_decay, batch_axis_name='batch', eigh=False) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -97,7 +97,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -105,7 +105,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -163,6 +163,27 @@ def update_params( return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/target_setting_algorithms/data_selection.py b/reference_algorithms/target_setting_algorithms/data_selection.py index ce24482fc..5e70f9f8b 100644 --- a/reference_algorithms/target_setting_algorithms/data_selection.py +++ b/reference_algorithms/target_setting_algorithms/data_selection.py @@ -1,6 +1,6 @@ from typing import Dict, Iterator, Tuple -from algorithmic_efficiency import spec +from algoperf import spec def data_selection( diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index 6d2cfe245..b64f0dfd6 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import @@ -29,7 +29,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index 08a0f7e9d..a6c3d853b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..597a43c9e 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import @@ -96,8 +96,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -108,7 +108,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -124,14 +124,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -156,7 +156,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 6b27e0e2a..0c11044fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 51b20181b..217228935 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -53,7 +53,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -61,7 +61,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) @@ -112,3 +112,24 @@ def update_params( 'grad_norm': grad_norm[0], }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py index 0dcb5ab14..c87bdfb7d 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py @@ -2,7 +2,7 @@ import torch -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py index 1a2df449a..584caff39 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py @@ -3,7 +3,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index 71b819e66..a9dee1d79 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py index 830e5eac9..8e10db4ef 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py @@ -3,7 +3,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 6203c58b3..bbfd8b0f2 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -6,8 +6,8 @@ import torch import torch.distributed.nn as dist_nn -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] @@ -92,3 +92,24 @@ def update_params( grad_norm.item()) return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py index 5fb5f259d..d0e5bf70b 100644 --- a/scoring/compute_speedups.py +++ b/scoring/compute_speedups.py @@ -25,6 +25,7 @@ 'Whether to save the results to disk.') FLAGS = flags.FLAGS +# These are the old budgets, used in the first iteration of the competition. MAX_BUDGETS = { 'criteo1tb': 7703, 'fastmri': 8859, diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index f4f2d5679..615ac6fe1 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -38,14 +38,14 @@ import pandas as pd from tabulate import tabulate -from algorithmic_efficiency.workloads.workloads import get_base_workload_name -import algorithmic_efficiency.workloads.workloads as workloads_registry +from algoperf.workloads.workloads import get_base_workload_name +import algoperf.workloads.workloads as workloads_registry from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' -BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +BASE_WORKLOADS_DIR = 'algoperf/workloads/' # Open json file to read heldout workloads # TODO: This probably shouldn't be hardcoded but passed as an argument. with open("held_out_workloads_algoperf_v05.json", "r") as f: diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index e474b6910..99c6e810e 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -20,8 +20,8 @@ from absl import flags from absl import logging -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency.workloads.workloads import get_base_workload_name +from algoperf import random_utils as prng +from algoperf.workloads.workloads import get_base_workload_name import docker flags.DEFINE_string( diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 0dd997ab9..ac513816e 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -7,7 +7,7 @@ from absl import logging import pandas as pd -import algorithmic_efficiency.workloads.workloads as workloads_registry +import algoperf.workloads.workloads as workloads_registry TRIAL_LINE_REGEX = '(.*) --- Tuning run (\d+)/(\d+) ---' METRICS_LINE_REGEX = '(.*) Metrics: ({.*})' @@ -17,7 +17,7 @@ WORKLOADS = workloads_registry.WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' -BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +BASE_WORKLOADS_DIR = 'algoperf/workloads/' #### File IO helper functions ### diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4afefd164..000000000 --- a/setup.cfg +++ /dev/null @@ -1,314 +0,0 @@ -############################################################################### -# MLCommons Algorithmic Efficiency. # -############################################################################### - -[metadata] -name = algorithmic_efficiency -version = attr: algorithmic_efficiency.__version__ -description = MLCommons Algorithmic Efficiency -url = https://github.com/mlcommons/algorithmic-efficiency -author = MLCommons Algorithmic Efficiency -author_email = algorithms@mlcommons.org -license = Apache 2.0 -long_description = file: README.md -long_description_content_type = text/markdown -keywords = algorithmic-efficiency, machine-learning, deep-learning, - optimization, benchmarking, training-methods -platforms = any -classifiers = - Development Status :: 3 - Alpha - Intended Audience :: Developers - Intended Audience :: Science/Research - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Topic :: Scientific/Engineering :: Artificial Intelligence - -[options] -zip_safe = False -packages = find: -include_package_data = True -setup_requires = - setuptools_scm -# Dependencies of the project: -install_requires = - absl-py==1.4.0 - # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 - gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 - tabulate==0.9.0 -python_requires = >=3.8 - - -############################################################################### -# Additional Dependencies # -############################################################################### - -[options.extras_require] -# Add extra dependencies, e.g. to run tests or for the different frameworks. -# Use as `pip install -e '.[jax_gpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html` -# or `pip install -e '.[dev]'` - -# Bundled installs # - -# All workloads -full = - %(criteo1tb)s - %(fastmri)s - %(ogbg)s - %(librispeech_conformer)s - %(wmt)s - -# All workloads plus development dependencies -full_dev = - %(full)s - %(dev)s - - -# Dependencies for developing the package -dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 - -# Workloads # -criteo1tb = - scikit-learn==1.2.2 - -fastmri = - h5py==3.8.0 - scikit-image==0.20.0 - -ogbg = - jraph==0.0.6.dev0 - scikit-learn==1.2.2 - -librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - pydub==0.25.1 - -wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - sacrebleu==1.3.1 - -# Frameworks # - -# JAX Core -jax_core_deps = - flax==0.6.10 - optax==0.1.5 - # Fix chex (optax dependency) version. - # Not fixing it can raise dependency issues with our - # jax version. - # Todo(kasimbeg): verify if this is necessary after we - # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 - - -# JAX CPU -jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 - %(jax_core_deps)s - -# JAX GPU -# Note this installs both jax and jaxlib. -jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 - %(jax_core_deps)s - -# PyTorch CPU -pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 - -# PyTorch GPU -# Note: omit the cuda suffix and installing from the appropriate -# wheel will result in using locally installed CUDA. -pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 - -# wandb -wandb = - wandb==0.16.5 - -############################################################################### -# Linting Configurations # -############################################################################### - -# yapf configuration -[yapf] -based_on_style = yapf -each_dict_entry_on_separate_line = false -split_all_top_level_comma_separated_values = true - - -# isort configuration -[isort] -profile=google - - -# pylint configuration -[pylint.MASTER] -persistent=no # Pickle collected data for later comparisons. -#cache-size=500 # Set the cache size for astng objects. -# Ignore Py3 files -ignore=get_references_web.py,get_references_web_single_group.py -[pylint.REPORTS] -# Set the output format. -# output-format=sorted-text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". -#files-output=no -# Tells whether to display a full report or only the messages. -reports=no -# Disable the report(s) with the given id(s). -#disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 -# Error message template (continued on second line) -msg-template={msg_id}:{line:3} {obj}: {msg} [{symbol}] -[pylint.'MESSAGES CONTROL'] -# List of checkers and warnings to enable. -enable=indexing-exception,old-raise-syntax - - -[pylint.BASIC] -# Required attributes for module, separated by a comma -#required-attributes= -# Regular expression which should only match the name -# of functions or classes which do not require a docstring. -no-docstring-rgx=(__.*__|main) -# Min length in lines of a function that requires a docstring. -docstring-min-length=10 -# Regular expression which should only match correct module names. The -# leading underscore is sanctioned for private modules by Google's style -# guide. -# -# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover -# requirements of Python's module system. -module-rgx=^(_?[a-z][a-z0-9_]*)|__init__$ -# Regular expression which should only match correct module level names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ -# Regular expression which should only match correct class attribute -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ -# Regular expression which should only match correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ -# Regular expression which should only match correct function names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ -# Regular expression which should only match correct method names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. 'exempt' indicates a name which is -# consistent with all naming styles. -method-rgx=(?x) - ^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase - |tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass - |(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next) - |(?P_{0,2}[A-Z][a-zA-Z0-9_]*) - |(?P_{0,2}[a-z][a-z0-9_]*))$ -# Regular expression which should only match correct instance attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ -# Regular expression which should only match correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ -# Regular expression which should only match correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx=^[a-z][a-z0-9_]*$ -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ -# Bad variable names which should always be refused, separated by a comma -bad-names= -# List of builtins function names that should not be used, separated by a comma -#bad-functions=input,apply,reduce -# List of decorators that define properties, such as abc.abstractproperty. -property-classes=abc.abstractproperty -[pylint.TYPECHECK] -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes -# List of decorators that create context managers from functions, such as -# contextlib.contextmanager. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager -[pylint.VARIABLES] -# Tells whether we should check for unused import in __init__ files. -init-import=no -# A regular expression matching names used for dummy variables (i.e. not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= -[pylint.CLASSES] -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__,__new__,setUp -# "class_" is also a valid for the first argument to a class method. -valid-classmethod-first-arg=cls,class_ -[pylint.EXCEPTIONS] -overgeneral-exceptions=builtins.StandardError,builtins.Exception,builtins.BaseException -[pylint.IMPORTS] -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets -[pylint.FORMAT] -# List of checkers and warnings to disable. -disable=abstract-method,access-member-before-definition,arguments-differ,assignment-from-no-return,attribute-defined-outside-init,bad-mcs-classmethod-argument,bad-option-value,c-extension-no-member,consider-merging-isinstance,consider-using-dict-comprehension,consider-using-enumerate,consider-using-in,consider-using-set-comprehension,consider-using-ternary,deprecated-method,design,file-ignored,fixme,global-statement,import-error,inconsistent-return-statements,invalid-unary-operand-type,len-as-condition,locally-disabled,locally-enabled,misplaced-comparison-constant,missing-docstring,multiple-imports,no-else-return,no-member,no-name-in-module,no-self-use,no-value-for-parameter,not-an-iterable,not-context-manager,pointless-except,protected-access,redefined-argument-from-local,signature-differs,similarities,simplifiable-if-expression,star-args,super-init-not-called,suppressed-message,too-many-function-args,trailing-comma-tuple,trailing-newlines,ungrouped-imports,unnecessary-pass,unsubscriptable-object,unused-argument,useless-object-inheritance,useless-return,useless-suppression,wrong-import-order,wrong-import-position,unneeded-not,unexpected-keyword-arg,redundant-keyword-arg,unspecified-encoding,logging-fstring-interpolation,consider-using-f-string,use-dict-literal - -# Maximum number of characters on a single line. -max-line-length=80 -# Regexp for a line that is allowed to be longer than the limit. -# This "ignore" regex is today composed of several independent parts: -# (1) Long import lines -# (2) URLs in comments or pydocs. Detecting URLs by regex is a hard problem and -# no amount of tweaking will make a perfect regex AFAICT. This one is a good -# compromise. -# (3) Constant string literals at the start of files don't need to be broken -# across lines. Allowing long paths and urls to be on a single -# line. Also requires that the string not be a triplequoted string. -ignore-long-lines=(?x) - (^\s*(import|from)\s - |^\s*(\#\ )??$ - |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') - ) -# Maximum number of lines in a module -max-module-lines=99999 -# String used as indentation unit. We differ from PEP8's normal 4 spaces. -indent-string=' ' -# Do not warn about multiple statements on a single line for constructs like -# if test: stmt -single-line-if-stmt=y -[pylint.LOGGING] -# Add logging modules. -logging-modules=logging,absl.logging -[pylint.MISCELLANEOUS] -# Maximum line length for lambdas -#short-func-length=1 -# List of module members that should be marked as deprecated. -# All of the string functions are listed in 4.1.4 Deprecated string functions -# in the Python 2.4 docs. -#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint -# List of exceptions that do not need to be mentioned in the Raises section of -# a docstring. -#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError -# Number of spaces of indent required when the last token on the preceding line -# is an open (, [, or {. -indent-after-paren=4 diff --git a/setup.py b/setup.py deleted file mode 100644 index a4ead8f48..000000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -from setuptools import setup - -if __name__ == "__main__": - setup() diff --git a/submission_runner.py b/submission_runner.py index 1a66acc58..a2521e77b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -30,29 +30,29 @@ from absl import flags from absl import logging import jax +import tensorflow as tf import torch import torch.distributed as dist -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -import tensorflow as tf - # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') -from algorithmic_efficiency import checkpoint_utils -from algorithmic_efficiency import halton -from algorithmic_efficiency import logger_utils -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency import spec -from algorithmic_efficiency.profiler import PassThroughProfiler -from algorithmic_efficiency.profiler import Profiler -from algorithmic_efficiency.pytorch_utils import pytorch_init -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.pytorch_utils import sync_ddp_time -from algorithmic_efficiency.workloads import workloads - -# disable only for deepspeech if it works fine for other workloads. +from algoperf import checkpoint_utils +from algoperf import halton +from algoperf import logger_utils +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.profiler import Profiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import sync_ddp_time +from algoperf.workloads import workloads + +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' # TODO(znado): make a nicer registry of workloads that lookup in. @@ -202,6 +202,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, + prepare_for_eval: Optional[spec.PrepareForEvalFn], hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState, @@ -241,8 +242,9 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', + 'librispeech_deepspeech' ] - eager_backend_workloads = ['librispeech_deepspeech'] + eager_backend_workloads = [] aot_eager_backend_workloads = [] loss_compilation_workloads = [ 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt' @@ -341,7 +343,9 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) + + data_select_rng, update_rng, prep_eval_rng, eval_rng = \ + prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -378,101 +382,131 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) - # Use 3x the runtime budget for the self-tuning ruleset. - max_allowed_runtime_sec = ( - workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 3 * workload.max_allowed_runtime_sec) - train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): - with profiler.profile('Evaluation'): - del batch - _reset_cuda_mem() - - try: - eval_start_time = get_time() - latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) - # Check if targets reached. - # Note that this is one of the stopping conditions for the length of - # a training run. To score the run we only consider the time - # to validation target retrospectively. - train_state['validation_goal_reached'] = ( - workload.has_reached_validation_target(latest_eval_result) or - train_state['validation_goal_reached']) - train_state['test_goal_reached'] = ( - workload.has_reached_test_target(latest_eval_result) or - train_state['test_goal_reached']) - goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) - # Save last eval time. - eval_end_time = get_time() - train_state['last_eval_time'] = eval_end_time - - # Accumulate eval time. - train_state[ - 'accumulated_eval_time'] += eval_end_time - eval_start_time - - # Add times to eval results for logging. - latest_eval_result['score'] = ( - train_state['accumulated_submission_time']) - latest_eval_result[ - 'total_duration'] = eval_end_time - global_start_time - latest_eval_result['accumulated_submission_time'] = train_state[ - 'accumulated_submission_time'] - latest_eval_result['accumulated_eval_time'] = train_state[ - 'accumulated_eval_time'] - latest_eval_result['accumulated_logging_time'] = train_state[ - 'accumulated_logging_time'] - time_since_start = latest_eval_result['total_duration'] - logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') - eval_results.append((global_step, latest_eval_result)) - - logging_start_time = get_time() - - if log_dir is not None and RANK == 0: - metrics_logger.append_scalar_metrics( - latest_eval_result, - global_step=global_step, - preemption_count=preemption_count, - is_eval=True, - ) - if save_checkpoints: - checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS - .save_intermediate_checkpoints) - - logging_end_time = get_time() - train_state['accumulated_logging_time'] += ( - logging_end_time - logging_start_time) + # Prepare for evaluation (timed). + if prepare_for_eval is not None: + + with profiler.profile('Prepare for eval'): + del batch + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) + + # Check if time is remaining, + # use 1.5x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 1.5 * workload.max_allowed_runtime_sec) + train_state['is_time_remaining'] = ( + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + + # Eval if time is remaining (untimed). + if train_state['is_time_remaining']: + + with profiler.profile('Evaluation'): _reset_cuda_mem() - except RuntimeError as e: - logging.exception(f'Eval step {global_step} error.\n') - if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + try: + eval_start_time = get_time() + latest_eval_result = workload.eval_model(global_eval_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) + # Check if targets reached. + # Note that this is one of the stopping conditions for the length of + # a training run. To score the run we only consider the time + # to validation target retrospectively. + train_state['validation_goal_reached'] = ( + workload.has_reached_validation_target(latest_eval_result) or + train_state['validation_goal_reached']) + train_state['test_goal_reached'] = ( + workload.has_reached_test_target(latest_eval_result) or + train_state['test_goal_reached']) + goals_reached = ( + train_state['validation_goal_reached'] and + train_state['test_goal_reached']) + # Save last eval time. + eval_end_time = get_time() + train_state['last_eval_time'] = eval_end_time + + # Accumulate eval time. + train_state[ + 'accumulated_eval_time'] += eval_end_time - eval_start_time + + # Add times to eval results for logging. + latest_eval_result['score'] = ( + train_state['accumulated_submission_time']) + latest_eval_result[ + 'total_duration'] = eval_end_time - global_start_time + latest_eval_result['accumulated_submission_time'] = train_state[ + 'accumulated_submission_time'] + latest_eval_result['accumulated_eval_time'] = train_state[ + 'accumulated_eval_time'] + latest_eval_result['accumulated_logging_time'] = train_state[ + 'accumulated_logging_time'] + time_since_start = latest_eval_result['total_duration'] + logging.info(f'Time since start: {time_since_start:.2f}s, ' + f'\tStep: {global_step}, \t{latest_eval_result}') + eval_results.append((global_step, latest_eval_result)) + + logging_start_time = get_time() + + if log_dir is not None and RANK == 0: + metrics_logger.append_scalar_metrics( + latest_eval_result, + global_step=global_step, + preemption_count=preemption_count, + is_eval=True, + ) + if save_checkpoints: + checkpoint_utils.save_checkpoint( + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS + .save_intermediate_checkpoints) + + logging_end_time = get_time() + train_state['accumulated_logging_time'] += ( + logging_end_time - logging_start_time) + _reset_cuda_mem() + except RuntimeError as e: + logging.exception(f'Eval step {global_step} error.\n') + if 'out of memory' in str(e): + logging.warning( + 'Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') + _reset_cuda_mem() + train_state['last_step_end_time'] = get_time() metrics = {'eval_results': eval_results, 'global_step': global_step} @@ -526,6 +560,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection + prepare_for_eval = getattr(submission_module, 'prepare_for_eval', None) try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: @@ -598,6 +633,7 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, + prepare_for_eval, hyperparameters, rng_seed, rng, @@ -644,6 +680,14 @@ def main(_): else: profiler = PassThroughProfiler() + # Set PyTorch environment variables before initializing w DDP + base_workload = workloads.get_base_workload_name(FLAGS.workload) + if base_workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + + if FLAGS.set_pytorch_max_split_size: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch': pytorch_init(USE_PYTORCH_DDP, RANK, profiler) @@ -655,13 +699,13 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] - # Prevent OOM on librispeech conformer. - base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload == 'librispeech_conformer': - os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' - - if FLAGS.set_pytorch_max_split_size: - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if base_workload in [ + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb' + ]: + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( diff --git a/submissions/template/submission.py b/submissions/template/submission.py index ab98c9958..a4fdc62b4 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -6,7 +6,7 @@ """ from typing import Any, Dict, Iterator, List, Optional, Tuple -from algorithmic_efficiency import spec +from algoperf import spec def init_optimizer_state(workload: spec.Workload, @@ -44,6 +44,25 @@ def update_params( pass +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """ + Returns: + new_optimizer_state + new_params + new_model_state + """ + pass + + def get_batch_size(workload_name): """ Gets batch size for workload. diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index adbade983..d280803af 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -7,10 +7,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 0748e2d71..73744c667 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -7,10 +7,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 0a6e5c5ac..96e3cc5cc 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -7,10 +7,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 288442594..188e4cac3 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 56b74b32d..da5f0ba0a 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRIWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRIWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 23ccf26d7..5f1eb1842 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRILayerNormWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRILayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index b61516c29..ebb8669f8 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRIModelSizeWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRIModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 0f455387c..558bc2ba1 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRITanhWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRITanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index fb730f1bf..0a6a1b7c5 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 6c8adbec2..4f20873b7 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetGELUWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetGELUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 7668cdbd9..e94fdcd4c 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetSiLUWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetSiLUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index ba21e63da..b7b9af794 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index 2c0aa546d..11edcd84e 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -9,10 +9,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitGluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitGluWorkload as PyTorchWorkload sd_transform = None diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index e7c4c2ee8..70bcd2e04 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -9,10 +9,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitMapWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitMapWorkload as PytWorkload diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index 8a9063cac..113a65a2a 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -9,10 +9,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitPostLNWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetViTPostLNWorkload as PyTorchWorkload sd_transform = None diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index cfe6c7381..5bfbf915a 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index 8480fca02..bb9a8fae1 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index caa9b09b9..629418488 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerGeluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 1a94d3c77..48fe991f7 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerLayerNormWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index edcc3ba87..81e12b15d 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index 6c00bdf69..ea106ebe4 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechTanhWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.librispeech_deepspeech.compare import key_transform diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index c68d6adf9..ecb6d28af 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.librispeech_deepspeech.compare import key_transform diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 4cfdf4f21..31d9029b4 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.librispeech_deepspeech.compare import key_transform diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 56316ba12..43ca48764 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index b58bcd461..062588fe2 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgGeluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 62443bbb5..2eb70d097 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgModelSizeWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 2922b7046..19e446030 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgSiluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgSiluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index d45694bcb..5595894e6 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.jax_submission_base import \ @@ -21,7 +21,7 @@ def init_optimizer_state(workload: spec.Workload, del rng # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001) diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py index 254ef6018..a6a0c5fa6 100644 --- a/tests/modeldiffs/vanilla_sgd_pytorch.py +++ b/tests/modeldiffs/vanilla_sgd_pytorch.py @@ -1,6 +1,6 @@ import torch -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 41fc5ee17..64401ef7f 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -6,10 +6,9 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ - WmtWorkload as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 92ce4eb44..01dc2895c 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkloadAttentionTemp as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadAttentionTemp as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index b8d860479..77e71c826 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkloadGLUTanH as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadGLUTanH as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 3f5469d8d..909fcd672 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkloadPostLN as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadPostLN as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index f107be8d7..c4ca514a8 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -40,15 +40,13 @@ import torch import torch.distributed as dist -from algorithmic_efficiency import halton -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency.profiler import PassThroughProfiler -from algorithmic_efficiency.workloads import workloads -from algorithmic_efficiency.workloads.ogbg import \ - input_pipeline as ogbg_input_pipeline -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ - _graph_map +from algoperf import halton +from algoperf import pytorch_utils +from algoperf import random_utils as prng +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads import workloads +from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline +from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map import submission_runner from tests.modeldiffs import diff as diff_utils @@ -97,9 +95,9 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes): def _pytorch_map(inputs): if USE_PYTORCH_DDP: - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs) - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view( -1), diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index cc98e603e..ff724b201 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -13,7 +13,7 @@ from absl.testing import absltest from absl.testing import parameterized -from algorithmic_efficiency.profiler import PassThroughProfiler +from algoperf.profiler import PassThroughProfiler import submission_runner FLAGS = flags.FLAGS diff --git a/tests/test_baselines.py b/tests/test_baselines.py index f79e629e7..b2be8aa11 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -12,8 +12,8 @@ from absl.testing import absltest from absl.testing import parameterized -from algorithmic_efficiency.profiler import PassThroughProfiler -from algorithmic_efficiency.workloads import workloads +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads import workloads import submission_runner FLAGS = flags.FLAGS diff --git a/tests/test_num_params.py b/tests/test_num_params.py index 574fd0aa5..b0633025e 100644 --- a/tests/test_num_params.py +++ b/tests/test_num_params.py @@ -5,42 +5,37 @@ import pytest import torch -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.models import \ +from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ DlrmSmall as JaxDlrmSmall -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.models import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ DlrmSmall as PyTorchDlrmSmall -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ ResNet18 as JaxResNet_c10 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ ResNet50 as JaxResNet -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet18 as PyTorchResNet_c10 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet50 as PyTorchResNet -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.models import \ - ViT as JaxViT -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ ViT as PyTorchViT -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ Conformer as JaxConformer -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ ConformerConfig as JaxConformerConfig -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ ConformerConfig as PytorchConformerConfig -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ ConformerEncoderDecoder as PytorchConformer -from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import \ - _Model as JaxMLP -from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import \ +from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP +from algoperf.workloads.mnist.mnist_pytorch.workload import \ _Model as PyTorchMLP -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.models import \ - GNN as PyTorchGNN -from algorithmic_efficiency.workloads.wmt.wmt_jax.models import \ - Transformer as JaxTransformer -from algorithmic_efficiency.workloads.wmt.wmt_jax.models import \ - TransformerConfig -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ +from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN +from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer +from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig +from algoperf.workloads.wmt.wmt_pytorch.models import \ Transformer as PyTorchTransformer WORKLOADS = [ diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..df4c798d8 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -3,29 +3,30 @@ import jax import numpy as np import pytest +from flax.core import FrozenDict # isort: skip_file # pylint:disable=line-too-long -from algorithmic_efficiency.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algorithmic_efficiency.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload +from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload +from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload +from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload +from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload +from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload +from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload # pylint:enable=line-too-long WORKLOADS = [ @@ -51,8 +52,11 @@ def test_param_shapes(workload): jax_workload, pytorch_workload = get_workload(workload) # Compare number of parameter tensors of both models. + jax_workload_param_shapes = jax_workload.param_shapes + if isinstance(jax_workload_param_shapes, dict): + jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload.param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) if workload == 'wmt': diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 7cf8f63c3..d3722ae86 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -2,30 +2,30 @@ import pytest from absl import logging -from algorithmic_efficiency import spec +from algoperf import spec # isort: skip_file # pylint:disable=line-too-long -from algorithmic_efficiency.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algorithmic_efficiency.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload +from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload +from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload +from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload +from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload +from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload +from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload # pylint:enable=line-too-long WORKLOADS = [ diff --git a/tests/test_ssim.py b/tests/test_ssim.py index fadf41f64..920556964 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -9,14 +9,13 @@ import numpy as np import torch -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import \ +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.fastmri.fastmri_jax.ssim import \ _uniform_filter as _jax_uniform_filter -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import \ - ssim as jax_ssim -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import \ +from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ _uniform_filter as _pytorch_uniform_filter -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import \ +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ ssim as pytorch_ssim # Make sure no GPU memory is preallocated to Jax. diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 000000000..d1bfbd18f --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,17 @@ +"""Check whether the __version__ attribute is set correctly.""" + +import algoperf + + +def test_version_attribute(): + """Check whether __version__ exists and is a valid string.""" + + assert hasattr(algoperf, "__version__") + version = algoperf.__version__ + assert isinstance(version, str) + version_elements = version.split(".") + print(version_elements) + # Only check the first two elements, i.e. major, minor + # (patch is not checked as it is not required). + # The remaining elements contain commit hash and dirty status. + assert all(el.isnumeric() for el in version_elements[0:2]) diff --git a/tests/version_test.py b/tests/version_test.py deleted file mode 100644 index 9f7006aab..000000000 --- a/tests/version_test.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Check whether the __version__ attribute is set correctly.""" - -import algorithmic_efficiency - - -def test_version_attribute(): - """Check whether __version__ exists and is a valid string.""" - - assert hasattr(algorithmic_efficiency, "__version__") - version = algorithmic_efficiency.__version__ - assert isinstance(version, str) - version_elements = version.split(".") - assert all(el.isnumeric() for el in version_elements) diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 6a85c2196..d44234927 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -4,13 +4,13 @@ import jax import jax.numpy as jnp -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload def _pytree_total_diff(pytree_a, pytree_b): - pytree_diff = jax.tree_map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) + pytree_diff = jax.tree.map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) pytree_diff = jax.tree_util.tree_leaves(pytree_diff) return jnp.sum(jnp.array(pytree_diff))