Skip to content

Commit ccb1247

Browse files
authored
Merge pull request #52 from AI-Hypercomputer/llama31_70b
Add Llama 3.1 70B training recipe
2 parents f99f9ed + 9f084a3 commit ccb1247

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Instructions for training Llama3.1-70B-MaxText on TPU trillium (v6e-256)
2+
3+
## XPK setup
4+
Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/XPK_README.md) to create your GKE cluster with XPK
5+
6+
## Prep for Maxtext
7+
8+
### Install MaxText and Build Docker Image
9+
Please follow this [link](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/trillium/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set:
10+
11+
In step 1, use the MaxText [tpu-recipes-v0.1.0](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.0) tag to run this recipe:
12+
```
13+
git checkout tpu-recipes-v0.1.0
14+
```
15+
16+
In step 2, use the jax-stable-stack image containing JAX 0.5.2:
17+
```
18+
BASE_IMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.5.2-rev1
19+
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=${BASE_IMAGE}
20+
```
21+
22+
## Run MaxText Llama3.1-70B workloads on GKE
23+
24+
### Starting workload
25+
26+
From the MaxText root directory, start your Llama3.1-70B workload
27+
```
28+
python3 benchmarks/benchmark_runner.py xpk \
29+
--project=$PROJECT \
30+
--zone=$ZONE \
31+
--device_type=v6e-256 \
32+
--num_slices=1 \
33+
--cluster_name=${CLUSTER_NAME} \
34+
--base_output_directory=${OUTPUT_DIR} \
35+
--model_name="llama3_1_70b_8192" \
36+
--base_docker_image=maxtext_base_image
37+
```
38+
39+
From your workload logs, you should start seeing step time logs like the following:
40+
```
41+
completed step: 7, seconds: 34.562, TFLOP/s/device: 456.442, Tokens/s/device: 948.086, total_weights: 8388608, loss: 8.946
42+
```
43+
If you would like to run on multiple slices of v6e-256, you may modify the `--num_slices` flag.
44+
45+
### Workload Details
46+
47+
For reference, here are the `llama3_1_70b_8192` workload details as found in `[email protected]`:
48+
49+
```
50+
MaxTextModel(
51+
model_name="llama3_1-70b-8192",
52+
model_type="llama3.1-70b",
53+
tuning_params={
54+
"per_device_batch_size": 4,
55+
"ici_fsdp_parallelism": -1,
56+
"remat_policy": "custom",
57+
"decoder_layer_input": "offload",
58+
"query_proj": "offload",
59+
"key_proj": "offload",
60+
"value_proj": "offload",
61+
"max_target_length": 8192,
62+
"attention": "flash",
63+
"use_iota_embed": True,
64+
"dataset_path": "gs://max-datasets-rogue",
65+
"dataset_type": "synthetic",
66+
"enable_checkpointing": False,
67+
"sa_block_q": 2048,
68+
"sa_block_kv": 2048,
69+
"sa_block_kv_compute": 2048,
70+
"sa_block_q_dkv": 2048,
71+
"sa_block_kv_dkv": 2048,
72+
"sa_block_kv_dkv_compute": 2048,
73+
"sa_block_q_dq": 2048,
74+
"sa_block_kv_dq": 2048,
75+
"sa_use_fused_bwd_kernel": True,
76+
"profiler": "xplane",
77+
"skip_first_n_steps_for_profiler": 10,
78+
"profiler_steps": 5,
79+
},
80+
xla_flags=(
81+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
82+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
83+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
84+
+ xla_flags_library.CF_FOR_ALL_GATHER
85+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
86+
),
87+
)
88+
```
89+
90+
This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/243b25e480f7550a0c389fa95cd3adcc716fe0df/benchmarks/maxtext_trillium_model_configs.py#L932-L972) file within the MaxText repository.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 benchmarks/benchmark_runner.py --project=$PROJECT --zone=$ZONE --device_type=v6e-256 --num_slices=1 --cluster_name=${CLUSTER_NAME} --base_output_directory=${OUTPUT_DIR} \
2+
--model_name="llama3_1_70b_8192" --base_docker_image maxtext_base_image

0 commit comments

Comments
 (0)