Skip to content

Commit 830017c

Browse files
nits
1 parent d523293 commit 830017c

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

mlx_vlm/trainer/trainer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mlx.utils import tree_flatten, tree_map
1515
from tqdm import tqdm
1616

17-
from .utils import grad_checkpoint, Colors
17+
from .utils import grad_checkpoint, Colors, get_learning_rate
1818

1919
@dataclass
2020
class TrainingArgs:
@@ -52,20 +52,26 @@ class TrainingArgs:
5252
metadata={"help": "Learning rate."},
5353
)
5454
grad_clip: float = field(
55-
default=None,
55+
default=1.0,
5656
metadata={"help": "Gradient clipping value."},
5757
)
58+
warmup_steps: int = field(
59+
default=100,
60+
metadata={"help": "Number of warmup steps for learning rate."},
61+
)
62+
min_learning_rate: float = field(
63+
default=1e-6,
64+
metadata={"help": "Minimum learning rate after decay."},
65+
)
5866

5967

6068
def default_loss(model, inputs, targets, lengths, train_on_completions=False, assistant_id=77091):
6169
outputs = model(inputs)
6270
logits = outputs.logits.astype(mx.float32)
6371

64-
batch_size, seq_len = targets.shape
72+
_, seq_len = targets.shape
6573
steps = mx.arange(seq_len)[None, :]
66-
6774
base_mask = steps < lengths[:, None]
68-
6975
if train_on_completions:
7076
eq = (inputs == assistant_id)
7177
idxs = mx.arange(seq_len)[None, :]
@@ -82,7 +88,6 @@ def default_loss(model, inputs, targets, lengths, train_on_completions=False, as
8288
ce = ce.sum() / ntoks
8389
return ce, ntoks
8490

85-
8691
def iterate_batches(dataset, batch_size, max_seq_length, train=False):
8792
# Simple indices without sorting
8893
indices = list(range(len(dataset)))

mlx_vlm/trainer/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from pathlib import Path
3+
import math
34

45
import mlx.nn as nn
56
import mlx.core as mx
@@ -36,6 +37,15 @@ def inner_fn(params, *args, **kwargs):
3637
type(layer).__call__ = checkpointed_fn
3738

3839

40+
def get_learning_rate(iters: int, step: int, warmup_steps: int, learning_rate: float, min_learning_rate: float):
41+
if step < warmup_steps:
42+
return learning_rate * (step / warmup_steps)
43+
44+
progress = (step - warmup_steps) / (iters - warmup_steps)
45+
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
46+
return min_learning_rate + (learning_rate - min_learning_rate) * cosine_decay
47+
48+
3949
def get_module_by_name(model, name):
4050
parts = name.split(".")
4151
module = model

0 commit comments

Comments
 (0)