|
19 | 19 | from datasets import load_dataset
|
20 | 20 | from parameterized import parameterized
|
21 | 21 | from transformers import AutoModelForCausalLM, AutoTokenizer
|
22 |
| -from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_peft, require_vision |
| 22 | +from transformers.testing_utils import ( |
| 23 | + require_bitsandbytes, |
| 24 | + require_flash_attn, |
| 25 | + require_liger_kernel, |
| 26 | + require_peft, |
| 27 | + require_vision, |
| 28 | +) |
23 | 29 | from transformers.utils import is_peft_available
|
24 | 30 |
|
25 | 31 | from trl import SFTConfig, SFTTrainer
|
@@ -1400,6 +1406,107 @@ def test_prompt_tuning(self):
|
1400 | 1406 | else:
|
1401 | 1407 | raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
|
1402 | 1408 |
|
| 1409 | + @require_peft |
| 1410 | + @require_bitsandbytes |
| 1411 | + def test_peft_model_with_quantization(self): |
| 1412 | + """SFTTrainer should not freeze layers of existing PeftModel. |
| 1413 | +
|
| 1414 | + This test simulates a realistic QLoRA scenario where a quantized base model |
| 1415 | + is first converted to a PeftModel, then passed to SFTTrainer. The issue was |
| 1416 | + that prepare_model_for_kbit_training would freeze all parameters including |
| 1417 | + the LoRA adapters, making training impossible. |
| 1418 | + """ |
| 1419 | + # Get the base model |
| 1420 | + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" |
| 1421 | + model = AutoModelForCausalLM.from_pretrained(model_id) |
| 1422 | + |
| 1423 | + # Simulate a realistic QLoRA setup by mocking quantization attributes |
| 1424 | + # This mimics what happens when loading a model with load_in_4bit=True |
| 1425 | + model.is_loaded_in_4bit = True |
| 1426 | + model.is_loaded_in_8bit = False |
| 1427 | + |
| 1428 | + # Verify that this triggers the is_qlora condition |
| 1429 | + is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) |
| 1430 | + self.assertTrue(is_qlora, "Model should be detected as QLoRA (quantized)") |
| 1431 | + |
| 1432 | + # Create LoRA configuration suitable for QLoRA |
| 1433 | + lora_config = LoraConfig( |
| 1434 | + task_type=TaskType.CAUSAL_LM, |
| 1435 | + target_modules=["q_proj", "v_proj"], |
| 1436 | + r=16, |
| 1437 | + lora_alpha=32, |
| 1438 | + lora_dropout=0.1, |
| 1439 | + ) |
| 1440 | + |
| 1441 | + # Convert the quantized model to a PeftModel (typical QLoRA workflow) |
| 1442 | + peft_model = get_peft_model(model, lora_config) |
| 1443 | + |
| 1444 | + # Verify the quantization attributes are preserved on the PeftModel |
| 1445 | + self.assertTrue(getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag") |
| 1446 | + |
| 1447 | + # Get the dataset |
| 1448 | + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") |
| 1449 | + |
| 1450 | + # Analyze parameters before SFTTrainer initialization |
| 1451 | + trainable_params_before = [] |
| 1452 | + base_params_before = [] |
| 1453 | + lora_params_before = [] |
| 1454 | + |
| 1455 | + for name, param in peft_model.named_parameters(): |
| 1456 | + if param.requires_grad: |
| 1457 | + trainable_params_before.append(name) |
| 1458 | + if "lora" in name.lower(): |
| 1459 | + lora_params_before.append(name) |
| 1460 | + else: |
| 1461 | + base_params_before.append(name) |
| 1462 | + |
| 1463 | + # Ensure we have the expected parameter distribution for QLoRA |
| 1464 | + self.assertTrue(len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially") |
| 1465 | + self.assertTrue(len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters") |
| 1466 | + self.assertEqual(len(base_params_before), 0, "Base model parameters should already be frozen in PeftModel") |
| 1467 | + |
| 1468 | + # Initialize the trainer with the already configured PeftModel |
| 1469 | + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1) |
| 1470 | + trainer = SFTTrainer(model=peft_model, args=training_args, train_dataset=dataset) |
| 1471 | + |
| 1472 | + # Analyze parameters after SFTTrainer initialization |
| 1473 | + trainable_params_after = [] |
| 1474 | + lora_params_after = [] |
| 1475 | + |
| 1476 | + for name, param in trainer.model.named_parameters(): |
| 1477 | + if param.requires_grad: |
| 1478 | + trainable_params_after.append(name) |
| 1479 | + if "lora" in name.lower(): |
| 1480 | + lora_params_after.append(name) |
| 1481 | + |
| 1482 | + # LoRA parameters should remain trainable |
| 1483 | + self.assertTrue( |
| 1484 | + len(trainable_params_after) > 0, |
| 1485 | + f"PeftModel should still have trainable parameters after SFTTrainer initialization. " |
| 1486 | + f"Found {len(trainable_params_after)} trainable params. " |
| 1487 | + f"This test fails without the fix for issue #3926.", |
| 1488 | + ) |
| 1489 | + |
| 1490 | + self.assertTrue( |
| 1491 | + len(lora_params_after) > 0, |
| 1492 | + f"LoRA adapter parameters should remain trainable. " |
| 1493 | + f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original.", |
| 1494 | + ) |
| 1495 | + |
| 1496 | + # Ensure the parameter counts are preserved (no additional freezing occurred) |
| 1497 | + self.assertEqual( |
| 1498 | + len(trainable_params_before), |
| 1499 | + len(trainable_params_after), |
| 1500 | + "Number of trainable parameters should not change after SFTTrainer initialization", |
| 1501 | + ) |
| 1502 | + |
| 1503 | + # Verify that all original LoRA parameters are still trainable |
| 1504 | + self.assertEqual( |
| 1505 | + set(lora_params_before), |
| 1506 | + set(lora_params_after), |
| 1507 | + "All original LoRA parameters should remain trainable after SFTTrainer initialization", |
| 1508 | + ) |
| 1509 | + |
1403 | 1510 | @require_peft
|
1404 | 1511 | def test_prompt_tuning_peft_model(self):
|
1405 | 1512 | """Test that SFT works with Prompt Tuning and a pre-converted PeftModel"""
|
|
0 commit comments