Skip to content

SFTTrainer with PEFT model #4029

@lylaiyy

Description

@lylaiyy

Reproduction

Cannot use get_peft_model() externally to transfer to SFTTrainer, can only be used within SFTTrainer using peft_comfig

import copy
import os
from tqdm import tqdm
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
from trl import SFTTrainer
from peft import get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, prepare_model_for_kbit_training
import torch
from utils import *
from federated_learning import *
from config import get_config, save_config, get_training_args, get_model_config

# ===== Define the arguments =====
script_args, fed_args, peft_config = get_config()
training_args = get_training_args(script_args, script_args.learning_rate)
save_config(script_args, fed_args)
print(script_args, fed_args)

# ===== Load the dataset =====
dataset = get_dataset(script_args.dataset_name, script_args.local_data_dir)
dataset = process_sft_dataset(script_args.dataset_name, dataset)

device_map, quantization_config, torch_dtype = get_model_config(script_args)
model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name_or_path,
    quantization_config=quantization_config,
    device_map=device_map,
    dtype=torch_dtype
)

if script_args.load_in_8bit or script_args.load_in_4bit:
    model = prepare_model_for_kbit_training(model)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model.config.use_cache = False  
model.gradient_checkpointing_enable()
model.enable_input_require_grads()


# ===== Define the tokenizer =====
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token 


trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    # peft_config=peft_config,
)

trainer.model.print_trainable_parameters()

trainer.train()

outputs:

trainable params: 8,388,608 || all params: 6,746,804,224 || trainable%: 0.1243
trainable params: 0 || all params: 6,746,804,224 || trainable%: 0.0000
  0%|                                                                                                                                 | 0/10 [00:00<?, ?it/s]/home/asus/miniconda3/envs/llm_sft/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:186: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
/home/asus/miniconda3/envs/llm_sft/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:186: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization
  warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
{'loss': 1.0142, 'grad_norm': 0.0, 'learning_rate': 2.2222222222222223e-05, 'entropy': 1.5035455226898193, 'num_tokens': 35638.0, 'mean_token_accuracy': 0.71231250166893, 'epoch': 0.0}

System Info

  • Platform: Linux-6.8.0-60-generic-x86_64-with-glibc2.35
  • Python version: 3.11.13
  • TRL version: 0.22.2
  • PyTorch version: 2.8.0
  • accelerator(s): NVIDIA GeForce RTX 5090 D, NVIDIA GeForce RTX 5090 D
  • Transformers version: 4.56.0
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 3.6.0
  • HF Hub version: 0.34.4
  • bitsandbytes version: 0.47.0
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: 0.17.1
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions