Skip to content

Commit

Permalink
Update TRL pin to 0.9.3+ (#213)
Browse files Browse the repository at this point in the history
Add sft config hack for trl upgrade



code formatter

Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks authored Jun 26, 2024
1 parent 8ddd68c commit 0949699
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"sentencepiece>=0.1.99,<0.3",
"tokenizers>=0.13.3,<1.0",
"tqdm>=4.66.2,<5.0",
"trl==0.8.6",
"trl>=0.9.3,<1.0",
"peft>=0.8.0,<0.13",
"datasets>=2.15.0,<3.0",
"fire>=0.5.0,<1.0",
Expand Down
22 changes: 20 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from typing import Dict, List, Optional, Union
import dataclasses
import json
import sys
import time
Expand All @@ -33,7 +34,7 @@
TrainerCallback,
)
from transformers.utils import is_accelerate_available, logging
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
import datasets
import fire
import transformers
Expand Down Expand Up @@ -315,6 +316,23 @@ def train(
model, train_args, modifiable_args=(peft_config,)
)

# HACK - The SFT Trainer has internal validation which inspects the name of the class
# being used for the HF training args; if it's a TrainingArguments class, which is
# presumably from transformers, it tries to build it into an SFT Config.
#
# This is unfortunately a naming collision with one of our own classes, which has extra
# fields, and therefore can't be used to initialize the SFT Config. For now, to sidestep
# this validation, we just drop the things that aren't part of the SFT Config and build one
# from our object directly. In the future, we should consider renaming this class and / or
# not adding things that are not directly used by the trainer instance to it.
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
transformer_kwargs = {
k: v
for k, v in train_args.to_dict().items()
if k in transformer_train_arg_fields
}
training_args = SFTConfig(**transformer_kwargs)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -323,7 +341,7 @@ def train(
packing=packing,
data_collator=data_collator,
dataset_text_field=data_args.dataset_text_field,
args=train_args,
args=training_args,
max_seq_length=max_seq_length,
callbacks=trainer_callbacks,
peft_config=peft_config,
Expand Down

0 comments on commit 0949699

Please sign in to comment.