Skip to content

Commit 7e20753

Browse files
HoesuHoesukashif
authored
Fix get_peft_model() so that prepare_model_for_kbit_training does not reapply to an instance of PeftModel, thus freezing all the layers (#4081)
Co-authored-by: Hoesu <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent 20cc58d commit 7e20753

File tree

2 files changed

+109
-2
lines changed

2 files changed

+109
-2
lines changed

tests/test_sft_trainer.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
from datasets import load_dataset
2020
from parameterized import parameterized
2121
from transformers import AutoModelForCausalLM, AutoTokenizer
22-
from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_peft, require_vision
22+
from transformers.testing_utils import (
23+
require_bitsandbytes,
24+
require_flash_attn,
25+
require_liger_kernel,
26+
require_peft,
27+
require_vision,
28+
)
2329
from transformers.utils import is_peft_available
2430

2531
from trl import SFTConfig, SFTTrainer
@@ -1400,6 +1406,107 @@ def test_prompt_tuning(self):
14001406
else:
14011407
raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
14021408

1409+
@require_peft
1410+
@require_bitsandbytes
1411+
def test_peft_model_with_quantization(self):
1412+
"""SFTTrainer should not freeze layers of existing PeftModel.
1413+
1414+
This test simulates a realistic QLoRA scenario where a quantized base model
1415+
is first converted to a PeftModel, then passed to SFTTrainer. The issue was
1416+
that prepare_model_for_kbit_training would freeze all parameters including
1417+
the LoRA adapters, making training impossible.
1418+
"""
1419+
# Get the base model
1420+
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
1421+
model = AutoModelForCausalLM.from_pretrained(model_id)
1422+
1423+
# Simulate a realistic QLoRA setup by mocking quantization attributes
1424+
# This mimics what happens when loading a model with load_in_4bit=True
1425+
model.is_loaded_in_4bit = True
1426+
model.is_loaded_in_8bit = False
1427+
1428+
# Verify that this triggers the is_qlora condition
1429+
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
1430+
self.assertTrue(is_qlora, "Model should be detected as QLoRA (quantized)")
1431+
1432+
# Create LoRA configuration suitable for QLoRA
1433+
lora_config = LoraConfig(
1434+
task_type=TaskType.CAUSAL_LM,
1435+
target_modules=["q_proj", "v_proj"],
1436+
r=16,
1437+
lora_alpha=32,
1438+
lora_dropout=0.1,
1439+
)
1440+
1441+
# Convert the quantized model to a PeftModel (typical QLoRA workflow)
1442+
peft_model = get_peft_model(model, lora_config)
1443+
1444+
# Verify the quantization attributes are preserved on the PeftModel
1445+
self.assertTrue(getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag")
1446+
1447+
# Get the dataset
1448+
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
1449+
1450+
# Analyze parameters before SFTTrainer initialization
1451+
trainable_params_before = []
1452+
base_params_before = []
1453+
lora_params_before = []
1454+
1455+
for name, param in peft_model.named_parameters():
1456+
if param.requires_grad:
1457+
trainable_params_before.append(name)
1458+
if "lora" in name.lower():
1459+
lora_params_before.append(name)
1460+
else:
1461+
base_params_before.append(name)
1462+
1463+
# Ensure we have the expected parameter distribution for QLoRA
1464+
self.assertTrue(len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially")
1465+
self.assertTrue(len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters")
1466+
self.assertEqual(len(base_params_before), 0, "Base model parameters should already be frozen in PeftModel")
1467+
1468+
# Initialize the trainer with the already configured PeftModel
1469+
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1)
1470+
trainer = SFTTrainer(model=peft_model, args=training_args, train_dataset=dataset)
1471+
1472+
# Analyze parameters after SFTTrainer initialization
1473+
trainable_params_after = []
1474+
lora_params_after = []
1475+
1476+
for name, param in trainer.model.named_parameters():
1477+
if param.requires_grad:
1478+
trainable_params_after.append(name)
1479+
if "lora" in name.lower():
1480+
lora_params_after.append(name)
1481+
1482+
# LoRA parameters should remain trainable
1483+
self.assertTrue(
1484+
len(trainable_params_after) > 0,
1485+
f"PeftModel should still have trainable parameters after SFTTrainer initialization. "
1486+
f"Found {len(trainable_params_after)} trainable params. "
1487+
f"This test fails without the fix for issue #3926.",
1488+
)
1489+
1490+
self.assertTrue(
1491+
len(lora_params_after) > 0,
1492+
f"LoRA adapter parameters should remain trainable. "
1493+
f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original.",
1494+
)
1495+
1496+
# Ensure the parameter counts are preserved (no additional freezing occurred)
1497+
self.assertEqual(
1498+
len(trainable_params_before),
1499+
len(trainable_params_after),
1500+
"Number of trainable parameters should not change after SFTTrainer initialization",
1501+
)
1502+
1503+
# Verify that all original LoRA parameters are still trainable
1504+
self.assertEqual(
1505+
set(lora_params_before),
1506+
set(lora_params_after),
1507+
"All original LoRA parameters should remain trainable after SFTTrainer initialization",
1508+
)
1509+
14031510
@require_peft
14041511
def test_prompt_tuning_peft_model(self):
14051512
"""Test that SFT works with Prompt Tuning and a pre-converted PeftModel"""

trl/models/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def prepare_peft_model(
535535
break
536536

537537
# Prepare model for kbit training if needed
538-
if is_qlora and not is_sharded_qlora:
538+
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
539539
model = prepare_model_for_kbit_training(
540540
model,
541541
use_gradient_checkpointing=args.gradient_checkpointing,

0 commit comments

Comments
 (0)