|
| 1 | +# Instructions for training Llama3.1-70B-MaxText on TPU trillium (v6e-256) |
| 2 | + |
| 3 | +## XPK setup |
| 4 | +Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK |
| 5 | + |
| 6 | +## Prep for Maxtext |
| 7 | + |
| 8 | +### Install MaxText and Build Docker Image |
| 9 | +Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set: |
| 10 | + |
| 11 | +In step 1, use the MaxText [tpu-recipes-v0.1.0](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.0) tag to run this recipe: |
| 12 | +``` |
| 13 | +git checkout tpu-recipes-v0.1.0 |
| 14 | +``` |
| 15 | + |
| 16 | +In step 2, use the jax-stable-stack image containing JAX 0.5.2: |
| 17 | +``` |
| 18 | +BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1 |
| 19 | +bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE} |
| 20 | +``` |
| 21 | + |
| 22 | +## Run MaxText Llama3.1-70B workloads on GKE |
| 23 | + |
| 24 | +### Starting workload |
| 25 | + |
| 26 | +From the MaxText root directory, start your Llama3.1-70B workload |
| 27 | +``` |
| 28 | +python3 benchmarks/benchmark_runner.py xpk \ |
| 29 | + --project=$PROJECT \ |
| 30 | + --zone=$ZONE \ |
| 31 | + --device_type=v6e-256 \ |
| 32 | + --num_slices=1 \ |
| 33 | + --cluster_name=${CLUSTER_NAME} \ |
| 34 | + --base_output_directory=${OUTPUT_DIR} \ |
| 35 | + --model_name="llama3_1_70b_8192" \ |
| 36 | + --base_docker_image=maxtext_base_image |
| 37 | +``` |
| 38 | + |
| 39 | +From your workload logs, you should start seeing step time logs like the following: |
| 40 | +``` |
| 41 | +completed step: 7, seconds: 34.562, TFLOP/s/device: 456.442, Tokens/s/device: 948.086, total_weights: 8388608, loss: 8.946 |
| 42 | +``` |
| 43 | +If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag. |
| 44 | + |
| 45 | +### Workload Details |
| 46 | + |
| 47 | +For reference, here are the `llama3_1_70b_8192` workload details as found in `[email protected]`: |
| 48 | + |
| 49 | +``` |
| 50 | + MaxTextModel( |
| 51 | + model_name="llama3_1-70b-8192", |
| 52 | + model_type="llama3.1-70b", |
| 53 | + tuning_params={ |
| 54 | + "per_device_batch_size": 4, |
| 55 | + "ici_fsdp_parallelism": -1, |
| 56 | + "remat_policy": "custom", |
| 57 | + "decoder_layer_input": "offload", |
| 58 | + "query_proj": "offload", |
| 59 | + "key_proj": "offload", |
| 60 | + "value_proj": "offload", |
| 61 | + "max_target_length": 8192, |
| 62 | + "attention": "flash", |
| 63 | + "use_iota_embed": True, |
| 64 | + "dataset_path": "gs://max-datasets-rogue", |
| 65 | + "dataset_type": "synthetic", |
| 66 | + "enable_checkpointing": False, |
| 67 | + "sa_block_q": 2048, |
| 68 | + "sa_block_kv": 2048, |
| 69 | + "sa_block_kv_compute": 2048, |
| 70 | + "sa_block_q_dkv": 2048, |
| 71 | + "sa_block_kv_dkv": 2048, |
| 72 | + "sa_block_kv_dkv_compute": 2048, |
| 73 | + "sa_block_q_dq": 2048, |
| 74 | + "sa_block_kv_dq": 2048, |
| 75 | + "sa_use_fused_bwd_kernel": True, |
| 76 | + "profiler": "xplane", |
| 77 | + "skip_first_n_steps_for_profiler": 10, |
| 78 | + "profiler_steps": 5, |
| 79 | + }, |
| 80 | + xla_flags=( |
| 81 | + xla_flags_library.DENSE_VMEM_LIMIT_FLAG |
| 82 | + + xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER |
| 83 | + + xla_flags_library.DATA_PARALLEL_OVERLAP |
| 84 | + + xla_flags_library.CF_FOR_ALL_GATHER |
| 85 | + + xla_flags_library.HOST_OFFLOAD_FLAGS |
| 86 | + ), |
| 87 | + ) |
| 88 | +``` |
| 89 | + |
| 90 | +This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/243b25e480f7550a0c389fa95cd3adcc716fe0df/benchmarks/maxtext_trillium_model_configs.py#L932-L972) file within the MaxText repository. |
0 commit comments