diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 19db5b2e3..1f0b35dd2 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -22,3 +22,6 @@ MLX LM was developed with contributions from the following individuals: HuggingFace's `Starcoder2`, Cohere's `Cohere (1 and 2)`, Alibaba Qwen's `Qwen (2, 3 and MoE)`, Microsoft's `Phi (3 and 3.5 MoE)`, `BitNet1.58`, Meta's `Llama (3 and 4)`, Google DeepMind's `Gemma 3`, and InterLM's `InternLM 2.5`. +- Chimezie Ogbuji: Added support for `YAML` configuration for mlx_lm.lora, fewshot and apply chat template + for lm_eval, HF dataset collections, prompt-masking, `Min P` sampling, parameterized batching function to + trainer, LoRA on all linear (or by pattern), and Configurable LR schedulers diff --git a/mlx_lm/examples/lora_config.yaml b/mlx_lm/examples/lora_config.yaml index b5fe2d524..aae2df1e7 100644 --- a/mlx_lm/examples/lora_config.yaml +++ b/mlx_lm/examples/lora_config.yaml @@ -72,10 +72,13 @@ grad_checkpoint: false lora_parameters: # The layer keys to apply LoRA to. # These will be applied for the last lora_layers + #Use ["all"] to have them applied to all linear layers keys: ["self_attn.q_proj", "self_attn.v_proj"] rank: 8 scale: 20.0 dropout: 0.0 + #Match keys by key_patterns + #key_patterns: [".+gate_proj.*", ".+down_proj.+", ".+up_proj.+", ".+(q|v|k|o)_proj.+"] # Schedule can only be specified in a config file, uncomment to use. #lr_schedule: diff --git a/mlx_lm/tuner/utils.py b/mlx_lm/tuner/utils.py index 39eb0a9fd..e0e46222d 100644 --- a/mlx_lm/tuner/utils.py +++ b/mlx_lm/tuner/utils.py @@ -1,8 +1,10 @@ # Copyright © 2024 Apple Inc. import json +import re import types +from functools import partial from pathlib import Path -from typing import Dict +from typing import Dict, Pattern, Set import mlx.core as mx import mlx.nn as nn @@ -34,6 +36,54 @@ def build_schedule(schedule_config: Dict): return bound_schedule_fn +def should_convert_to_lora( + layer_key: str, + module: nn.Module, + keys: Set[str] = None, + all_linear_layers: bool = False, + key_patterns: Set[Pattern] = None, +) -> bool: + """ + Determines whether a given module should be converted to a LoRA layer + + Returns True if `layer_key` is in the set of keys, all_linear_layers is True + and the related module is a linear layer, or any of the patterns in key_patterns match layer_key + + Args: + layer_key (str): The layer key for the module + module (nn.Module): The corresponding module + keys (set): The indicated layer keys to convert (if all_linear_layers is False) + all_linear_layers (bool): Whether or not to convert all linear layers (defaults to False). + key_patterns (set): Set of regex patterns to match against layer keys that should be converted to LoRA. + + Returns: + bool: A boolean indicating whether the module should be converted to LoRA + """ + if key_patterns is None: + key_patterns = set() + if keys is None: + keys = set() + + convertible_layers = ( + nn.QuantizedLinear, + nn.Linear, + LoRASwitchLinear, + SwitchLinear, + QuantizedSwitchLinear, + ) + is_convertible = isinstance(module, convertible_layers) or hasattr( + module, "to_lora" + ) + should_convert = ( + all_linear_layers + or layer_key in keys + or (any(p.match(layer_key) for p in key_patterns)) + ) + if (key_patterns or all_linear_layers) and should_convert: + print(f"Converting {layer_key} to LoRA") + return is_convertible and should_convert + + def linear_to_lora_layers( model: nn.Module, num_layers: int, @@ -82,6 +132,7 @@ def to_lora(layer): ) keys = config.get("keys", None) + key_patterns = set([re.compile(p) for p in config.get("key_patterns", [])]) if keys is not None: keys = set(keys) elif model.model_type in { @@ -215,13 +266,27 @@ def to_lora(layer): keys.add("mixer.o_proj") else: raise ValueError(f"Lora does not support {model.model_type}") - + all_linear_layers = "all" in keys + should_convert = partial( + should_convert_to_lora, + keys=keys, + all_linear_layers=all_linear_layers, + key_patterns=key_patterns, + ) + if all_linear_layers: + print("Applying LoRA to all linear layers") for l in model.layers[-max(num_layers, 0) :]: - lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] + lora_layers = [ + (k, to_lora(m)) for k, m in l.named_modules() if should_convert(k, m) + ] if lora_layers: l.update_modules(tree_unflatten(lora_layers)) - lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys] + lora_modules = [ + (k, to_lora(m)) + for k, m in model.named_modules() + if should_convert_to_lora(k, m, keys, all_linear_layers=all_linear_layers) + ] if lora_modules: model.update_modules(tree_unflatten(lora_modules)) diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 76639ad5e..d42c374ad 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -1,6 +1,6 @@ # Copyright © 2024 Apple Inc. - import math +import re import sys import unittest from contextlib import contextmanager @@ -12,11 +12,11 @@ import mlx.optimizers as opt from mlx.utils import tree_flatten -from mlx_lm import lora, tuner +from mlx_lm import tuner from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear -from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear +from mlx_lm.tuner.lora import LoRAEmbedding, LoRASwitchLinear from mlx_lm.tuner.trainer import evaluate -from mlx_lm.tuner.utils import build_schedule +from mlx_lm.tuner.utils import build_schedule, should_convert_to_lora @contextmanager @@ -27,6 +27,59 @@ def swapped_with_identity(obj, func): setattr(obj, func, old_func) +class TestShouldConvertToLoRa(unittest.TestCase): + def setUp(self): + self.capturedOutput = StringIO() + sys.stdout = self.capturedOutput + + def tearDown(self): + sys.stdout = sys.__stdout__ + + def test_all_linear_with_non_linear(self): + quantized_linear = MagicMock(spec=nn.QuantizedEmbedding) + self.assertFalse( + should_convert_to_lora("", quantized_linear, set(), all_linear_layers=True) + ) + + def test_all_linear_with_linear_and_switch(self): + linear = MagicMock(spec=nn.Linear) + switch_linear = MagicMock(spec=LoRASwitchLinear) + for layer in [linear, switch_linear]: + self.assertTrue( + should_convert_to_lora("", layer, set(), all_linear_layers=True) + ) + + def test_not_all_linear_with_empty_keys(self): + switch_linear = MagicMock(spec=LoRASwitchLinear) + self.assertFalse( + should_convert_to_lora("self_attn.q_proj", switch_linear, set()) + ) + + def test_not_all_linear_pattern_key_matching(self): + linear = MagicMock(spec=nn.Linear) + self.assertFalse( + should_convert_to_lora( + "self_attn.q_proj", linear, set(), key_patterns=[re.compile(r"^.+mlp")] + ) + ) + self.assertTrue( + should_convert_to_lora( + "self_attn.q_proj", + linear, + set(), + key_patterns=[re.compile(r"^.+q_proj")], + ) + ) + self.assertTrue( + should_convert_to_lora( + "mlp.up_proj", + linear, + set(), + key_patterns=[re.compile(r"^.*mlp\.up.+$")], + ) + ) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO()