Skip to content

Commit 4a8155e

Browse files
Migrated jax stable stack -> jax ai image (#177)
Changed requirements file with JAII name & updated Dockerfile Resolved comments. Added deprecation warning and notes
1 parent 013c2f8 commit 4a8155e

File tree

6 files changed

+40
-38
lines changed

6 files changed

+40
-38
lines changed

.github/workflows/UploadDockerImages.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ jobs:
2929
- uses: actions/checkout@v3
3030
- name: Cleanup old docker images
3131
run: docker system prune --all --force
32-
- name: build maxdiffusion jax stable stack image
32+
- name: build maxdiffusion jax ai image
3333
run: |
34-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest
34+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
3535
- name: build maxdiffusion jax nightly image
3636
run: |
3737
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly

docker_build_dependency_image.sh

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
# Each time you update the base image via a "bash docker_maxdiffusion_image_upload.sh", there will be a slow upload process
2121
# (minutes). However, if you are simply changing local code and not updating dependencies, uploading just takes a few seconds.
2222

23-
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}}
23+
# bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
24+
# Note: The mode stable_stack is marked for deprecation, please use MODE=jax_ai_image instead
25+
# bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_IMAGE BASEIMAGE FROM ARTIFACT REGISTRY}}
2426
# bash docker_build_dependency_image.sh MODE=nightly
2527
# bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13
2628
# bash docker_build_dependency_image.sh MODE=stable
@@ -69,17 +71,17 @@ if [[ ${DEVICE} == "gpu" ]]; then
6971
fi
7072
docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE -f ./maxdiffusion_gpu_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} .
7173
else
72-
if [[ "${MODE}" == "stable_stack" ]]; then
74+
if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then
7375
if [[ ! -v BASEIMAGE ]]; then
7476
echo "Erroring out because BASEIMAGE is unset, please set it!"
7577
exit 1
7678
fi
7779
docker build --no-cache \
78-
--build-arg JAX_STABLE_STACK_BASEIMAGE=${BASEIMAGE} \
80+
--build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \
7981
--build-arg COMMIT_HASH=${COMMIT_HASH} \
8082
--network=host \
8183
-t ${LOCAL_IMAGE_NAME} \
82-
-f maxdiffusion_jax_stable_stack_tpu.Dockerfile .
84+
-f maxdiffusion_jax_ai_image_tpu.Dockerfile .
8385
else
8486
docker build --no-cache \
8587
--network=host \

docs/getting_started/run_maxdiffusion_via_xpk.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,23 @@ after which log out and log back in to the machine.
6262
bash docker_build_dependency_image.sh
6363
```
6464

65-
#### New: Build MaxDiffusion Docker Image with JAX Stable Stack
66-
We're excited to announce that you can build the MaxDiffusion Docker image using the JAX Stable Stack base image. This provides a more reliable and consistent build environment.
65+
#### New: Build MaxDiffusion Docker Image with JAX AI Images (Formerly known as JAX Stable Stack)
66+
We're excited to announce that you can build the MaxDiffusion Docker image using the JAX AI base image. This provides a more reliable and consistent build environment.
6767
68-
###### What is JAX Stable Stack?
69-
JAX Stable Stack provides a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions.
68+
###### What is JAX AI Images?
69+
JAX AI Images provide a consistent environment for MaxDiffusion by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running MaxDiffusion and eliminating potential conflicts due to incompatible package versions.
7070
7171
###### How to Use It
72-
To build the MaxDiffusion Docker image with JAX Stable Stack, simply set the MODE to `stable_stack` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:
72+
To build the MaxDiffusion Docker image with JAX AI Images, simply set the MODE to `jax_ai_image` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:
7373
7474
```
75-
# Example bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1
76-
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_BASEIMAGE}}
75+
# Example bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2
76+
bash docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE={{JAX_AI_IMAGE_BASEIMAGE}}
7777
```
7878
79-
You can find a list of available JAX Stable Stack base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu).
79+
You can find a list of available JAX AI base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu).
8080
81-
**Important Note:** The JAX Stable Stack is currently in the experimental phase. We encourage you to try it out and provide feedback.
81+
**Important Note:** JAX AI Images is currently in the experimental phase. We encourage you to try it out and provide feedback.
8282
8383
3. After building the dependency image `maxdiffusion_base_image`, xpk can handle updates to the working directory when running `xpk workload create` and using `--base-docker-image`.
8484
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
ARG JAX_AI_IMAGE_BASEIMAGE
2+
3+
# JAX AI Base Image
4+
FROM $JAX_AI_IMAGE_BASEIMAGE
5+
6+
ARG COMMIT_HASH
7+
8+
ENV COMMIT_HASH=$COMMIT_HASH
9+
10+
RUN mkdir -p /deps
11+
12+
# Set the working directory in the container
13+
WORKDIR /deps
14+
15+
# Copy all files from local workspace into docker container
16+
COPY . .
17+
18+
# Install Maxdiffusion Jax AI Image requirements
19+
RUN pip install -r /deps/requirements_with_jax_ai_image.txt
20+
21+
# Run the script available in JAX-AI-Image base image to generate the manifest file
22+
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH

maxdiffusion_jax_stable_stack_tpu.Dockerfile

Lines changed: 0 additions & 22 deletions
This file was deleted.

requirements_with_jax_stable_stack.txt renamed to requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Requirements for Building the MaxDifussion Docker Image
2-
# These requirements are additional to the dependencies present in the JAX Stable Stack base image.
2+
# These requirements are additional to the dependencies present in the JAX AI base image.
33
absl-py
44
datasets
55
einops==0.8.0

0 commit comments

Comments
 (0)