Skip to content

Commit

Permalink
Add Multi-Image Support with phi-3.5 in chat app demo (#1036)
Browse files Browse the repository at this point in the history
Changes:
- Switch to Gradio File object from Image object to support multiple
images
- Auto-populate image tags based on user input to avoid image tag
mismatch:
https://github.com/microsoft/onnxruntime-genai/blob/main/src/models/prompt_image_processor.cpp#L47
- Increase maximum tokens

Validation:
- [x] Tested demo with UI locally
- [x] Ran python unit test for multi-image support in GenAI here:
https://github.com/microsoft/onnxruntime-genai/pull/796/files#diff-3d2f537fb637f846715772cb3ac1a9b32e1289bdcc3be8e36c868c55cc016e9dR284
within chat app

Sample:

![image](https://github.com/user-attachments/assets/f37472cb-a0aa-4a7e-b820-e5f344abf9c9)

---------

Co-authored-by: Sayan Shaw <[email protected]>
  • Loading branch information
sayanshaw24 and Sayan Shaw authored Nov 6, 2024
1 parent 4bf84d8 commit 8965fed
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
24 changes: 12 additions & 12 deletions examples/chat_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,16 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa
)
max_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
step=8,
maximum=131072,
value=8192,
step=128,
interactive=True,
label="Max Token Length",
)
max_context_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
maximum=131072,
value=8192,
step=128,
interactive=True,
label="Max History Token Length",
Expand All @@ -142,18 +142,18 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa
label="Token Printing Step",
visible=False
)
image = gr.Image(type="filepath", visible=False)
image.change(
images = gr.File(file_count="multiple", file_types=["image"], label="Upload image(s)", visible=False)
images.change(
reset_state,
outputs=[chatbot, history, status_display],
show_progress=True,
)
image.change(**reset_args)
images.change(**reset_args)

model_name.change(
change_model_listener,
inputs=[model_name],
outputs=[model_name, image, chatbot, history, user_input, status_display],
outputs=[model_name, images, chatbot, history, user_input, status_display],
)
gr.Markdown(description)

Expand All @@ -166,7 +166,7 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa
max_length_tokens,
max_context_length_tokens,
token_printing_step,
image,
images,
],
"outputs": [chatbot, history, status_display],
"show_progress": True,
Expand All @@ -179,7 +179,7 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa
max_length_tokens,
max_context_length_tokens,
token_printing_step,
image
images
],
"outputs": [chatbot, history, status_display],
"show_progress": True,
Expand Down Expand Up @@ -219,7 +219,7 @@ def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_pa
cancels=[predict_event1, predict_event2, predict_event3],
)

demo.load(change_model_listener, inputs=[model_name], outputs=[model_name, image], concurrency_limit=1)
demo.load(change_model_listener, inputs=[model_name], outputs=[model_name, images], concurrency_limit=1)

demo.title = "Local Model UI"

Expand Down
21 changes: 12 additions & 9 deletions examples/chat_app/interface/multimodal_onnx_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __init__(self, model_path):
self.enable_history_max = 2
self.template_header = "<s>"
self.history_template = "[INST] {input} [/INST]{response}</s>"
self.chat_template = "<|user|>\n<|image_1|>\n{input}<|end|>\n<|assistant|>\n"
self.chat_template = "<|user|>\n{tags}\n{input}<|end|>\n<|assistant|>\n"

def generate_prompt_with_history(self, image, history, text=default_prompt, max_length=3072):
def generate_prompt_with_history(self, images, history, text=default_prompt, max_length=3072):

prompt = ""

Expand All @@ -31,16 +31,19 @@ def generate_prompt_with_history(self, image, history, text=default_prompt, max_

prompt = self.template_header + prompt

prompt += f'{self.chat_template.format(input=text)}'
image_tags = ""
for i in range(len(images)):
image_tags += f"<|image_{i+1}|>\n"

prompt += f'{self.chat_template.format(input=text, tags=image_tags)}'
if len(prompt) > max_length:
history.clear()
prompt = f'{self.chat_template.format(input=text)}'
prompt = f'{self.chat_template.format(input=text, tags=image_tags)}'

logging.info("Loading image ...")
self.image = og.Images.open(image)
self.images = og.Images.open(*images)

logging.info("Preprocessing image and prompt ...")
input_ids = self.processor(prompt, images=self.image)
logging.info("Preprocessing images and prompt ...")
input_ids = self.processor(prompt, images=self.images)

return input_ids

Expand Down Expand Up @@ -74,7 +77,7 @@ def predict(self, text, chatbot, history, max_length_tokens, max_context_length_
input_ids = self.generate_prompt_with_history(
text=text,
history=history,
image=args[0],
images=args[0],
max_length=max_context_length_tokens
)

Expand Down

0 comments on commit 8965fed

Please sign in to comment.