Skip to content

Commit 9e3d755

Browse files
authored
chatui() fixed and interpolate_336 added back
1 parent 27746f7 commit 9e3d755

File tree

3 files changed

+64
-11
lines changed

3 files changed

+64
-11
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ This project brings the powerful phi-3-vision VLM to Apple's MLX framework, offe
44

55
## Key Features
66

7-
* **Su-scaled RoPE:** Implements Su-scaled Rotary Position Embeddings to manage sequences of up to 128K tokens.
7+
* **VLM Agent:** Leverages VLM's visual understanding for interactive code generation and refinement, enabling data visualization and image manipulation through a visual feedback loop. (WIP)
88
* **Batch Generation:** Accelerate inference by generating text for multiple prompts concurrently (107 tokens-per-sec batched vs 56 tokens-per-sec original)
99
* **Cache Quantization:** Optimize inference for processing long contexts with key-value cache quantization (5.3s quantized vs 5.1s original).
1010
* **Model Quantization:** Reduce model size for faster loading and deployment (2.3GB quantized vs 8.5GB original).
11+
* **Su-scaled RoPE:** Implements Su-scaled Rotary Position Embeddings to manage sequences of up to 128K tokens.
1112
* **Chat Template:** Utilization of chat template for streamlining interactions with the model.
1213
* **LoRA Training:** Easily customize the model for specific tasks or datasets using LoRA.
1314
* **Benchmarking:** To quickly assess model performance on any dataset. (WIP)
14-
* **VLM Agent:** Leverages VLM's visual understanding for interactive code generation and refinement, enabling data visualization and image manipulation through a visual feedback loop. (WIP)
1515
* **Long Context RAG:** Enables the integration of Retrieval-Augmented Generation to harness large amounts of external knowledge for complex tasks such as code understanding, leveraging the phi-3-vision model's 128K context window. (WIP)
1616

1717
## Quick Start
@@ -28,6 +28,8 @@ chatui()
2828

2929
![Alt text](assets/chatui_2.png)
3030

31+
![Alt text](assets/chatui_caption.png)
32+
3133
### **Image Captioning**
3234

