Skip to content

Commit 0dfa106

Browse files
updates
1 parent 65afe7b commit 0dfa106

File tree

5 files changed

+9
-2
lines changed

5 files changed

+9
-2
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@ var/
2525
# virtual environment
2626
.myenv
2727
.venv
28+
29+
.vscode

mlx_vlm/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
find_all_linear_names,
1313
get_peft_model,
1414
unfreeze_modules,
15+
supported_for_training
1516
)
1617
from .utils import load, load_image_processor
1718

@@ -25,9 +26,8 @@ def main(args):
2526
args.model_path, processor_config={"trust_remote_code": True}
2627
)
2728

28-
unsupported_for_training = {"lfm2-vl", "", ""}
2929
model_type = getattr(getattr(model, "config", None), "model_type", None)
30-
if model_type in unsupported_for_training:
30+
if model_type not in supported_for_training:
3131
raise ValueError(
3232
f"{Colors.FAIL}Model type {model_type} not supported for training. "
3333
f"Please choose a different model or remove it from the unsupported list.{Colors.ENDC}"

mlx_vlm/models/qwen2_vl/qwen2_vl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def __call__(
154154
return logits
155155

156156
def sanitize(self, weights):
157+
weights = {
158+
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
159+
}
157160
def transform_key(key):
158161
if "vision_tower" not in key:
159162
key = key.replace("visual", "vision_tower")

mlx_vlm/trainer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
get_peft_model,
99
print_trainable_parameters,
1010
Colors,
11+
supported_for_training
1112
)

mlx_vlm/trainer/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Colors:
2020
BOLD = '\033[1m'
2121
UNDERLINE = '\033[4m'
2222

23+
supported_for_training = {"qwen2_vl"}
2324

2425
def grad_checkpoint(layer):
2526
"""

0 commit comments

Comments
 (0)