Skip to content

Commit ae2a542

Browse files
susanbaosusanbao
andauthored
update jax/flax version (#274)
* update jax/flax version * update script for WAN 2.1 * HIDDEN_STATE_WITH_OFFLOAD has not been supported * fix * update --------- Co-authored-by: susanbao <sanbao_google_com@t1v-n-216c02cd-w-0.europe-west4-b.c.cloud-tpu-multipod-dev.internal>
1 parent 9716fc5 commit ae2a542

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ After installation completes, run the training script.
233233
max_train_steps=1000 \
234234
enable_profiler=True \
235235
dataset_save_location=${SAVE_DATASET_DIR} \
236-
remat_policy='FULL' \
236+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
237237
flash_min_seq_length=0 \
238238
seed=$RANDOM \
239239
skip_first_n_steps_for_profiler=3 \
@@ -352,12 +352,11 @@ After installation completes, run the training script.
352352
per_device_batch_size=0.25 \
353353
ici_data_parallelism=32 \
354354
ici_fsdp_parallelism=4 \
355-
ici_tensor_parallelism=1" \
355+
ici_tensor_parallelism=1 \
356356
max_train_steps=5000 \
357357
eval_every=100 \
358358
eval_data_dir=${EVAL_DATA_DIR} \
359-
enable_generate_video_for_eval=True \
360-
warmup_steps_fraction=0.025"
359+
enable_generate_video_for_eval=True" \
361360
--base-docker-image=${IMAGE_DIR} \
362361
--enable-debug-logs \
363362
--workload=${RUN_NAME} \

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
2-
jax>=0.6.2
2+
jax>=0.7.2
33
jaxlib>=0.4.30
44
grain
55
google-cloud-storage>=2.17.0
66
absl-py
77
datasets
8-
flax>=0.11.0
8+
flax>=0.12.0
99
optax>=0.2.3
1010
torch>=2.6.0
1111
torchvision>=0.20.1

requirements_with_jax_ai_image.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# Requirements for Building the MaxDifussion Docker Image
22
# These requirements are additional to the dependencies present in the JAX AI base image.
33
--extra-index-url https://download.pytorch.org/whl/cpu
4-
jax>=0.6.2
4+
jax>=0.7.2
55
jaxlib>=0.4.30
66
grain
77
google-cloud-storage>=2.17.0
88
absl-py
99
datasets
10-
flax>=0.10.2
10+
flax>=0.12.0
1111
optax>=0.2.3
1212
torch>=2.6.0
1313
torchvision>=0.20.1

0 commit comments

Comments
 (0)