Skip to content

Commit 4fe4e8f

Browse files
authored
Typo fix (Streamer mute)
1 parent 2220a05 commit 4fe4e8f

File tree

7 files changed

+41
-35
lines changed

7 files changed

+41
-35
lines changed

README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,7 @@ generate(prompts, max_tokens=100)
5656
generate(prompts, max_tokens=100, blind_model=True)
5757
```
5858

59-
### Model and Cache Quantization
60-
61-
```python
62-
# Model quantization
63-
generate("Describe the water cycle.", quantize_model=True)
64-
65-
# Cache quantization
66-
generate("Explain quantum computing.", quantize_cache=True)
67-
```
68-
69-
### Constrained Decoding (WIP)
59+
### Constrained (Beam Search) Decoding
7060

7161
The `constrain` function allows for structured generation, which can be useful for tasks like code generation, function calling, chain-of-thought prompting, or multiple-choice question answering.
7262

@@ -93,7 +83,11 @@ prompts = [
9383
"A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? A: Factor V Leiden B: Hemophilia A C: Lupus anticoagulant D: Protein C deficiency E: Von Willebrand disease",
9484
"A 25-year-old primigravida presents to her physician for a routine prenatal visit. She is at 34 weeks gestation, as confirmed by an ultrasound examination. She has no complaints, but notes that the new shoes she bought 2 weeks ago do not fit anymore. The course of her pregnancy has been uneventful and she has been compliant with the recommended prenatal care. Her medical history is unremarkable. She has a 15-pound weight gain since the last visit 3 weeks ago. Her vital signs are as follows: blood pressure, 148/90 mm Hg; heart rate, 88/min; respiratory rate, 16/min; and temperature, 36.6℃ (97.9℉). The blood pressure on repeat assessment 4 hours later is 151/90 mm Hg. The fetal heart rate is 151/min. The physical examination is significant for 2+ pitting edema of the lower extremity. Which of the following tests o should confirm the probable condition of this patient? A: Bilirubin assessment B: Coagulation studies C: Hematocrit assessment D: Leukocyte count with differential E: 24-hour urine protein"]
9585

96-
constrain(prompts, constraints=[(30, ' The correct answer is'), (10, 'X.')], blind_model=True, quantize_model=True)
86+
# Apply vanilla constrained decoding
87+
constrain(prompts, constraints=[(30, ' The correct answer is'), (10, 'X.')], blind_model=True, quantize_model=True, use_beam=False)
88+
89+
# Apply constrained beam decoding (ACB)
90+
constrain(prompts, constraints=[(30, ' The correct answer is'), (10, 'X.')], blind_model=True, quantize_model=True, use_beam=True)
9791
```
9892

