Skip to content

Commit d991e05

Browse files
authored
Fix typos (e.g., np.inf)
1 parent c0b83af commit d991e05

File tree

3 files changed

+44
-35
lines changed

3 files changed

+44
-35
lines changed

phi.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def __init__(self, config):
429429
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
430430
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
431431

432-
def __call__(self, x, cache, cos, sin, mask):
432+
def __call__(self, x, cache, cos, sin, mask, n_beam):
433433
@mx.compile
434434
def _rotate_half(x, cos, sin):
435435
midpoint = x.shape[-1] // 2
@@ -441,11 +441,18 @@ def _rotate_half(x, cos, sin):
441441
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
442442
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
443443
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
444+
445+
if n_beam > 1:
446+
sin = mx.repeat(sin, repeats=n_beam, axis=0)
447+
cos = mx.repeat(cos, repeats=n_beam, axis=0)
448+
mask = mx.repeat(mask, repeats=n_beam, axis=0)
449+
444450
queries = _rotate_half(queries, cos, sin)
445451
keys = _rotate_half(keys, cos, sin)
446-
keys, values = cache(keys, values)
452+
keys, values = cache(keys, values, n_beam)
447453
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
448454
scores += mask
455+
449456
scores = mx.softmax(scores, axis=-1)
450457
output = scores @ values
451458
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@@ -470,8 +477,8 @@ def __init__(self, config):
470477
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471478
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472479

473-
def __call__(self, x, cache, cos, sin, mask):
474-
r = self.self_attn(self.input_layernorm(x), cache, cos, sin, mask)
480+
def __call__(self, x, cache, cos, sin, mask, n_beam):
481+
r = self.self_attn(self.input_layernorm(x), cache, cos, sin, mask, n_beam)
475482
h = x + r
476483
r = self.mlp(self.post_attention_layernorm(h))
477484
return h + r
@@ -507,15 +514,14 @@ def __init__(self, config, x, max_tokens):
507514
self.kv = None
508515
else:
509516
self.kv = mx.zeros(shape, mx.float32)
510-
self.n_beam = None
511517

512-
def __call__(self, keys, values):
518+
def __call__(self, keys, values, n_beam):
513519
if self.max_tokens < 1:
514520
return keys, values
515-
if self.n_beam:
521+
if n_beam > 1:
516522
if self.use_quantized_cache:
517523
raise NotImplementedError('Beam Search is not yet compatible with Quantized Cache')
518-
kv = mx.repeat(self.kv[:,:,:,:self.offset,:], repeats=self.n_beam, axis=1)
524+
kv = mx.repeat(self.kv[:,:,:,:self.offset,:], repeats=n_beam, axis=1)
519525
return mx.concatenate([kv[0], keys], axis=-2), mx.concatenate([kv[1], values], axis=-2)
520526
B, N, L, D = keys.shape
521527
new_offset = self.offset + L
@@ -534,9 +540,6 @@ def __call__(self, keys, values):
534540
self.offset = new_offset
535541
return self.kv[0,:,:,:new_offset,:], self.kv[1,:,:,:new_offset,:]
536542

537-
def beam(self, n_beam):
538-
self.n_beam = n_beam
539-
540543
class Mask4D:
541544
def __init__(self, L_all, mask):
542545
mask_4d = mx.triu(mx.full((L_all, L_all), -mx.inf), k=1)[None, None]
@@ -545,9 +548,8 @@ def __init__(self, L_all, mask):
545548
mask = mx.pad(mask, ((0,0),(0,pad_len)), 1)
546549
mask = mx.expand_dims(mask, (1,2))
547550
mask = mask*mask.transpose(0,1,3,2)
548-
mask = mx.where(mask==1, 0, -np.inf)
551+
mask = mx.where(mask==1, 0, -mx.inf)
549552
mask_4d += mask
550-
mask_4d = mx.repeat(mask_4d, 32, axis=1)
551553
self.mask_4d = mask_4d
552554

