Skip to content

Conversation

eghuzefa
Copy link

Add Direct Preference Optimization (DPO) Support

Summary

This PR implements Direct Preference Optimization (DPO) for MLX-LM, enabling users to fine-tune language models using human preference data without requiring a separate reward model.

What is DPO?

DPO is a simpler alternative to RLHF that directly optimizes on preference pairs (chosen vs rejected responses), avoiding the complexity of training reward models and using PPO. It's mathematically equivalent to RLHF but more stable and efficient.

Key Features Added

Core Implementation

  • DPO Loss Function: Bradley-Terry preference model with KL regularization
  • Preference Dataset: Handles both simple prompt format and chat conversation format
  • Dual Forward Passes: Policy and reference model evaluation during training
  • Beta Parameter: Configurable temperature for preference strength

Integration

  • CLI Interface: python -m mlx_lm dpo with full argument parsing
  • Fine-tuning Support: Compatible with LoRA, DoRA, and full fine-tuning
  • Memory Efficient: Works with quantized models (QLoRA)
  • MLX Native: Follows existing MLX-LM patterns and conventions

Data Format Support

{"prompt": "What is AI?", "chosen": "Detailed explanation...", "rejected": "Short answer."}
{"messages": [{"role": "user", "content": "Hello"}], "chosen": "Detailed response", "rejected": "Hi."}

Documentation & Testing

  • Comprehensive Documentation: mlx_lm/DPO.md with usage examples and best practices
  • Unit Tests: Full test coverage for loss functions, datasets, and integration
  • End-to-end Validation: Successfully trained models with real preference data

Usage Example

# Basic DPO training
mlx_lm.dpo \
    --model mlx-community/Meta-Llama-3-8B-Instruct-4bit \
    --train \
    --data preference_data/ \
    --beta 0.1 \
    --fine-tune-type lora

Files Added/Modified

  • mlx_lm/dpo.py - Main DPO module
  • mlx_lm/tuner/losses.py - DPO loss implementation
  • mlx_lm/tuner/datasets.py - PreferenceDataset class
  • mlx_lm/tuner/trainer.py - DPO training functions
  • mlx_lm/__main__.py - CLI registration
  • mlx_lm/DPO.md - Documentation
  • tests/test_dpo.py - DPO-specific tests
  • tests/test_losses.py - Enhanced with DPO loss tests

Testing

  • ✅ All existing tests pass
  • ✅ New DPO unit tests pass (9/9)
  • ✅ End-to-end training validation completed
  • ✅ Memory usage verified with quantized models

Benefits for MLX-LM Users

  1. Simplified Preference Training: No need for complex RLHF pipelines
  2. Better Alignment: Train models to follow human preferences effectively
  3. Resource Efficient: Works with existing LoRA/QLoRA infrastructure
  4. Production Ready: Stable, well-tested implementation

This implementation enables MLX-LM users to easily train more helpful, harmless, and honest models using preference data.

eghuzefa and others added 11 commits August 31, 2025 15:17
- Added DPO loss function in
- Introduced  in
- Extended trainer with DPO training step and dual forward passes
- Created  with main training logic (mirroring LoRA structure)
- Added CLI entry point in  and integrated DPO subcommand in
- Implemented unit and integration tests in
- Updated documentation with data formats, configs, metrics, and guidelines
- Ensured compliance with MLX-LM contributing standards (formatting, testing, pre-commit)
  - Implement DPO loss function with Bradley-Terry
  preference model
  - Add PreferenceDataset for handling chosen/rejected
  response pairs
  - Create DPO training pipeline with dual forward passes
  - Add CLI interface: Loading pretrained model
  - Support LoRA/DoRA/full fine-tuning with DPO
  - Include comprehensive test suite and documentation
  - Compatible with existing MLX-LM infrastructure
  - Add --reference-adapter-path parameter to DPO CLI
  - Enable using locally finetuned LoRA adapters as reference models
  - Leverages existing load() function's adapter_path parameter
  - Allows reusing MLX finetuned models for subsequent DPO training

  Usage:
  mlx_lm.dpo --model base_model --reference-model base_model     --reference-adapter-path path/to/adapters --data data.jsonl --train
  - Change --reference-adapter-path to --reference-model-adapters
  - Resolves parameter name confusion with adapter_path during config merge
  - Ensures reference model adapters are loaded correctly for DPO training
  - Update CONFIG_DEFAULTS and function references accordingly

  The original parameter name was too similar to adapter_path and caused
  the config loading logic to overwrite the reference adapter path with
  the current training adapter path.
…optimization

- Add preference accuracy and reward margin tracking to DPO evaluation
- Implement memory-efficient shared weights when no reference model specified
- Update validation and test reporting to include accuracy metrics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant