Concept-aware fine-tuning (CAFT) encourages stronger conceptual understanding by incorporating multi-token prediction into fine-tuning.
git clone https://github.com/michaelchen-lab/caft-llm.git
cd caft-llm
pip install -e .- Create
.envfile withHUGGINGFACE_TOKEN=<token>and optionallyWANDB_TOKEN=<token> - Add
train_set.jsonlandeval_set.jsonlfiles toscripts/datasets/. Each instance should be of the format:
{
"id": "<int/str>", "status": "OK",
"conversation": [
{"role": "human", "content": "(prompt)"},
{"role": "assistant", "content": "(ground truth answer)"},
]
}Currently, only the auxiliary heads of meta-llama/Llama-3.1-8B-Instruct have been pretrained.
Method 1: Use the provided training script scripts/train.py
torchrun --nprod-per-node 1 scripts/train.py -ftm lora
torchrun --nprod-per-node 1 scripts/train.py -ftm lora -ft-heads -hpretrain
torchrun --nprod-per-node 1 scripts/train.py -ftm sft -lr 5e-6 -fr-unembed
torchrun --nprod-per-node 1 scripts/train.py -ftm sft -lr 5e-6 -fr-unembed -ft-heads -hpretrainSelected Arguments:
--model-name-or-path -model: Currently onlymeta-llama/Llama-3.1-8B-Instructis supported.--model-max-length -maxlen--finetune-method -ftm:loraorsft(full finetuning)--learning-rate -lr--epochs -e--freeze-unembedding -fr-unembed: Only applicable for full fine-tuning. Recommended:True--per-device-batch-size -micro-bs--gradient-accumulation-steps -grad-acc--heads-pretraining -hpretrain: Train auxiliary heads on your dataset for 1 epoch before apply CAFT to your model.-ft-headsmust also be set toTrue.
The full list of arguments can be found using this command:
python scripts/train.py --helpMethod 2: Integrate CAFT into your existing Transformers fine-tuning pipeline
import transformers
from caft import *
# Import your pretrained Transformers model, tokenizer, TrainingArguments, and data_module
add_auxiliary_heads(model)
add_caft_loss(transformers)
trainer = transformers.trainer.Trainer( # The additional CAFT functions track and save the auxiliary losses
model=model, tokenizer=tokenizer, args=model_training_args,
callbacks=[CAFTSaveLogging],
compute_metrics=caft_compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
**data_module
)Please refer to scripts/train.py for a complete implementation example.
- Download the train and validation dataset from this Huggingface repo and save to
scripts/datasets - Run the following command
torchrun nproc-per-node 4 scripts/train_aux_heads.pyWe welcome community contributions and feature requests for caft-llm. Feel free to open an issue or submit a pull request. If you have any questions or wish to collaborate, please contact [email protected].
- Support all model architectures.
Description
Currently, the `LlamaDecoderLayer` is used to create auxiliary heads; in other words, only Llama-based models are supported. Edit `core.py` to copy the last hidden layer of the given model instead of inserting `LlamaDecoderLayer`, then reinitialize the weights. - Support speculative decoding.
Description
Speculative decoding can be implemented using the same method as Gloeckle et al. (2024) and Stern et al. (2018). - Support FSDP and DeepSpeed
This codebase adapts code from several amazing projects, including Medusa and Facebook Multi-Token.