Skip to content

Commit

Permalink
Update notebook for auto flash attention detection
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Jun 12, 2024
1 parent 6a13337 commit 063d828
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@
"# This cannot be greater than model's max sequence length\n",
"max_sequence_length = launch_parameters.max_length\n",
"\n",
"# If to drop sequences that are longer than max_sequence_length\n",
"drop_long_sequences = False\n",
"\n",
"# Batch size per GPU. \n",
"# Increasing this will increase GPU memory requirement and training time\n",
"micro_batch_size = launch_parameters.batch_size\n",
Expand Down Expand Up @@ -353,7 +356,8 @@
"import torch\n",
"\n",
"# Mixed Precision Training. We automatically select the precision based on GPU capability\n",
"mixed_precision = \"bf16\" if torch.cuda.is_bf16_supported() else \"fp16\"\n",
"is_ampere_or_newer = torch.cuda.get_device_capability(device=0) >= (8, 0)\n",
"mixed_precision = \"bf16\" if is_ampere_or_newer and torch.cuda.is_bf16_supported() else \"fp16\"\n",
"\n",
"COMMAND = f\"\"\"\n",
"accelerate launch \\\n",
Expand All @@ -362,7 +366,6 @@
"train.py \\\n",
"config-base.yaml \\\n",
"--deepspeed ./deepspeed_configs/3_ds_z2_config.json \\\n",
"--flash_attention True \\\n",
"--gradient_checkpointing unsloth \\\n",
"--base_model {model_id} \\\n",
"--output_dir {output_dir} \\\n",
Expand All @@ -372,6 +375,7 @@
"--val_set_size {eval_size} \\\n",
"--max_steps {max_steps} \\\n",
"--sequence_len {max_sequence_length} \\\n",
"--drop_long_sequences {drop_long_sequences} \\\n",
"--train_on_inputs False \\\n",
"--sample_packing {sample_packing} \\\n",
"--pad_to_sequence_len True \\\n",
Expand Down
2 changes: 1 addition & 1 deletion notebook-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-r requirements.txt
jupyter-server-proxy==4.1.1
jupyter-server-proxy==4.1.2
jupyter_app_launcher @ git+https://github.com/truefoundry/jupyter_app_launcher@9a959b894542995fc763ed07324bbd274e96610d

0 comments on commit 063d828

Please sign in to comment.