Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ MLX LM was developed with contributions from the following individuals:
HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen
(2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama
(3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`.
- Chimezie Ogbuji: Added support for `YAML` configuration for mlx_lm.lora, fewshot and apply chat template
for lm_eval, HF dataset collections, prompt-masking, `Min P` sampling, parameterized batching function to
trainer, LoRA on all linear (or by pattern), and Configurable LR schedulers
3 changes: 3 additions & 0 deletions mlx_lm/examples/lora_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ grad_checkpoint: false
lora_parameters:
# The layer keys to apply LoRA to.
# These will be applied for the last lora_layers
#Use ["all"] to have them applied to all linear layers
keys: ["self_attn.q_proj", "self_attn.v_proj"]
rank: 8
scale: 20.0
dropout: 0.0
#Match keys by key_patterns
#key_patterns: [".+gate_proj.*", ".+down_proj.+", ".+up_proj.+", ".+(q|v|k|o)_proj.+"]

# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
Expand Down
73 changes: 69 additions & 4 deletions mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright © 2024 Apple Inc.
import json
import re
import types
from functools import partial
from pathlib import Path
from typing import Dict
from typing import Dict, Pattern, Set

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -34,6 +36,54 @@ def build_schedule(schedule_config: Dict):
return bound_schedule_fn


def should_convert_to_lora(
layer_key: str,
module: nn.Module,
keys: Set[str] = None,
all_linear_layers: bool = False,
key_patterns: Set[Pattern] = None,
) -> bool:
"""
Determines whether a given module should be converted to a LoRA layer

Returns True if `layer_key` is in the set of keys, all_linear_layers is True
and the related module is a linear layer, or any of the patterns in key_patterns match layer_key

Args:
layer_key (str): The layer key for the module
module (nn.Module): The corresponding module
keys (set): The indicated layer keys to convert (if all_linear_layers is False)
all_linear_layers (bool): Whether or not to convert all linear layers (defaults to False).
key_patterns (set): Set of regex patterns to match against layer keys that should be converted to LoRA.

Returns:
bool: A boolean indicating whether the module should be converted to LoRA
"""
if key_patterns is None:
key_patterns = set()
if keys is None:
keys = set()

convertible_layers = (
nn.QuantizedLinear,
nn.Linear,
LoRASwitchLinear,
SwitchLinear,
QuantizedSwitchLinear,
)
is_convertible = isinstance(module, convertible_layers) or hasattr(
module, "to_lora"
)
should_convert = (
all_linear_layers
or layer_key in keys
or (any(p.match(layer_key) for p in key_patterns))
)
if (key_patterns or all_linear_layers) and should_convert:
print(f"Converting {layer_key} to LoRA")
return is_convertible and should_convert


def linear_to_lora_layers(
model: nn.Module,
num_layers: int,
Expand Down Expand Up @@ -82,6 +132,7 @@ def to_lora(layer):
)

keys = config.get("keys", None)
key_patterns = set([re.compile(p) for p in config.get("key_patterns", [])])
if keys is not None:
keys = set(keys)
elif model.model_type in {
Expand Down Expand Up @@ -215,13 +266,27 @@ def to_lora(layer):
keys.add("mixer.o_proj")
else:
raise ValueError(f"Lora does not support {model.model_type}")

all_linear_layers = "all" in keys
should_convert = partial(
should_convert_to_lora,
keys=keys,
all_linear_layers=all_linear_layers,
key_patterns=key_patterns,
)
if all_linear_layers:
print("Applying LoRA to all linear layers")
for l in model.layers[-max(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
lora_layers = [
(k, to_lora(m)) for k, m in l.named_modules() if should_convert(k, m)
]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))

lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
lora_modules = [
(k, to_lora(m))
for k, m in model.named_modules()
if should_convert_to_lora(k, m, keys, all_linear_layers=all_linear_layers)
]
if lora_modules:
model.update_modules(tree_unflatten(lora_modules))

Expand Down
61 changes: 57 additions & 4 deletions tests/test_finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright © 2024 Apple Inc.

import math
import re
import sys
import unittest
from contextlib import contextmanager
Expand All @@ -12,11 +12,11 @@
import mlx.optimizers as opt
from mlx.utils import tree_flatten

from mlx_lm import lora, tuner
from mlx_lm import tuner
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.lora import LoRAEmbedding, LoRASwitchLinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
from mlx_lm.tuner.utils import build_schedule, should_convert_to_lora


@contextmanager
Expand All @@ -27,6 +27,59 @@ def swapped_with_identity(obj, func):
setattr(obj, func, old_func)


class TestShouldConvertToLoRa(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput

def tearDown(self):
sys.stdout = sys.__stdout__

def test_all_linear_with_non_linear(self):
quantized_linear = MagicMock(spec=nn.QuantizedEmbedding)
self.assertFalse(
should_convert_to_lora("", quantized_linear, set(), all_linear_layers=True)
)

def test_all_linear_with_linear_and_switch(self):
linear = MagicMock(spec=nn.Linear)
switch_linear = MagicMock(spec=LoRASwitchLinear)
for layer in [linear, switch_linear]:
self.assertTrue(
should_convert_to_lora("", layer, set(), all_linear_layers=True)
)

def test_not_all_linear_with_empty_keys(self):
switch_linear = MagicMock(spec=LoRASwitchLinear)
self.assertFalse(
should_convert_to_lora("self_attn.q_proj", switch_linear, set())
)

def test_not_all_linear_pattern_key_matching(self):
linear = MagicMock(spec=nn.Linear)
self.assertFalse(
should_convert_to_lora(
"self_attn.q_proj", linear, set(), key_patterns=[re.compile(r"^.+mlp")]
)
)
self.assertTrue(
should_convert_to_lora(
"self_attn.q_proj",
linear,
set(),
key_patterns=[re.compile(r"^.+q_proj")],
)
)
self.assertTrue(
should_convert_to_lora(
"mlp.up_proj",
linear,
set(),
key_patterns=[re.compile(r"^.*mlp\.up.+$")],
)
)


class TestLora(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
Expand Down