Skip to content

BY571/sft-kl-lora-trainer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Custom SFT Trainer with KL divergence loss

This repo provides an extension of Hugging Face's trl.SFTTrainer that adds a KL divergence loss between a LoRA-adapted model and its base counterpart. It enables more stable and conservative fine-tuning by regularizing the adapted model's predictions against its original distribution.

Setup

conda env create -f environment.yml
conda activate custom_sft_loss

Custom Loss

The custom loss is implemented in the custom_trainer.py file. It extends the SFTTrainer class and overrides the compute_loss method to add a KL divergence loss term.

Training

The training script is implemented in the train.py file. You can compare the custom loss to the standard SFT loss.

About

Custom trl.SFTTrainer that adds a KL divergence loss between a LoRA-adapted model and its base model.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages