nsys-jax: re-work to be more pip-install-able #3562
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: CI | |
on: | |
schedule: | |
- cron: '30 9 * * *' # Pacific Time 01:30 AM in UTC | |
pull_request: | |
types: | |
- opened | |
- reopened | |
- ready_for_review | |
- synchronize | |
paths-ignore: | |
- '**.md' | |
workflow_dispatch: | |
inputs: | |
PUBLISH: | |
type: boolean | |
description: Publish dated images and update the 'latest' tag? | |
default: false | |
required: false | |
BUMP_MANIFEST: | |
type: boolean | |
description: Bump git repos in manifest.yaml to head of tree? | |
default: true | |
required: true | |
MERGE_BUMPED_MANIFEST: | |
type: boolean | |
description: "(used if BUMP_MANIFEST=true) If true: attempt to PR/merge manifest branch" | |
default: false | |
required: false | |
CUDA_IMAGE: | |
type: string | |
description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04 | |
default: 'latest' | |
required: false | |
SOURCE_OVERRIDES: | |
type: string | |
description: | | |
A comma-separated PACKAGE=URL#REF list to override sources used by build. | |
PACKAGE∊{JAX,XLA,Flax,transformer-engine,T5X,paxml,praxis,maxtext,levanter,haliax,mujuco,mujuco-mpc,gemma,big-vision,common-loop-utils,flaxformer,panopticapi} (case-insensitive) | |
default: '' | |
required: false | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | |
permissions: | |
contents: write # to fetch code and push branch | |
actions: write # to cancel previous workflows | |
packages: write # to upload container | |
pull-requests: write # to make pull request for manifest bump | |
env: | |
DEFAULT_MANIFEST_ARTIFACT_NAME: bumped-manifest | |
jobs: | |
metadata: | |
runs-on: ubuntu-22.04 | |
outputs: | |
BUILD_DATE: ${{ steps.date.outputs.BUILD_DATE }} | |
PUBLISH: ${{ steps.if-publish.outputs.PUBLISH }} | |
BUMP_MANIFEST: ${{ steps.manifest-branch.outputs.BUMP_MANIFEST }} | |
MANIFEST_ARTIFACT_NAME: ${{ steps.manifest-branch.outputs.MANIFEST_ARTIFACT_NAME }} | |
MANIFEST_BRANCH: ${{ steps.manifest-branch.outputs.MANIFEST_BRANCH }} | |
MERGE_BUMPED_MANIFEST: ${{ steps.manifest-branch.outputs.MERGE_BUMBED_MANIFEST }} | |
CUDA_IMAGE: ${{ steps.cuda-image.outputs.CUDA_IMAGE }} | |
steps: | |
- name: Cancel workflow run if the trigger is a draft PR | |
id: cancel-if-draft | |
if: github.event_name == 'pull_request' && github.event.pull_request.draft == true | |
run: | | |
echo "Cancelling workflow for draft PR" | |
curl -X POST -H "Authorization: token ${{ github.token }}" \ | |
-H "Accept: application/vnd.github.v3+json" \ | |
"https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/cancel" | |
while true; do sleep 1; done # blocks execution in case workflow cancellation takes time | |
- name: Set build date | |
id: date | |
shell: bash -x -e {0} | |
run: | | |
BUILD_DATE=$(TZ='US/Los_Angeles' date '+%Y-%m-%d') | |
echo "BUILD_DATE=${BUILD_DATE}" >> $GITHUB_OUTPUT | |
- name: Determine whether results will be 'published' | |
id: if-publish | |
shell: bash -x -e {0} | |
run: | | |
echo "PUBLISH=${{ github.event_name == 'schedule' || inputs.PUBLISH }}" >> $GITHUB_OUTPUT | |
- name: Set manifest branch name | |
id: manifest-branch | |
shell: bash -x -e {0} | |
run: | | |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then | |
BUMP_MANIFEST="${{ inputs.BUMP_MANIFEST }}" | |
else | |
BUMP_MANIFEST="true" | |
fi | |
MERGE_BUMPED_MANIFEST=${{ github.event_name == 'schedule' || inputs.MERGE_BUMPED_MANIFEST || 'false' }} | |
# Prepend nightly manifest branch with "z" to make it appear at the end | |
if [[ "$BUMP_MANIFEST" == "true" ]]; then | |
# This branch is for scheduled nightlies or manually run nightlies | |
MANIFEST_BRANCH=znightly-${{ steps.date.outputs.BUILD_DATE }}-${{ github.run_id }} | |
MANIFEST_ARTIFACT_NAME=${{ env.DEFAULT_MANIFEST_ARTIFACT_NAME }} | |
else | |
# This branch is for presubmits (no bumping needed) | |
MANIFEST_BRANCH=${{ github.sha }} | |
# Empty artifact name means to use the one in version control | |
MANIFEST_ARTIFACT_NAME="" | |
fi | |
echo "MANIFEST_BRANCH=$MANIFEST_BRANCH" | tee -a $GITHUB_OUTPUT | |
echo "MANIFEST_ARTIFACT_NAME=$MANIFEST_ARTIFACT_NAME" | tee -a $GITHUB_OUTPUT | |
echo "BUMP_MANIFEST=$BUMP_MANIFEST" | tee -a $GITHUB_OUTPUT | |
echo "MERGE_BUMBED_MANIFEST=$MERGE_BUMPED_MANIFEST" | tee -a $GITHUB_OUTPUT | |
if [[ "$BUMP_MANIFEST" == "false" && "$MERGE_BUMPED_MANIFEST" == "true" ]]; then | |
echo "Error: If BUMP_MANIFEST=false, MERGE_BUMPED_MANIFEST cannot be true" >&2 | |
exit 1 | |
fi | |
- name: Determine CUDA image to use | |
id: cuda-image | |
shell: bash -x -e {0} | |
run: | | |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then | |
CUDA_IMAGE="${{ inputs.CUDA_IMAGE }}" | |
else | |
CUDA_IMAGE="latest" | |
fi | |
echo "CUDA_IMAGE=${CUDA_IMAGE}" >> $GITHUB_OUTPUT | |
bump-manifest: | |
needs: metadata | |
runs-on: ubuntu-22.04 | |
outputs: | |
SOURCE_URLREFS: ${{ steps.source-urlrefs.outputs.SOURCE_URLREFS }} | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: Test if manifest bump is functional, and save result to a new file | |
working-directory: .github/container | |
shell: bash -x -e {0} | |
run: | | |
bash bump.sh --input-manifest manifest.yaml --output-manifest manifest.yaml.new --base-patch-dir ./patches-new | |
- name: Maybe replace current manifest/patches with the new one and show diff | |
working-directory: .github/container | |
shell: bash -x -e {0} | |
run: | | |
if [[ "${{ needs.metadata.outputs.BUMP_MANIFEST }}" == "true" ]]; then | |
mv manifest.yaml.new manifest.yaml | |
rm -rf patches | |
mv patches-new patches | |
else | |
rm -rf patches-new manifest.yaml.new | |
fi | |
sed -i 's|file://patches-new/|file://patches/|g' manifest.yaml | |
git diff | |
- name: Upload bumped manifest/patches to be used in build-base | |
if: needs.metadata.outputs.MANIFEST_ARTIFACT_NAME != '' | |
uses: actions/upload-artifact@v4 | |
with: | |
name: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} | |
path: | | |
.github/container/manifest.yaml | |
.github/container/patches | |
- name: Create URL ref build args | |
id: source-urlrefs | |
shell: bash -x -e {0} | |
run: | | |
# converts manifest yaml to a json object of {SOFTWARE_NAME: URL#REF, ...} | |
urlrefs=$( | |
cat .github/container/manifest.yaml |\ | |
yq -o=json 'to_entries | .[] | select(.value.mode == "git-clone") | {( .key | upcase | sub("-", "_") ): .value.url + "#" + .value.latest_verified_commit}' |\ | |
jq -c -s 'add' | |
) | |
# SOURCE_OVERRIDES is a comma-separated list of package=urlref pairs | |
IFS=, read -ra overrides <<< "${{ inputs.SOURCE_OVERRIDES }}" | |
for override in "${overrides[@]}"; do | |
PACKAGE=$(cut -d= -f 1 <<< "${override}" | tr '[:lower:]' '[:upper:]' | tr '-' '_') | |
URLREF=$(cut -d= -f 2- <<< "${override}") | |
urlrefs=$(echo "$urlrefs" | jq -c ". + {\"$PACKAGE\": \"$URLREF\"}") | |
done | |
echo "SOURCE_URLREFS=${urlrefs}" >> $GITHUB_OUTPUT | |
amd64: | |
needs: [metadata, bump-manifest] | |
uses: ./.github/workflows/_ci.yaml | |
with: | |
ARCHITECTURE: amd64 | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} | |
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} | |
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} | |
secrets: inherit | |
arm64: | |
needs: [metadata, bump-manifest] | |
uses: ./.github/workflows/_ci.yaml | |
with: | |
ARCHITECTURE: arm64 | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
CUDA_IMAGE: ${{ needs.metadata.outputs.CUDA_IMAGE }} | |
MANIFEST_ARTIFACT_NAME: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} | |
SOURCE_URLREFS: ${{ needs.bump-manifest.outputs.SOURCE_URLREFS }} | |
secrets: inherit | |
# Only merge if everything succeeds | |
merge-new-manifest: | |
runs-on: ubuntu-22.04 | |
if: ${{ !cancelled() && needs.metadata.outputs.MERGE_BUMPED_MANIFEST == 'true' && needs.metadata.outputs.MANIFEST_BRANCH != github.sha }} | |
needs: | |
- metadata | |
- amd64 | |
- arm64 | |
steps: | |
- name: "Tests Succeeded: ${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" | |
id: test_result | |
run: | |
echo "SUCCEEDED=${{ !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') }}" | tee -a $GITHUB_OUTPUT | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: Delete checked-out manifest and patches | |
run: | | |
rm .github/container/manifest.yaml | |
rm -rf .github/container/patches | |
- name: Replace checked-out manifest file/patches with bumped one | |
uses: actions/download-artifact@v4 | |
with: | |
name: ${{ needs.metadata.outputs.MANIFEST_ARTIFACT_NAME }} | |
path: .github/container/ | |
- name: "Create local manifest branch: ${{ needs.metadata.outputs.MANIFEST_BRANCH }}" | |
id: local_branch | |
shell: bash -x -e {0} | |
run: | | |
git config user.name "JAX-Toolbox CI" | |
git config user.email "[email protected]" | |
git switch -c ${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
git status | |
git add .github/container/patches/ | |
git status | |
# In the unusual situation where the manifest is the same even after bumping, | |
# we will produce an empty commit with --allow-empty, which allows a PR to be | |
# made and merged even with no changeset. | |
git commit --allow-empty -a -m "Nightly Manifest Bump (${{ needs.metadata.outputs.BUILD_DATE }}) from: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/${{ github.run_id }}" | |
- name: Try to merge manifest branch | |
id: merge_local | |
if: steps.test_result.outputs.SUCCEEDED == 'true' | |
# Merge can fail | |
continue-on-error: true | |
shell: bash -x -e {0} | |
run: | | |
git switch ${{ github.ref_name }} | |
# Pull this ref in case it was updated | |
git pull --rebase | |
git merge --ff-only ${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
# Push the new change | |
git push origin ${{ github.ref_name }} | |
# We will create a Draft PR & remote branch if: | |
# 1. The tests failed | |
# 2. The merge failed | |
- name: Create remote manifest branch | |
id: create_remote_branch | |
if: steps.test_result.outputs.SUCCEEDED == 'false' || steps.merge_local.outcome != 'success' | |
shell: bash -x -e {0} | |
run: | | |
# Always abort in case in-progress merge | |
git merge --abort || true | |
git switch ${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
# Since the merge failed, create a remote and follow up with a PR | |
git push --set-upstream origin ${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
- name: Creating Draft PR for MANIFEST_BRANCH=${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
id: create_pr | |
if: steps.test_result.outputs.SUCCEEDED == 'false' || steps.merge_local.outcome != 'success' | |
uses: octokit/[email protected] | |
with: | |
route: POST /repos/{owner_and_repo}/pulls | |
owner_and_repo: ${{ github.repository }} | |
head: ${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
# Always try to merge back into the branch that triggered this workflow | |
base: ${{ github.ref }} | |
body: | | |
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/${{ github.run_id }} | |
title: Nightly Manifest Bump (${{ needs.metadata.outputs.BUILD_DATE }}) | |
draft: true | |
env: | |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
- name: "Log created PR: #${{ fromJson(steps.create_pr.outputs.data).number }}" | |
if: steps.create_pr.outcome == 'success' | |
run: | | |
echo "https://github.com/NVIDIA/JAX-Toolbox/pull/${{ fromJson(steps.create_pr.outputs.data).number }}" | tee -a $GITHUB_STEP_SUMMARY | |
# Guard delete in simple check to protect other branches | |
- name: Check that the branch matches znightly- prefix | |
run: | | |
if [[ "${{ needs.metadata.outputs.MANIFEST_BRANCH }}" != znightly-* ]]; then | |
echo Tried to delete MANIFEST_BRANCH=${{ needs.metadata.outputs.MANIFEST_BRANCH }}, but did not start with "znightly-" | |
exit 1 | |
fi | |
# If merging fails b/c upstream conflict, branch is deleted to avoid clutter since changeset is preserved in PR | |
- name: Deleting remote MANIFEST_BRANCH=${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
# Delete can fail if branch was already deleted or not created, e.g., if the PR successfully merges, then branch is also already deleted. | |
continue-on-error: true | |
uses: octokit/[email protected] | |
with: | |
route: DELETE /repos/{owner_and_repo}/git/refs/heads/${{ needs.metadata.outputs.MANIFEST_BRANCH }} | |
owner_and_repo: ${{ github.repository }} | |
env: | |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | |
make-publish-configs: | |
runs-on: ubuntu-22.04 | |
if: ${{ !cancelled() }} | |
env: | |
MEALKIT_IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax-mealkit' || 'mock-jax-mealkit' }} | |
FINAL_IMAGE_REPO: ${{ needs.metadata.outputs.PUBLISH == 'true' && 'jax' || 'mock-jax' }} | |
needs: | |
- metadata | |
- amd64 | |
- arm64 | |
outputs: | |
PUBLISH_CONFIGS: ${{ steps.generate-configs.outputs.PUBLISH_CONFIGS }} | |
steps: | |
- id: generate-configs | |
shell: bash -eu -o pipefail {0} | |
run: | | |
declare -a FLAVORS=( | |
base | |
jax | |
triton | |
equinox | |
maxtext | |
levanter | |
upstream-t5x | |
upstream-pax | |
t5x | |
pax | |
gemma | |
) | |
declare -a STAGES=( | |
mealkit | |
final | |
) | |
## create JSON specs for a 1D matrix of container publication jobs | |
ALL_TAGS=$( | |
echo '${{ needs.amd64.outputs.DOCKER_TAGS }}' \ | |
'${{ needs.arm64.outputs.DOCKER_TAGS }}' |\ | |
jq -s 'add' | |
) | |
PUBLISH_CONFIGS='[]' | |
for stage in "${STAGES[@]}"; do | |
for flavor in "${FLAVORS[@]}";do | |
# collect images for different platforms, e.g. amd64 and arm64 | |
matching_tags=$( | |
echo "$ALL_TAGS" |\ | |
jq -c ".[] | select(.stage == \"${stage}\" and .flavor == \"${flavor}\" and .tag != \"\")" | |
) | |
# source_image is a list of all platform-specific tags | |
source_image=$(echo "${matching_tags}" | jq -c "[.tag]" | jq -s 'add') | |
# if the build job failed without producing any images, skip this flavor | |
n_source_images=$(echo "$source_image" | jq 'length') | |
if [[ $n_source_images -gt 0 ]]; then | |
echo "PUBLISH image $flavor with $n_source_images $stage containers" | |
# tag priority is the highest priority of all platform-specific tags | |
priority=$(echo "${matching_tags}" | jq -r ".priority" | jq -s 'max') | |
# put all final images in the `ghcr.io/nvidia/jax` namespace | |
# and mealkit images in `ghcr.io/nvidia/jax-toolbox-mealkit` namespace | |
case ${stage} in | |
mealkit) | |
target_image=${MEALKIT_IMAGE_REPO} | |
;; | |
final) | |
target_image=${FINAL_IMAGE_REPO} | |
;; | |
esac | |
PUBLISH_CONFIGS=$( | |
echo ${PUBLISH_CONFIGS} | jq -c ". + [{ | |
\"flavor\": \"${flavor}\", | |
\"target_image\": \"${target_image}\", | |
\"priority\": \"${priority}\", | |
\"source_image\": ${source_image}, | |
\"stage\": \"${stage}\" | |
}]" | |
) | |
else | |
echo "SKIPPED image $flavor with 0 $stage containers" | |
fi | |
done | |
done | |
PUBLISH_CONFIGS=$(echo "$PUBLISH_CONFIGS" | jq -c '{"config": .}') | |
echo ${PUBLISH_CONFIGS} | jq | |
echo "PUBLISH_CONFIGS=${PUBLISH_CONFIGS}" >> $GITHUB_OUTPUT | |
publish-containers: | |
needs: | |
- metadata | |
- make-publish-configs | |
if: ${{ !cancelled() && needs.make-publish-configs.outputs.PUBLISH_CONFIGS.config != '{"config":[]}' }} | |
strategy: | |
fail-fast: false | |
matrix: ${{ fromJson(needs.make-publish-configs.outputs.PUBLISH_CONFIGS) }} | |
uses: ./.github/workflows/_publish_container.yaml | |
with: | |
ARTIFACT_NAME: ${{ matrix.config.stage }}-${{ matrix.config.flavor }} | |
ARTIFACT_TAG: ${{ matrix.config.flavor }}-${{ needs.metadata.outputs.BUILD_DATE }} | |
SOURCE_IMAGE: ${{ join(matrix.config.source_image, ' ') }} | |
TARGET_IMAGE: ${{ matrix.config.target_image }} | |
TARGET_TAGS: | | |
type=raw,value=${{ matrix.config.flavor }},priority=${{ matrix.config.priority }} | |
type=raw,value=${{ matrix.config.flavor }}-${{ needs.metadata.outputs.BUILD_DATE }},priority=${{ matrix.config.priority }} | |
finalize: | |
needs: [metadata, amd64, arm64, publish-containers] | |
if: "!cancelled()" | |
uses: ./.github/workflows/_finalize.yaml | |
with: | |
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }} | |
PUBLISH_BADGE: ${{ needs.metadata.outputs.PUBLISH == 'true' }} | |
secrets: inherit |