Skip to content

Commit 5640a60

Browse files
authored
Fix typos
1 parent 9736016 commit 5640a60

File tree

2 files changed

+58
-44
lines changed

2 files changed

+58
-44
lines changed

examples.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,6 @@
22

33
import phi_3_vision_mlx as pv
44

5-
# Decoding Strategies
6-
7-
## Multiple Choice Question Answering 1
8-
prompts = [
9-
"A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? A: Factor V Leiden B: Hemophilia A C: Lupus anticoagulant D: Protein C deficiency E: Von Willebrand disease",
10-
"A 25-year-old primigravida presents to her physician for a routine prenatal visit. She is at 34 weeks gestation, as confirmed by an ultrasound examination. She has no complaints, but notes that the new shoes she bought 2 weeks ago do not fit anymore. The course of her pregnancy has been uneventful and she has been compliant with the recommended prenatal care. Her medical history is unremarkable. She has a 15-pound weight gain since the last visit 3 weeks ago. Her vital signs are as follows: blood pressure, 148/90 mm Hg; heart rate, 88/min; respiratory rate, 16/min; and temperature, 36.6℃ (97.9℉). The blood pressure on repeat assessment 4 hours later is 151/90 mm Hg. The fetal heart rate is 151/min. The physical examination is significant for 2+ pitting edema of the lower extremity. Which of the following tests o should confirm the probable condition of this patient? A: Bilirubin assessment B: Coagulation studies C: Hematocrit assessment D: Leukocyte count with differential E: 24-hour urine protein"]
11-
pv.choose("What is the capital of France? A: London B: Berlin C: Paris D: Madrid E: Rome")
12-
13-
## Multiple Choice Question Answering 2
14-
pv.constrain(prompts, constraints=[(30, ' The correct answer is'), (10, 'X.')], blind_model=True, quantize_model=True)
15-
16-
## Code Generation
17-
pv.constrain("Write a Python function to calculate the Fibonacci sequence up to a given number n.", [(100, "\n```python\n"), (100, " return "), (200, "\n```")])
18-
195
# Train
206
pv.train_lora(
217
lora_layers=5, # Number of layers to apply LoRA
@@ -50,5 +36,20 @@
5036
pv.add_text('How to inspect API endpoints? @https://raw.githubusercontent.com/gradio-app/gradio/main/guides/08_gradio-clients-and-lite/01_getting-started-with-the-python-client.md')
5137
pv.rag('Comparison of Sortino Ratio for Bitcoin and Ethereum.')
5238

39+
# Decoding Strategies
40+
41+
## Multiple Choice Question Answering 1
42+
prompts = [
43+
"A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? A: Factor V Leiden B: Hemophilia A C: Lupus anticoagulant D: Protein C deficiency E: Von Willebrand disease",
44+
"A 25-year-old primigravida presents to her physician for a routine prenatal visit. She is at 34 weeks gestation, as confirmed by an ultrasound examination. She has no complaints, but notes that the new shoes she bought 2 weeks ago do not fit anymore. The course of her pregnancy has been uneventful and she has been compliant with the recommended prenatal care. Her medical history is unremarkable. She has a 15-pound weight gain since the last visit 3 weeks ago. Her vital signs are as follows: blood pressure, 148/90 mm Hg; heart rate, 88/min; respiratory rate, 16/min; and temperature, 36.6℃ (97.9℉). The blood pressure on repeat assessment 4 hours later is 151/90 mm Hg. The fetal heart rate is 151/min. The physical examination is significant for 2+ pitting edema of the lower extremity. Which of the following tests o should confirm the probable condition of this patient? A: Bilirubin assessment B: Coagulation studies C: Hematocrit assessment D: Leukocyte count with differential E: 24-hour urine protein"]
45+
46+
pv.choose(prompts, choices='ABCDE')
47+
48+
## Multiple Choice Question Answering 2
49+
pv.constrain(prompts, constraints=[(30, ' The correct answer is'), (10, 'X.')], blind_model=True, quantize_model=True)
50+
51+
## Code Generation
52+
pv.constrain("Write a Python function to calculate the Fibonacci sequence up to a given number n.", [(100, "\n```python\n"), (100, " return "), (200, "\n```")])
53+
5354
# Benchmark
5455
pv.benchmark()

phi_3_vision_mlx.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -462,22 +462,21 @@ def _get_adapter_path(model_path):
462462
print(f'{PATH_ADAPTERS}/{Path(model_path).name}')
463463
return f'{PATH_ADAPTERS}/{Path(model_path).name}'
464464

465-
def _score(model, processor, prompts):
466-
dict_input = processor(prompts)
467-
logits, _ = model(**dict_input, max_tokens=0)
468-
logits = nn.log_softmax(logits)
469-
input_ids = dict_input['input_ids']
470-
mask = dict_input['mask']
471-
batch_size, seq_length, vocab_size = logits.shape
472-
row_indices = mx.arange(batch_size)[:, None]
473-
col_indices = mx.arange(seq_length - 1)[None, :]
474-
token_indices = input_ids[:, 1:]
475-
next_token_logits = logits[row_indices, col_indices, token_indices]
476-
masked_logits = next_token_logits * mask[:, 1:]
477-
logit_sums = masked_logits.sum(axis=1)
478-
return logit_sums
479-
480465
def _choose(model, processor, prompts, appends=None, return_idx=False):
466+
def _process_and_score(model, processor, prompts):
467+
dict_input = processor(prompts)
468+
logits, _ = model(**dict_input, max_tokens=0)
469+
logits = nn.log_softmax(logits)
470+
input_ids = dict_input['input_ids']
471+
mask = dict_input['mask']
472+
batch_size, seq_length, vocab_size = logits.shape
473+
row_indices = mx.arange(batch_size)[:, None]
474+
col_indices = mx.arange(seq_length - 1)[None, :]
475+
token_indices = input_ids[:, 1:]
476+
next_token_logits = logits[row_indices, col_indices, token_indices]
477+
masked_logits = next_token_logits * mask[:, 1:]
478+
logit_sums = masked_logits.sum(axis=1)
479+
return logit_sums
481480
if isinstance(appends, list):
482481
prompts = [prompt + str(a) for prompt in prompts for a in appends]
483482
scores = _score(model, processor, prompts)
@@ -495,7 +494,7 @@ def _choose(model, processor, prompts, appends=None, return_idx=False):
495494
return scores
496495
return [choices[i] for i in scores]
497496

498-
def _choose_from(model, processor, prompt, choices='ABCDE', mute=True):
497+
def _choose_from(model, processor, prompt, choices='ABCDE', mute=False):
499498
def _ord(s):
500499
return processor([f' {i}' for i in s])['input_ids'][:,-1]
501500
if isinstance(prompt, str):
@@ -595,8 +594,20 @@ def _constrain(model, processor, prompt, constraints, return_full_text=False, mu
595594
output = output[0]
596595
return output
597596

597+
def _score(input_ids, mask, logits):
598+
batch_size, seq_length, vocab_size = logits.shape
599+
row_indices = mx.arange(batch_size)[:, None]
600+
col_indices = mx.arange(seq_length - 1)[None, :]
601+
token_indices = input_ids[:, 1:]
602+
next_token_logits = logits[row_indices, col_indices, token_indices]
603+
masked_logits = next_token_logits * mask[:, :-1]
604+
logit_sums = masked_logits.sum(axis=1)
605+
mask_counts = mask[:, :-1].sum(axis=1)
606+
return logit_sums, mask_counts
607+
598608
def _beam(model, processor, prompt, constraints, return_full_text=False, mute=False):
599609
def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
610+
token = mx.argmax(logits[:, 0, :], axis=-1)
600611
_B, _S, _V = logits.shape
601612
_arg_beam = mx.argpartition(-logits[:, beam_idx, :], kth=n_beam, axis=-1)[:,:n_beam]
602613
_beam = _arg_beam.reshape(-1)[:,None]
@@ -607,8 +618,8 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
607618
_beam_score = mx.concatenate([logits[mx.arange(_arg_beam.shape[0])[:,None],beam_idx,_arg_beam].reshape(-1)[:,None], _beam_score], axis=1).mean(axis=1)
608619
_beam_score = _beam_score.reshape(-1,n_beam)
609620
_argmax_beam = mx.argmax(_beam_score, axis=-1)
610-
token = _arg_beam[mx.arange(len(_argmax_beam)), _argmax_beam]
611-
return token, cache
621+
beam_token = _arg_beam[mx.arange(len(_argmax_beam)), _argmax_beam]
622+
return token, beam_token, cache
612623
if isinstance(prompt, str):
613624
_was_prompt_str = True
614625
prompt = [prompt]
@@ -623,13 +634,14 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
623634
logits, cache = model(**dict_input, max_tokens=constraint[0] + id_constraint.shape[0]+10)
624635
logits = nn.log_softmax(logits, axis=-1)
625636
token = mx.argmax(logits[:, -1, :], axis=-1)[:,None]
637+
_sum_init, _mask_init = _score(dict_input['input_ids'], dict_input['mask'], logits)
626638
running_score = mx.max(logits[:, -1, :], axis=-1)[:,None]
627639
_score_0 = logits[:, -1, id_constraint[0]]
628640
tiled_id_constraint = mx.tile(id_constraint, (token.shape[0], 1))
629641
logits_rest, cache = model(input_ids=tiled_id_constraint, cache=cache, advance_offset=0)
630642
logits_rest = nn.log_softmax(logits_rest, axis=-1)
631643
_score_1 = logits_rest[mx.arange(tiled_id_constraint.shape[0])[:,None], mx.arange(tiled_id_constraint.shape[1]-1)[None,:], tiled_id_constraint[:,1:]]
632-
score_sofar = mx.concatenate([_score_0[:,None], _score_1], axis = 1).mean(axis=1)
644+
score_sofar = mx.concatenate([_sum_init[:,None], _score_0[:,None], _score_1], axis=1).sum(axis=1) / (_mask_init + id_constraint.shape[0])
633645
synth_sofar = tiled_id_constraint
634646
synth_pad = mx.tile(mx.array([ID_EOS]), (tiled_id_constraint.shape[0], 1))
635647
synth = tiled_id_constraint
@@ -641,12 +653,9 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
641653
token_plus = mx.concatenate([token, tiled_id_constraint], axis=1)
642654
logits, cache = model(input_ids=token_plus, cache=cache, advance_offset=1)
643655
logits = nn.log_softmax(logits)
644-
token, cache = _get_beam(logits, cache, id_constraint)
645-
finished_rows *= token != ID_EOS
656+
token, beam_token, cache = _get_beam(logits, cache, id_constraint)
646657
_synth_score = logits[mx.arange(tiled_id_constraint.shape[0])[:,None], mx.arange(tiled_id_constraint.shape[1])[None,:], tiled_id_constraint]
647658
score = mx.concatenate([running_score, _synth_score], axis=1).mean(axis=1)
648-
running_score = mx.concatenate([running_score, logits[mx.arange(token.shape[0]),0,token][:,None]], axis=1)
649-
token = token[:,None]
650659
synth_sofar = mx.concatenate([synth_sofar, synth_pad], axis=1)
651660
finished_rows *= _already(mx.concatenate(tokens, axis=1), id_constraint)
652661
if finished_rows.sum() < 1:
@@ -655,6 +664,10 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
655664
rows_to_update *= finished_rows
656665
synth_sofar = mx.where(rows_to_update[:,None], synth, synth_sofar)
657666
score_sofar = mx.where(rows_to_update, score, score_sofar)
667+
token = mx.where(rows_to_update, beam_token, token)
668+
running_score = mx.concatenate([running_score, logits[mx.arange(token.shape[0]),0,token][:,None]], axis=1)
669+
finished_rows *= token != ID_EOS
670+
token = token[:,None]
658671
output = mx.concatenate([dict_input['input_ids'], synth_sofar], axis=1).tolist()
659672
S = dict_input['input_ids'].shape[1]
660673
output = [(i[:i.index(ID_EOS,S)] if ID_EOS in i[S:] else i) for i in output]
@@ -1200,7 +1213,7 @@ def _map(ds, map_args):
12001213
'q_col':'input',
12011214
'q_until':None,
12021215
'q_format':'',
1203-
'fxn':partial(_constrain, model=model, processor=processor, constraints=[(30, ' The correct answer is'), (10, 'X.')], mute=True),
1216+
'fxn':partial(_constrain, model=model, processor=processor, constraints=[(100, ' The correct answer is'), (10, 'X.')], mute=True),
12041217
'a_format':'The correct answer is ',
12051218
'a_col':'constrained_attempt',
12061219
'c_col':'output',
@@ -1210,7 +1223,7 @@ def _map(ds, map_args):
12101223
'q_col':'input',
12111224
'q_until':None,
12121225
'q_format':'',
1213-
'fxn':partial(_beam, model=model, processor=processor, constraints=[(30, ' The correct answer is'), (10, 'X.')], mute=True),
1226+
'fxn':partial(_beam, model=model, processor=processor, constraints=[(100, ' The correct answer is'), (10, 'X.')], mute=True),
12141227
'a_format':'The correct answer is ',
12151228
'a_col':'beamed_attempt',
12161229
'c_col':'output',
@@ -1425,7 +1438,7 @@ def generate(prompt, images=None, preload=None, blind_model=False, quantize_mode
14251438
preload = load(blind_model=blind_model, quantize_model=quantize_model, quantize_cache=quantize_cache, use_adapter=use_adapter)
14261439
return _generate(*preload, *_apply_chat_template(prompt, images, verbose, apply_chat_template), max_tokens=max_tokens, verbose=verbose, return_tps=return_tps, early_stop=early_stop, stream=stream)
14271440

1428-
def choose(prompt, choices='ABCDE', images=None, preload=None, blind_model=False, quantize_model=False, quantize_cache=False, use_adapter=False, apply_chat_template=True):
1441+
def choose(prompt, choices='ABCDE', images=None, preload=None, blind_model=False, quantize_model=False, quantize_cache=False, use_adapter=False, verbose=True, apply_chat_template=True):
14291442
"""
14301443
Choose the best option from a set of choices for a given prompt.
14311444
@@ -1469,10 +1482,10 @@ def choose(prompt, choices='ABCDE', images=None, preload=None, blind_model=False
14691482
if preload is None:
14701483
preload = load(blind_model=blind_model, quantize_model=quantize_model, quantize_cache=quantize_cache, use_adapter=use_adapter)
14711484
if apply_chat_template:
1472-
prompt, _ = _apply_chat_template(prompt, images, False)
1473-
return _choose_from(model, processor, prompt, choices)
1485+
prompt, _ = _apply_chat_template(prompt, images, verbose)
1486+
return _choose_from(*preload, prompt=prompt, choices=choices)
14741487

1475-
def constrain(prompt, constraints=[(30, ' The correct answer is'), (10, 'X.')], images=None, preload=None, blind_model=False, quantize_model=False, quantize_cache=False, use_adapter=False, apply_chat_template=True, use_beam=False):
1488+
def constrain(prompt, constraints=[(30, ' The correct answer is'), (10, 'X.')], images=None, preload=None, blind_model=False, quantize_model=False, quantize_cache=False, use_adapter=False, verbose=True, apply_chat_template=True, use_beam=False):
14761489
"""
14771490
Perform constrained decoding on the given prompt using specified constraints.
14781491
@@ -1528,7 +1541,7 @@ def constrain(prompt, constraints=[(30, ' The correct answer is'), (10, 'X.')],
15281541
if preload is None:
15291542
preload = load(blind_model=blind_model, quantize_model=quantize_model, quantize_cache=quantize_cache, use_adapter=use_adapter)
15301543
if apply_chat_template:
1531-
prompt = _apply_chat_template(prompt, None, False)[0]
1544+
prompt = _apply_chat_template(prompt, None, verbose)[0]
15321545
if use_beam:
15331546
return _beam(*preload, prompt=prompt, constraints=constraints)
15341547
return _constrain(*preload, prompt=prompt, constraints=constraints)

0 commit comments

Comments
 (0)