14
14
from mlx .utils import tree_flatten , tree_map
15
15
from tqdm import tqdm
16
16
17
- from .utils import grad_checkpoint , Colors
17
+ from .utils import grad_checkpoint , Colors , get_learning_rate
18
18
19
19
@dataclass
20
20
class TrainingArgs :
@@ -52,20 +52,26 @@ class TrainingArgs:
52
52
metadata = {"help" : "Learning rate." },
53
53
)
54
54
grad_clip : float = field (
55
- default = None ,
55
+ default = 1.0 ,
56
56
metadata = {"help" : "Gradient clipping value." },
57
57
)
58
+ warmup_steps : int = field (
59
+ default = 100 ,
60
+ metadata = {"help" : "Number of warmup steps for learning rate." },
61
+ )
62
+ min_learning_rate : float = field (
63
+ default = 1e-6 ,
64
+ metadata = {"help" : "Minimum learning rate after decay." },
65
+ )
58
66
59
67
60
68
def default_loss (model , inputs , targets , lengths , train_on_completions = False , assistant_id = 77091 ):
61
69
outputs = model (inputs )
62
70
logits = outputs .logits .astype (mx .float32 )
63
71
64
- batch_size , seq_len = targets .shape
72
+ _ , seq_len = targets .shape
65
73
steps = mx .arange (seq_len )[None , :]
66
-
67
74
base_mask = steps < lengths [:, None ]
68
-
69
75
if train_on_completions :
70
76
eq = (inputs == assistant_id )
71
77
idxs = mx .arange (seq_len )[None , :]
@@ -82,7 +88,6 @@ def default_loss(model, inputs, targets, lengths, train_on_completions=False, as
82
88
ce = ce .sum () / ntoks
83
89
return ce , ntoks
84
90
85
-
86
91
def iterate_batches (dataset , batch_size , max_seq_length , train = False ):
87
92
# Simple indices without sorting
88
93
indices = list (range (len (dataset )))
0 commit comments