553555
def __call__(self, past_L, L):
@@ -575,16 +577,9 @@ def __call__(self, input_ids, pixel_values, image_sizes, positions, cache, pids,
575577
past_L, new_L = cache[0].offset, x.shape[1]
576578
mask = self.masker(past_L, new_L)
577579
cos, sin = self.roper(past_L, new_L)
578-
if n_beam is not None:
579-
mask = mx.repeat(mask, repeats=n_beam, axis=0)
580-
cos = mx.repeat(cos, repeats=n_beam, axis=0)
581-
sin = mx.repeat(sin, repeats=n_beam, axis=0)
582-
[c.beam(n_beam) for c in cache]
583580
for i, l in enumerate(self.layers):
584-
x = l(x, cache[i], cos, sin, mask)
581+
x = l(x, cache[i], cos, sin, mask, n_beam)
585582

586-
if n_beam is not None:
587-
[c.beam(None) for c in cache]
588583
if advance_offset is not None:
589584
for c in cache:
590585
c.offset = past_L + advance_offset
@@ -602,7 +597,7 @@ def __init__(self, config):
602597
self.model = Phi3F(config)
603598
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
604599

605-
def __call__(self, input_ids, pixel_values=None, image_sizes=None, positions=None, cache=None, pids=None, mask=None, max_tokens=0, advance_offset=None, n_beam=None):
600+
def __call__(self, input_ids, pixel_values=None, image_sizes=None, positions=None, cache=None, pids=None, mask=None, max_tokens=0, advance_offset=None, n_beam=1):
606601
x, cache = self.model(input_ids, pixel_values, image_sizes, positions, cache, pids, mask, max_tokens, advance_offset, n_beam)
607602
return self.lm_head(x), cache
608603

phi_3_vision_mlx.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,13 @@ def _generate(model, processor, prompt, images=None, max_tokens=1000, verbose=Tr
378378
logit_stopper = LogitStopper(max_tokens, early_stop)
379379
streamer = Streamer(processor, stream, mute)
380380
dict_input = processor(prompt, images)
381+
mask, pids = dict_input.get('mask', None), dict_input.get('pids', None)
381382
token_stopper = TokenStopper(processor, dict_input['input_ids'].shape[0])
382383
tic = Tic()
383384
logits, cache = model(**dict_input, max_tokens=max_tokens)
384385
token = mx.argmax(logits[:, -1, :], axis=-1)[:,None]
385386
mx.eval(token, logits, cache)
386387
streamer(token)
387-
mask, pids = dict_input.get('mask', None), dict_input.get('pids', None)
388388
prompt_time = tic()
389389
for i in range(max_tokens-1):
390390
logits, cache = model(input_ids=token, cache=cache, mask=mask, pids=pids)
@@ -402,7 +402,7 @@ def _generate(model, processor, prompt, images=None, max_tokens=1000, verbose=Tr
402402
gen_tps = (gen_len - 1) / gen_time
403403
if verbose:
404404
print(f"\nPrompt: {prompt_tps:.2f} tokens-per-sec ({prompt_len} tokens / {prompt_time:.1f} sec)")
405-
print(f"Generation: {gen_tps:.2f} tokens-per-sec ({gen_len} tokens / {gen_time:.1f} sec)")
405+
print(f"Generate: {gen_tps:.2f} tokens-per-sec ({gen_len} tokens / {gen_time:.1f} sec)")
406406
if return_tps:
407407
return prompt_tps, gen_tps
408408
return result
@@ -496,24 +496,30 @@ def _already(array_2d, array_1d):
496496
return mx.ones(array_2d.shape[0])
497497
return ~mx.all(array_2d[:, -len(array_1d):] == array_1d, axis=1)
498498

499-
def _beam(model, processor, prompt, constraints, return_full_text=False, mute=False, use_beam=False):
499+
def _constrain(model, processor, prompt, constraints, return_full_text=False, mute=False, use_beam=False, verbose=True):
500500
def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
501501
token = mx.argmax(logits[:, beam_idx, :], axis=-1)
502502
_arg_beam = mx.argpartition(-logits[:, beam_idx, :], kth=n_beam, axis=-1)[:,:n_beam]
503503
_beam = _arg_beam.reshape(-1)[:,None]
504504
_beam = mx.concatenate([_beam, mx.tile(id_constraint, (_beam.shape[0], 1))], axis=-1)
505-
_beam_logits, cache = model(input_ids=_beam, cache=cache, n_beam=n_beam)
505+
_beam_logits, _ = model(input_ids=_beam, cache=cache, n_beam=n_beam, advance_offset=0)
506506
_beam_logits = nn.log_softmax(_beam_logits)
507507
_beam_score = mx.concatenate([logits[mx.arange(_arg_beam.shape[0])[:,None], beam_idx, _arg_beam].reshape(-1)[:,None], _beam_logits[mx.arange(_beam_logits.shape[0])[:,None], mx.arange(_beam.shape[1]-1)[None,:], _beam[:,1:]]], axis=1)
508508
_argmax_beam = mx.argmax(_beam_score.mean(axis=1).reshape(-1,n_beam), axis=-1)
509509
beam_token = _arg_beam[mx.arange(_argmax_beam.shape[0]), _argmax_beam]
510510
beam_score = _beam_score.reshape(logits.shape[0],n_beam, -1)[mx.arange(_argmax_beam.shape[0]), _argmax_beam]
511-
return token, beam_token, beam_score, cache
511+
mx.eval(token, beam_token, beam_score)
512+
return token, beam_token, beam_score
512513
if isinstance(prompt, str):
513514
_was_prompt_str = True
514515
prompt = [prompt]
515516
else:
516517
_was_prompt_str = False
518+
519+
tic = Tic()
520+
prompt_time = 0
521+
constrain_time = 0
522+
517523
prompt = [_preprocess(s) for s in prompt]
518524
len_ps = [len(p) for p in prompt]
519525
synth_pad = mx.tile(mx.array([ID_EOS]), (len(prompt), 1))
@@ -523,17 +529,18 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
523529
dict_input = processor(prompt)
524530
logits, cache = model(**dict_input, max_tokens=constraint[0] + id_constraint.shape[0]+10)
525531
logits = nn.log_softmax(logits, axis=-1)
532+
mx.eval(logits, cache)
526533
_score_0 = logits[:, -1, id_constraint[0]]
527534
tiled_id_constraint = mx.tile(id_constraint, (logits.shape[0], 1))
528-
logits_rest, cache = model(input_ids=tiled_id_constraint, cache=cache, advance_offset=0)
535+
logits_rest, _ = model(input_ids=tiled_id_constraint, cache=cache, advance_offset=0)
529536
logits_rest = nn.log_softmax(logits_rest, axis=-1)
530537
_score_1 = logits_rest[mx.arange(tiled_id_constraint.shape[0])[:,None], mx.arange(tiled_id_constraint.shape[1]-1)[None,:], tiled_id_constraint[:,1:]]
531538
running_score = mx.max(logits[:, -1, :], axis=-1)[:,None]
532539
pre_beam_score = mx.concatenate([_score_0[:,None], _score_1], axis=1).mean(axis=1)
533540
pre_beam_synth = mx.concatenate([tiled_id_constraint, synth_pad], axis=1)
534541
if use_beam:
535-
token, beam_token, beam_score, cache = _get_beam(logits, cache, id_constraint, -1)
536-
post_beam_score = mx.concatenate([running_score, beam_score], axis=1).mean(axis=1)
542+
token, beam_token, beam_score = _get_beam(logits, cache, id_constraint, -1)
543+
post_beam_score = mx.concatenate([running_score, beam_score], axis=1).mean(axis=1)
537544
post_beam_synth = mx.concatenate([beam_token[:,None], tiled_id_constraint], axis=1)
538545
win = pre_beam_score > post_beam_score
539546
score_sofar = mx.where(win, pre_beam_score, post_beam_score)
@@ -543,17 +550,20 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
543550
score_sofar = pre_beam_score
544551
synth_sofar = pre_beam_synth
545552
token = token[:,None]
553+
mx.eval(token)
546554
tokens = []
547555
finished_rows = mx.ones(tiled_id_constraint.shape[0])
556+
prompt_time += tic()
548557
for i in range(constraint[0]):
549558
tokens.append(token)
550559
token_plus = mx.concatenate([token, tiled_id_constraint], axis=1)
551560
logits, cache = model(input_ids=token_plus, cache=cache, advance_offset=1)
552561
logits = nn.log_softmax(logits)
562+
mx.eval(logits, cache)
553563
pre_beam_score = mx.concatenate([running_score, logits[mx.arange(logits.shape[0])[:,None], mx.arange(logits.shape[1]-1)[None,:], token_plus[:,1:]]], axis=1).mean(axis=1)
554564
pre_beam_synth = mx.concatenate(tokens + [tiled_id_constraint, synth_pad], axis=1)
555565
if use_beam:
556-
token, beam_token, beam_score, cache = _get_beam(logits, cache, id_constraint)
566+
token, beam_token, beam_score = _get_beam(logits, cache, id_constraint)
557567
post_beam_score = mx.concatenate([running_score, beam_score], axis=1).mean(axis=1)
558568
post_beam_synth = mx.concatenate(tokens + [beam_token[:,None], tiled_id_constraint], axis=1)
559569
win = pre_beam_score > post_beam_score
@@ -574,6 +584,8 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
574584
running_score = mx.concatenate([running_score, logits[mx.arange(token.shape[0]),0,token][:,None]], axis=1)
575585
finished_rows *= token != ID_EOS
576586
token = token[:,None]
587+
mx.eval(token)
588+
constrain_time += tic()
577589
output = mx.concatenate([dict_input['input_ids'], synth_sofar], axis=1).tolist()
578590
S = dict_input['input_ids'].shape[1]
579591
output = [(i[:i.index(ID_EOS,S)] if ID_EOS in i[S:] else i) for i in output]
@@ -589,6 +601,8 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
589601
else:
590602
for i,o in enumerate(output):
591603
print(f'\n< Constrained text for prompt #{i} >\n{o}')
604+
if verbose:
605+
print(f'Prompt: {prompt_time:.2f} sec\nConstrain: {constrain_time:.2f} sec')
592606
if _was_prompt_str:
593607
output = output[0]
594608
return output
@@ -1119,7 +1133,7 @@ def _map(ds, map_args):
11191133
'q_col':'input',
11201134
'q_until':None,
11211135
'q_format':'',
1122-
'fxn':partial(_beam, model=model, processor=processor, constraints=[(100, ' The correct answer is'), (1, 'X.')], mute=True, use_beam=False),
1136+
'fxn':partial(_constrain, model=model, processor=processor, constraints=[(100, ' The correct answer is'), (1, 'X.')], verbose=False, mute=True, use_beam=False),
11231137
'a_format':'The correct answer is ',
11241138
'a_col':'constrained_attempt',
11251139
'c_col':'output',
@@ -1129,7 +1143,7 @@ def _map(ds, map_args):
11291143
'q_col':'input',
11301144
'q_until':None,
11311145
'q_format':'',
1132-
'fxn':partial(_beam, model=model, processor=processor, constraints=[(100, ' The correct answer is'), (1, 'X.')], mute=True, use_beam=True),
1146+
'fxn':partial(_constrain, model=model, processor=processor, constraints=[(100, ' The correct answer is'), (1, 'X.')], verbose=False, mute=True, use_beam=True),
11331147
'a_format':'The correct answer is ',
11341148
'a_col':'beamed_attempt',
11351149
'c_col':'output',
@@ -1448,7 +1462,7 @@ def constrain(prompt, constraints=[(30, ' The correct answer is'), (1, 'X.')], i
14481462
preload = load(blind_model=blind_model, quantize_model=quantize_model, quantize_cache=quantize_cache, use_adapter=use_adapter)
14491463
if apply_chat_template:
14501464
prompt = _apply_chat_template(prompt, None, verbose)[0]
1451-
return _beam(*preload, prompt=prompt, constraints=constraints, use_beam=use_beam)
1465+
return _constrain(*preload, prompt=prompt, constraints=constraints, use_beam=use_beam, verbose=verbose)
14521466

14531467

14541468
def execute(code_strings, file_prefix=0, verbose=True):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
url='https://github.com/JosefAlbers/Phi-3-Vision-MLX',
99
py_modules=['phi_3_vision_mlx', 'gte', 'phi', 'api'],
1010
packages=find_packages(),
11-
version='0.0.9',
11+
version='0.1.0-alpha',
1212
readme="README.md",
1313
author_email="[email protected]",
1414
description="Phi-3-Vision on Apple silicon with MLX",

0 commit comments

Comments
 (0)