This repo is used for @smartliuhw thesis's model training. The huggingface SFT trainer is used as the training framwork with deepspeed methodology to ensure the RTX 4090 GPU can be used properly.
The environment dependency is listed in the requirment file, just run the following command:
pip install -r requirements.txt
The data processing code is in the utils.py file, all the data should be stored with the Dataset
class. The function get_train_data
is the most important part, modify it accroding to your demand.
The model tran code is in the train.py file, using trl framework. Modify the args, templates, special tokens accroding to your demand.
You can launch a training by following the train example file, with only few changes about the model and the data.
If you have any question, feel free to ask me.