9993
The constraints encourage a structured response that includes the thought process, making the output more informative and transparent:
@@ -131,6 +125,16 @@ batch_results = choose(prompts)
131125
print(batch_results) # Output: ['C', 'B']
132126
```
133127

128+
### Model and Cache Quantization
129+
130+
```python
131+
# Model quantization
132+
generate("Describe the water cycle.", quantize_model=True)
133+
134+
# Cache quantization
135+
generate("Explain quantum computing.", quantize_cache=True)
136+
```
137+
134138
### (Q)LoRA Fine-tuning
135139

136140
Training a LoRA Adapter

assets/ACB.pdf

106 Bytes
Binary file not shown.

assets/agent_toolchain.pdf

536 Bytes
Binary file not shown.

examples.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,31 @@
44

55
# Decoding Strategies
66

7+
## Code Generation
8+
9+
### Greedy Decoding
10+
pv.generate("Write a Python function to calculate the Fibonacci sequence up to a given number n.", blind_model=True, quantize_model=True)
11+
12+
### Constrained Decoding
13+
pv.constrain("Write a Python function to calculate the Fibonacci sequence up to a given number n.", [(100, "\n```python\n"), (100, " return "), (200, "\n```")], use_beam=False)
14+
15+
### Constrained Beam Search
16+
pv.constrain("Write a Python function to calculate the Fibonacci sequence up to a given number n.", [(100, "\n```python\n"), (100, " return "), (200, "\n```")], use_beam=True)
17+
718
## Multiple Choice Question Answering
819
prompts = [
920
"A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? A: Factor V Leiden B: Hemophilia A C: Lupus anticoagulant D: Protein C deficiency E: Von Willebrand disease",
1021
"A 25-year-old primigravida presents to her physician for a routine prenatal visit. She is at 34 weeks gestation, as confirmed by an ultrasound examination. She has no complaints, but notes that the new shoes she bought 2 weeks ago do not fit anymore. The course of her pregnancy has been uneventful and she has been compliant with the recommended prenatal care. Her medical history is unremarkable. She has a 15-pound weight gain since the last visit 3 weeks ago. Her vital signs are as follows: blood pressure, 148/90 mm Hg; heart rate, 88/min; respiratory rate, 16/min; and temperature, 36.6℃ (97.9℉). The blood pressure on repeat assessment 4 hours later is 151/90 mm Hg. The fetal heart rate is 151/min. The physical examination is significant for 2+ pitting edema of the lower extremity. Which of the following tests o should confirm the probable condition of this patient? A: Bilirubin assessment B: Coagulation studies C: Hematocrit assessment D: Leukocyte count with differential E: 24-hour urine protein"
1122
]
1223

13-
### Multiple Choice Selection
14-
pv.choose(prompts, choices='ABCDE')
15-
1624
### Constrained Decoding
1725
pv.constrain(prompts, constraints=[(100, ' The correct answer is'), (1, 'X.')], blind_model=True, quantize_model=True, use_beam=False)
1826

19-
### Constrained Beam Search (ACB)
27+
### Constrained Beam Search
2028
pv.constrain(prompts, constraints=[(100, ' The correct answer is'), (1, 'X.')], blind_model=True, quantize_model=True, use_beam=True)
2129

22-
## Code Generation
23-
24-
### Constrained Decoding
25-
pv.constrain("Write a Python function to calculate the Fibonacci sequence up to a given number n.", [(100, "\n```python\n"), (100, " return "), (200, "\n```")], use_beam=False)
26-
27-
### Constrained Beam Search
28-
pv.constrain("Write a Python function to calculate the Fibonacci sequence up to a given number n.", [(100, "\n```python\n"), (100, " return "), (200, "\n```")], use_beam=True)
30+
### Multiple Choice Selection
31+
pv.choose(prompts, choices='ABCDE')
2932

3033
# Train
3134
pv.train_lora(

phi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,8 @@ def __call__(self, keys, values, n_beam):
535535
self.offset = new_offset
536536
return keys, values
537537
else:
538-
self.kv[0,:,:,self.offset:new_offset,:] = keys
539-
self.kv[1,:,:,self.offset:new_offset,:] = values
538+
self.kv[0,:,:,self.offset:new_offset,:] = keys.astype(mx.float32)
539+
self.kv[1,:,:,self.offset:new_offset,:] = values.astype(mx.float32)
540540
self.offset = new_offset
541541
return self.kv[0,:,:,:new_offset,:], self.kv[1,:,:,:new_offset,:]
542542

phi_3_vision_mlx.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Streamer:
4646
def __init__(self, processor, stream, mute):
4747
self.tokenizer = processor.tokenizer
4848
self.mute = mute
49-
self.stream = stream and mute
49+
self.stream = stream and (not mute)
5050
self.list_tokens = []
5151
self.idx_sofar = 0
5252
def __call__(self, token):
@@ -71,7 +71,7 @@ def end(self):
7171
else:
7272
arr_tokens = mx.concatenate(self.list_tokens, axis=1)
7373
list_txt = self.tokenizer.batch_decode([(i[:i.index(ID_EOS)+1] if ID_EOS in i else i) for i in arr_tokens.tolist()])
74-
if self.mute is False:
74+
if not self.mute:
7575
for i, gen in enumerate(list_txt):
7676
print(f'\n< Generated text for prompt #{i} >\n{gen}')
7777
return list_txt, arr_tokens.size
@@ -372,7 +372,7 @@ def _get_wt(model_path, model_cfg):
372372
return [(k, v) for wf in glob.glob(f"{model_path}/*.safetensors") for k, v in mx.load(wf).items()]
373373
return [(k, v.transpose(0, 2, 3, 1) if "patch_embedding.weight" in k else v) for wf in glob.glob(f"{model_path}/*.safetensors") for k, v in mx.load(wf).items()]
374374

375-
def _generate(model, processor, prompt, images=None, max_tokens=1000, verbose=True, return_tps=False, early_stop=False, stream=True, mute=False):
375+
def _generate(model, processor, prompt, images=None, max_tokens=512, verbose=True, return_tps=False, early_stop=False, stream=True, mute=False):
376376
if images is not None and isinstance(prompt, list):
377377
raise ValueError('Images cannot be provided when prompt is a list')
378378
logit_stopper = LogitStopper(max_tokens, early_stop)
@@ -383,13 +383,13 @@ def _generate(model, processor, prompt, images=None, max_tokens=1000, verbose=Tr
383383
tic = Tic()
384384
logits, cache = model(**dict_input, max_tokens=max_tokens)
385385
token = mx.argmax(logits[:, -1, :], axis=-1)[:,None]
386-
mx.eval(token, logits, cache)
386+
mx.eval(token, logits)#, cache)
387387
streamer(token)
388388
prompt_time = tic()
389389
for i in range(max_tokens-1):
390390
logits, cache = model(input_ids=token, cache=cache, mask=mask, pids=pids)
391391
token = mx.argmax(logits[:, -1, :], axis=-1)[:,None]
392-
mx.eval(token, logits, cache)
392+
mx.eval(token, logits)#, cache)
393393
streamer(token)
394394
if logit_stopper(logits):
395395
break
@@ -529,7 +529,7 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
529529
dict_input = processor(prompt)
530530
logits, cache = model(**dict_input, max_tokens=constraint[0] + id_constraint.shape[0]+10)
531531
logits = nn.log_softmax(logits, axis=-1)
532-
mx.eval(logits, cache)
532+
mx.eval(logits)
533533
_score_0 = logits[:, -1, id_constraint[0]]
534534
tiled_id_constraint = mx.tile(id_constraint, (logits.shape[0], 1))
535535
logits_rest, _ = model(input_ids=tiled_id_constraint, cache=cache, advance_offset=0)
@@ -559,7 +559,7 @@ def _get_beam(logits, cache, id_constraint, beam_idx=0, n_beam=3):
559559
token_plus = mx.concatenate([token, tiled_id_constraint], axis=1)
560560
logits, cache = model(input_ids=token_plus, cache=cache, advance_offset=1)
561561
logits = nn.log_softmax(logits)
562-
mx.eval(logits, cache)
562+
mx.eval(logits)
563563
pre_beam_score = mx.concatenate([running_score, logits[mx.arange(logits.shape[0])[:,None], mx.arange(logits.shape[1]-1)[None,:], token_plus[:,1:]]], axis=1).mean(axis=1)
564564
pre_beam_synth = mx.concatenate(tokens + [tiled_id_constraint, synth_pad], axis=1)
565565
if use_beam:
@@ -1308,7 +1308,7 @@ def load(blind_model=False, quantize_model=False, quantize_cache=False, use_adap
13081308
_setup()
13091309
return _load(model_path=model_path, use_quantized_cache=quantize_cache, adapter_path=adapter_path)
13101310

1311-
def generate(prompt, images=None, preload=None, blind_model=False, quantize_model=False, quantize_cache=False, use_adapter=False, max_tokens=1000, verbose=True, return_tps=False, early_stop=False, stream=True, apply_chat_template=True):
1311+
def generate(prompt, images=None, preload=None, blind_model=False, quantize_model=False, quantize_cache=False, use_adapter=False, max_tokens=512, verbose=True, return_tps=False, early_stop=False, stream=True, apply_chat_template=True):
13121312
"""
13131313
Generate text based on a given prompt, optionally with image input.
13141314
@@ -1464,7 +1464,6 @@ def constrain(prompt, constraints=[(30, ' The correct answer is'), (1, 'X.')], i
14641464
prompt = _apply_chat_template(prompt, None, verbose)[0]
14651465
return _constrain(*preload, prompt=prompt, constraints=constraints, use_beam=use_beam, verbose=verbose)
14661466

1467-
14681467
def execute(code_strings, file_prefix=0, verbose=True):
14691468
"""
14701469
Execute one or more Python code strings and capture the results.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
url='https://github.com/JosefAlbers/Phi-3-Vision-MLX',
99
py_modules=['phi_3_vision_mlx', 'gte', 'phi', 'api'],
1010
packages=find_packages(),
11-
version='0.1.0-alpha',
11+
version='0.1.1-alpha',
1212
readme="README.md",
1313
author_email="[email protected]",
1414
description="Phi-3-Vision on Apple silicon with MLX",

0 commit comments

Comments
 (0)