3335
```python

assets/chatui_caption.png

156 KB
Loading

main.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66
import requests
7-
import torch
7+
# import torch
88
import datasets
99
import random
1010
import textwrap
@@ -32,13 +32,14 @@ class Agent:
3232
def __init__(self, executor=None, **kwargs):
3333
self.execute = execute if executor is None else executor
3434
self.kwargs = kwargs
35-
self.preload = load(**kwargs)
35+
self.preload = _preload(**kwargs)
3636
self.reset()
3737

3838
def reset(self):
3939
self.log = []
4040
self.step = 0
4141
self.ongoing = None
42+
self.user_since = 0
4243

4344
def __call__(self, prompt:str, images=None, **kwargs):
4445
if self.step > 0:
@@ -379,6 +380,51 @@ def _merge(self, images, texts):
379380
"image_sizes": mx.array(image_sizes),
380381
"positions": mx.array(positions)}
381382

383+
def interpolate_336(input):
384+
def get_weights_and_indices(scale, out_size, in_size):
385+
def cubic(x):
386+
abs_x = np.abs(x)
387+
abs_x2 = abs_x ** 2
388+
abs_x3 = abs_x ** 3
389+
f = ((1.5 * abs_x3 - 2.5 * abs_x2 + 1) * (abs_x <= 1) +
390+
(-0.5 * abs_x3 + 2.5 * abs_x2 - 4 * abs_x + 2) * ((abs_x > 1) & (abs_x <= 2)))
391+
return f
392+
kernel_radius = 2
393+
kernel_width = kernel_radius * 2
394+
out_coordinates = np.linspace(0, in_size - 1, out_size)
395+
in_coordinates = out_coordinates / scale
396+
left_indices = np.floor(in_coordinates - 0.5).astype(np.int32)
397+
right_indices = left_indices + 1
398+
left_indices = np.clip(left_indices, 0, in_size - 1)
399+
right_indices = np.clip(right_indices, 0, in_size - 1)
400+
weights = np.zeros((out_size, kernel_width), dtype=np.float32)
401+
indices = np.zeros((out_size, kernel_width), dtype=np.int32)
402+
for i in range(out_size):
403+
indices[i, 0] = left_indices[i]
404+
indices[i, 1] = right_indices[i]
405+
weights[i, 0] = cubic(in_coordinates[i] - left_indices[i])
406+
weights[i, 1] = cubic(right_indices[i] - in_coordinates[i])
407+
408+
weight_sum = weights[i].sum()
409+
if weight_sum != 0:
410+
weights[i] /= weight_sum
411+
412+
return weights, indices
413+
N, C, H, W = input.shape
414+
out_hw = 336
415+
output = np.zeros((N, C, out_hw, out_hw), dtype=input.dtype)
416+
h_weights, h_indices = get_weights_and_indices(out_hw / H, out_hw, H)
417+
w_weights, w_indices = get_weights_and_indices(out_hw / W, out_hw, W)
418+
for n in range(N):
419+
for c in range(C):
420+
for i in range(out_hw):
421+
for j in range(out_hw):
422+
h_kernel = input[n, c, h_indices[i]]
423+
w_kernel = h_kernel[:, w_indices[j]]
424+
output[n, c, i, j] = np.sum(h_weights[i][:, None] * w_weights[j] * w_kernel)
425+
426+
return output
427+
382428
class Phi3VImageProcessor:
383429
def __init__(self):
384430
self.num_crops=16
@@ -416,7 +462,8 @@ def pad_to_max_num_crops_tensor(images, max_crops=17):
416462
hd_images = [HD_transform(img) for img in images]
417463
shapes = [[im.shape[1], im.shape[2]] for im in hd_images]
418464
num_img_tokens = [int((h//336*w//336+1)*144 + 1 + (h//336+1)*12) for h, w in shapes]
419-
global_image = [torch.nn.functional.interpolate(torch.from_numpy(im[None]), size=(336, 336), mode='bicubic').numpy() for im in hd_images]
465+
# global_image = [torch.nn.functional.interpolate(torch.from_numpy(im[None]), size=(336, 336), mode='bicubic').numpy() for im in hd_images]
466+
global_image = [interpolate_336(im[None]) for im in hd_images]
420467
hd_images_reshape = [im
421468
.reshape(1, 3, h//336, 336, w//336, 336)
422469
.transpose(0,2,4,1,3,5)
@@ -694,7 +741,8 @@ def _apply_chat_template(prompt, images, verbose):
694741

695742
def _preload(quantize_model=False, quantize_cache=False, adapter_path=None, **kwargs):
696743
if (quantize_model is True) and (not os.path.exists('quantized_phi3v')):
697-
quantize()
744+
# quantize()
745+
snapshot_download(repo_id="JosefAlbers/Phi-3-vision-128k-instruct-mlx", allow_patterns=["*.safetensors", "*.json"], local_dir='quantized_phi3v')
698746
model_path='quantized_phi3v' if quantize_model is True else 'phi3v'
699747
return load(model_path=model_path, use_quantized_cache=quantize_cache, adapter_path=adapter_path)
700748

@@ -892,7 +940,7 @@ def chat(prompt, images=None, preload=None, quantize_model=False, quantize_cache
892940
return generate(*preload, *_apply_chat_template(prompt, images, verbose), max_tokens=max_tokens, verbose=verbose, return_tps=return_tps, early_stop=early_stop, stream=stream)
893941

894942
def chatui():
895-
agent = Agent()
943+
agent = Agent(max_tokens=1000, early_stop=False)
896944

897945
def add_message(history, message):
898946
for x in message["files"]:
@@ -902,11 +950,14 @@ def add_message(history, message):
902950
return history, gr.MultimodalTextbox(value=None, interactive=False)
903951

904952
def bot(history):
905-
print(history)
906-
response, file = agent(*history[-1])
953+
def _get_input(history):
954+
return history[-1][0], [i[0][0] for i in history[agent.user_since:-1]] if agent.user_since+1 < len(history) else None
955+
agent_input = _get_input(history)
956+
response, file = agent(*agent_input)
907957
history.append((None, response))
908958
if file is not None:
909959
history.append((None, (file,)))
960+
agent.user_since = len(history)
910961
return history
911962

912963
def reset():
@@ -918,7 +969,7 @@ def reset():
918969
[],
919970
elem_id="chatbot",
920971
bubble_full_width=False,
921-
height='75vh'
972+
height='70vh'
922973
)
923974

924975
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
@@ -996,4 +1047,4 @@ def benchmark():
9961047
agent('Modify the code to plot 3:4 frequency.')
9971048
agent('Modify the code to plot pi/4 phase difference.')
9981049
agent.end()
999-
"""
1050+
"""

0 commit comments

Comments
 (0)