@@ -378,13 +378,13 @@ def _generate(model, processor, prompt, images=None, max_tokens=1000, verbose=Tr
378
378
logit_stopper = LogitStopper (max_tokens , early_stop )
379
379
streamer = Streamer (processor , stream , mute )
380
380
dict_input = processor (prompt , images )
381
+ mask , pids = dict_input .get ('mask' , None ), dict_input .get ('pids' , None )
381
382
token_stopper = TokenStopper (processor , dict_input ['input_ids' ].shape [0 ])
382
383
tic = Tic ()
383
384
logits , cache = model (** dict_input , max_tokens = max_tokens )
384
385
token = mx .argmax (logits [:, - 1 , :], axis = - 1 )[:,None ]
385
386
mx .eval (token , logits , cache )
386
387
streamer (token )
387
- mask , pids = dict_input .get ('mask' , None ), dict_input .get ('pids' , None )
388
388
prompt_time = tic ()
389
389
for i in range (max_tokens - 1 ):
390
390
logits , cache = model (input_ids = token , cache = cache , mask = mask , pids = pids )
@@ -402,7 +402,7 @@ def _generate(model, processor, prompt, images=None, max_tokens=1000, verbose=Tr
402
402
gen_tps = (gen_len - 1 ) / gen_time
403
403
if verbose :
404
404
print (f"\n Prompt: { prompt_tps :.2f} tokens-per-sec ({ prompt_len } tokens / { prompt_time :.1f} sec)" )
405
- print (f"Generation : { gen_tps :.2f} tokens-per-sec ({ gen_len } tokens / { gen_time :.1f} sec)" )
405
+ print (f"Generate : { gen_tps :.2f} tokens-per-sec ({ gen_len } tokens / { gen_time :.1f} sec)" )
406
406
if return_tps :
407
407
return prompt_tps , gen_tps
408
408
return result
@@ -496,24 +496,30 @@ def _already(array_2d, array_1d):
496
496
return mx .ones (array_2d .shape [0 ])
497
497
return ~ mx .all (array_2d [:, - len (array_1d ):] == array_1d , axis = 1 )
498
498
499
- def _beam (model , processor , prompt , constraints , return_full_text = False , mute = False , use_beam = False ):
499
+ def _constrain (model , processor , prompt , constraints , return_full_text = False , mute = False , use_beam = False , verbose = True ):
500
500
def _get_beam (logits , cache , id_constraint , beam_idx = 0 , n_beam = 3 ):
501
501
token = mx .argmax (logits [:, beam_idx , :], axis = - 1 )
502
502
_arg_beam = mx .argpartition (- logits [:, beam_idx , :], kth = n_beam , axis = - 1 )[:,:n_beam ]
503
503
_beam = _arg_beam .reshape (- 1 )[:,None ]
504
504
_beam = mx .concatenate ([_beam , mx .tile (id_constraint , (_beam .shape [0 ], 1 ))], axis = - 1 )
505
- _beam_logits , cache = model (input_ids = _beam , cache = cache , n_beam = n_beam )
505
+ _beam_logits , _ = model (input_ids = _beam , cache = cache , n_beam = n_beam , advance_offset = 0 )
506
506
_beam_logits = nn .log_softmax (_beam_logits )
507
507
_beam_score = mx .concatenate ([logits [mx .arange (_arg_beam .shape [0 ])[:,None ], beam_idx , _arg_beam ].reshape (- 1 )[:,None ], _beam_logits [mx .arange (_beam_logits .shape [0 ])[:,None ], mx .arange (_beam .shape [1 ]- 1 )[None ,:], _beam [:,1 :]]], axis = 1 )
508
508
_argmax_beam = mx .argmax (_beam_score .mean (axis = 1 ).reshape (- 1 ,n_beam ), axis = - 1 )
509
509
beam_token = _arg_beam [mx .arange (_argmax_beam .shape [0 ]), _argmax_beam ]
510
510
beam_score = _beam_score .reshape (logits .shape [0 ],n_beam , - 1 )[mx .arange (_argmax_beam .shape [0 ]), _argmax_beam ]
511
- return token , beam_token , beam_score , cache
511
+ mx .eval (token , beam_token , beam_score )
512
+ return token , beam_token , beam_score
512
513
if isinstance (prompt , str ):
513
514
_was_prompt_str = True
514
515
prompt = [prompt ]
515
516
else :
516
517
_was_prompt_str = False
518
+
519
+ tic = Tic ()
520
+ prompt_time = 0
521
+ constrain_time = 0
522
+
517
523
prompt = [_preprocess (s ) for s in prompt ]
518
524
len_ps = [len (p ) for p in prompt ]
519
525
synth_pad = mx .tile (mx .array ([ID_EOS ]), (len (prompt ), 1 ))
@@ -523,17 +529,18 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
523
529
dict_input = processor (prompt )
524
530
logits , cache = model (** dict_input , max_tokens = constraint [0 ] + id_constraint .shape [0 ]+ 10 )
525
531
logits = nn .log_softmax (logits , axis = - 1 )
532
+ mx .eval (logits , cache )
526
533
_score_0 = logits [:, - 1 , id_constraint [0 ]]
527
534
tiled_id_constraint = mx .tile (id_constraint , (logits .shape [0 ], 1 ))
528
- logits_rest , cache = model (input_ids = tiled_id_constraint , cache = cache , advance_offset = 0 )
535
+ logits_rest , _ = model (input_ids = tiled_id_constraint , cache = cache , advance_offset = 0 )
529
536
logits_rest = nn .log_softmax (logits_rest , axis = - 1 )
530
537
_score_1 = logits_rest [mx .arange (tiled_id_constraint .shape [0 ])[:,None ], mx .arange (tiled_id_constraint .shape [1 ]- 1 )[None ,:], tiled_id_constraint [:,1 :]]
531
538
running_score = mx .max (logits [:, - 1 , :], axis = - 1 )[:,None ]
532
539
pre_beam_score = mx .concatenate ([_score_0 [:,None ], _score_1 ], axis = 1 ).mean (axis = 1 )
533
540
pre_beam_synth = mx .concatenate ([tiled_id_constraint , synth_pad ], axis = 1 )
534
541
if use_beam :
535
- token , beam_token , beam_score , cache = _get_beam (logits , cache , id_constraint , - 1 )
536
- post_beam_score = mx .concatenate ([running_score , beam_score ], axis = 1 ).mean (axis = 1 )
542
+ token , beam_token , beam_score = _get_beam (logits , cache , id_constraint , - 1 )
543
+ post_beam_score = mx .concatenate ([running_score , beam_score ], axis = 1 ).mean (axis = 1 )
537
544
post_beam_synth = mx .concatenate ([beam_token [:,None ], tiled_id_constraint ], axis = 1 )
538
545
win = pre_beam_score > post_beam_score
539
546
score_sofar = mx .where (win , pre_beam_score , post_beam_score )
@@ -543,17 +550,20 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
543
550
score_sofar = pre_beam_score
544
551
synth_sofar = pre_beam_synth
545
552
token = token [:,None ]
553
+ mx .eval (token )
546
554
tokens = []
547
555
finished_rows = mx .ones (tiled_id_constraint .shape [0 ])
556
+ prompt_time += tic ()
548
557
for i in range (constraint [0 ]):
549
558
tokens .append (token )
550
559
token_plus = mx .concatenate ([token , tiled_id_constraint ], axis = 1 )
551
560
logits , cache = model (input_ids = token_plus , cache = cache , advance_offset = 1 )
552
561
logits = nn .log_softmax (logits )
562
+ mx .eval (logits , cache )
553
563
pre_beam_score = mx .concatenate ([running_score , logits [mx .arange (logits .shape [0 ])[:,None ], mx .arange (logits .shape [1 ]- 1 )[None ,:], token_plus [:,1 :]]], axis = 1 ).mean (axis = 1 )
554
564
pre_beam_synth = mx .concatenate (tokens + [tiled_id_constraint , synth_pad ], axis = 1 )
555
565
if use_beam :
556
- token , beam_token , beam_score , cache = _get_beam (logits , cache , id_constraint )
566
+ token , beam_token , beam_score = _get_beam (logits , cache , id_constraint )
557
567
post_beam_score = mx .concatenate ([running_score , beam_score ], axis = 1 ).mean (axis = 1 )
558
568
post_beam_synth = mx .concatenate (tokens + [beam_token [:,None ], tiled_id_constraint ], axis = 1 )
559
569
win = pre_beam_score > post_beam_score
@@ -574,6 +584,8 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
574
584
running_score = mx .concatenate ([running_score , logits [mx .arange (token .shape [0 ]),0 ,token ][:,None ]], axis = 1 )
575
585
finished_rows *= token != ID_EOS
576
586
token = token [:,None ]
587
+ mx .eval (token )
588
+ constrain_time += tic ()
577
589
output = mx .concatenate ([dict_input ['input_ids' ], synth_sofar ], axis = 1 ).tolist ()
578
590
S = dict_input ['input_ids' ].shape [1 ]
579
591
output = [(i [:i .index (ID_EOS ,S )] if ID_EOS in i [S :] else i ) for i in output ]
@@ -589,6 +601,8 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
589
601
else :
590
602
for i ,o in enumerate (output ):
591
603
print (f'\n < Constrained text for prompt #{ i } >\n { o } ' )
604
+ if verbose :
605
+ print (f'Prompt: { prompt_time :.2f} sec\n Constrain: { constrain_time :.2f} sec' )
592
606
if _was_prompt_str :
593
607
output = output [0 ]
594
608
return output
@@ -1119,7 +1133,7 @@ def _map(ds, map_args):
1119
1133
'q_col' :'input' ,
1120
1134
'q_until' :None ,
1121
1135
'q_format' :'' ,
1122
- 'fxn' :partial (_beam , model = model , processor = processor , constraints = [(100 , ' The correct answer is' ), (1 , 'X.' )], mute = True , use_beam = False ),
1136
+ 'fxn' :partial (_constrain , model = model , processor = processor , constraints = [(100 , ' The correct answer is' ), (1 , 'X.' )], verbose = False , mute = True , use_beam = False ),
1123
1137
'a_format' :'The correct answer is ' ,
1124
1138
'a_col' :'constrained_attempt' ,
1125
1139
'c_col' :'output' ,
@@ -1129,7 +1143,7 @@ def _map(ds, map_args):
1129
1143
'q_col' :'input' ,
1130
1144
'q_until' :None ,
1131
1145
'q_format' :'' ,
1132
- 'fxn' :partial (_beam , model = model , processor = processor , constraints = [(100 , ' The correct answer is' ), (1 , 'X.' )], mute = True , use_beam = True ),
1146
+ 'fxn' :partial (_constrain , model = model , processor = processor , constraints = [(100 , ' The correct answer is' ), (1 , 'X.' )], verbose = False , mute = True , use_beam = True ),
1133
1147
'a_format' :'The correct answer is ' ,
1134
1148
'a_col' :'beamed_attempt' ,
1135
1149
'c_col' :'output' ,
@@ -1448,7 +1462,7 @@ def constrain(prompt, constraints=[(30, ' The correct answer is'), (1, 'X.')], i
1448
1462
preload = load (blind_model = blind_model , quantize_model = quantize_model , quantize_cache = quantize_cache , use_adapter = use_adapter )
1449
1463
if apply_chat_template :
1450
1464
prompt = _apply_chat_template (prompt , None , verbose )[0 ]
1451
- return _beam (* preload , prompt = prompt , constraints = constraints , use_beam = use_beam )
1465
+ return _constrain (* preload , prompt = prompt , constraints = constraints , use_beam = use_beam , verbose = verbose )
1452
1466
1453
1467
1454
1468
def execute (code_strings , file_prefix = 0 , verbose = True ):
0 commit comments