Skip to content

Commit d0993b3

Browse files
vmpuriJack-Khuuvmpuri
authored
Multiturn mm single image (pytorch#1270)
* initial test * Pad casual mask with zeroes and set decoder max_seq_len to the max sequence length so their shapes are both set to the max_seq_len * Fix control bug for image inputs * Clear image input after submitting a chat * Include empty assistant message for chat * Pipe image input from CLI --------- Co-authored-by: Jack-Khuu <[email protected]> Co-authored-by: vmpuri <[email protected]>
1 parent 766bee9 commit d0993b3

File tree

3 files changed

+151
-107
lines changed

3 files changed

+151
-107
lines changed

Diff for: torchchat/generate.py

+96-45
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import argparse
7+
import base64
78
import itertools
89
import logging
910
import os
@@ -12,6 +13,7 @@
1213

1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass
16+
from io import BytesIO
1517
from os import PathLike
1618
from pathlib import Path
1719
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
@@ -600,9 +602,8 @@ def generate(
600602

601603
if len(prompt.shape) > 1:
602604
prompt = prompt.squeeze(0)
603-
T = prompt.size(0)
604-
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
605-
T_new = T + max_new_tokens
605+
prompt_length = prompt.size(0)
606+
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
606607
# set up caches only if first inference
607608
if start_pos == 0:
608609
model = model.to(device=device)
@@ -616,7 +617,7 @@ def generate(
616617
batch_size=1,
617618
dtype=self.dtype,
618619
encoder_max_seq_len=6404,
619-
decoder_max_seq_len=T_new,
620+
decoder_max_seq_len=max_seq_length,
620621
)
621622
else:
622623
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
@@ -629,7 +630,7 @@ def generate(
629630
model.reset_caches()
630631

631632
input_pos = torch.arange(
632-
start_pos, T + start_pos, device=device, dtype=torch.int
633+
start_pos, prompt_length + start_pos, device=device, dtype=torch.int
633634
)
634635

635636
prefill_t0 = time.perf_counter()
@@ -655,7 +656,9 @@ def generate(
655656
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656657
callback(next_token.clone().view(-1), done_generating=max_new_tokens <= 2)
657658

658-
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
659+
input_pos = torch.tensor(
660+
[start_pos + prompt_length], device=device, dtype=torch.int
661+
)
659662
accept_counts = [0] * (
660663
speculate_k + 1
661664
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +681,7 @@ def generate(
678681
)
679682

680683
accept_counts[len(next_tokens) - 1] += 1
681-
num_added = min(T_new - input_pos - 1, len(next_tokens))
684+
num_added = min(max_new_tokens - input_pos - 1, len(next_tokens))
682685
for token in next_tokens[:num_added,]:
683686
callback(token)
684687
yield token, None
@@ -741,6 +744,7 @@ def _gen_model_input(
741744
prompt: Union[str | List[Any]],
742745
image_prompts: Optional[List[str | Image.Image]] = None,
743746
max_new_tokens: Optional[int] = None,
747+
max_seq_len: Optional[int] = 2048,
744748
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
745749
"""
746750
Convert prompt and image prompts into consumable model input args.
@@ -757,7 +761,7 @@ def _gen_model_input(
757761
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758762
"""
759763

760-
# Not Llama 3.2 11B
764+
# Text-Only model
761765
if self.model.config.model_type != ModelType.Flamingo:
762766
# Single String prompt
763767
if isinstance(prompt, str):
@@ -778,32 +782,69 @@ def _gen_model_input(
778782
assert (
779783
image_prompts is None or len(image_prompts) == 1
780784
), "At most one image is supported at the moment"
785+
781786
if image_prompts and isinstance(image_prompts[0], str):
782787
images = [Image.open(image_prompts[0])]
783788
else:
784-
images = image_prompts
789+
images = None
785790

786791
assert (
787792
max_new_tokens is not None
788793
), "max_new_tokens must be specified for Flamingo models"
789-
assert isinstance(
790-
prompt, str
791-
), "(Currently) prompt must be a str for Flamingo models"
792794

793-
is_multimodal = images is not None
794-
content = [{"type": "text", "content": prompt}]
795+
image_found = False
796+
messages = []
797+
for message in prompt:
798+
if isinstance(message["content"], str):
799+
if not image_found and image_prompts:
800+
messages.append(
801+
Message(
802+
role=message["role"],
803+
content=[
804+
{"type": "image", "content": images[0]},
805+
{"type": "text", "content": message["content"]},
806+
],
807+
)
808+
)
809+
image_found = True
810+
else:
811+
messages.append(Message(**message))
812+
813+
elif isinstance(message["content"], list):
814+
images = None
815+
for content_dict in message["content"]:
816+
if content_dict["type"] == "text":
817+
prompt_arg = content_dict["text"]
818+
elif content_dict["type"] == "image_url":
819+
assert (
820+
images is None
821+
), "At most one image is supported at the moment"
822+
823+
base64_decoded = base64.b64decode(
824+
content_dict["image_url"].split(";base64,")[1]
825+
)
826+
images = [Image.open(BytesIO(base64_decoded))]
827+
image_found = True
828+
829+
is_multimodal = images is not None
830+
content = [{"type": "text", "content": prompt_arg}]
831+
832+
if is_multimodal:
833+
content = [{"type": "image", "content": images[0]}] + content
795834

796-
if is_multimodal:
797-
content = [{"type": "image", "content": images[0]}] + content
835+
messages.append(
836+
Message(
837+
role=message["role"],
838+
content=content,
839+
)
840+
)
798841

799-
messages = [
842+
messages.append(
800843
Message(
801-
role="user",
802-
content=content,
803-
eot=True,
804-
),
805-
Message(role="assistant", content=""),
806-
]
844+
role="assistant",
845+
content="",
846+
)
847+
)
807848

808849
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
809850

@@ -812,7 +853,7 @@ def _gen_model_input(
812853
with device, set_default_dtype(self.dtype):
813854
data = transform({"messages": messages}, inference=True)
814855

815-
if is_multimodal:
856+
if image_found:
816857
batch = padded_collate_tiled_images_and_mask(
817858
[data], pad_direction="left", pad_max_images=1
818859
)
@@ -822,17 +863,27 @@ def _gen_model_input(
822863
batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
823864
self.dtype
824865
)
866+
825867
else:
826868
encoded = torch.tensor(data["tokens"], device=device).view(-1)
827869
seq_len = encoded.size(0)
828870
batch = {}
829871

830872
total_response_length = seq_len + max_new_tokens
831-
batch["causal_mask"] = torch.tril(
832-
torch.ones(
833-
size=(total_response_length, total_response_length),
834-
dtype=torch.bool,
835-
)
873+
batch["causal_mask"] = torch.nn.functional.pad(
874+
torch.tril(
875+
torch.ones(
876+
size=(total_response_length, total_response_length),
877+
dtype=torch.bool,
878+
)
879+
),
880+
(
881+
0,
882+
max_seq_len - total_response_length,
883+
0,
884+
max_seq_len - total_response_length,
885+
),
886+
value=0,
836887
)
837888

838889
logging.debug(encoded)
@@ -845,12 +896,6 @@ def chat(
845896
if generator_args.chat_mode:
846897
print("Starting Interactive Chat")
847898

848-
encoded, batch = self._gen_model_input(
849-
generator_args.prompt,
850-
generator_args.image_prompts,
851-
generator_args.max_new_tokens,
852-
)
853-
854899
model_size = sum(
855900
[
856901
p.numel() * p.dtype.itemsize
@@ -896,6 +941,12 @@ def chat(
896941
max_seq_length = (
897942
text_transformer_args.max_seq_length if text_transformer_args else 2048
898943
)
944+
encoded, batch = self._gen_model_input(
945+
[{"role": "user", "content": generator_args.prompt}],
946+
generator_args.image_prompts,
947+
generator_args.max_new_tokens,
948+
max_seq_length,
949+
)
899950

900951
if generator_args.chat_mode:
901952
print(
@@ -907,16 +958,16 @@ def chat(
907958
if get_system_prompt == "y" or get_system_prompt == "Y":
908959
self.system_prompt = input("What is your system prompt? \n")
909960

910-
elif not generator_args.is_torchtune_model:
911-
max_seq_length = min(
912-
encoded.size(0) + generator_args.max_new_tokens,
913-
(
914-
text_transformer_args.block_size
915-
if text_transformer_args is not None
916-
else 2048
917-
),
918-
max_seq_length,
919-
)
961+
# elif not generator_args.is_torchtune_model:
962+
# max_seq_length = min(
963+
# encoded.size(0) + generator_args.max_new_tokens,
964+
# (
965+
# text_transformer_args.block_size
966+
# if text_transformer_args is not None
967+
# else 2048
968+
# ),
969+
# max_seq_length,
970+
# )
920971

921972
max_seq_length = (
922973
max_seq_length + self.speculative_builder_args.speculate_k + 1

0 commit comments

Comments
 (0)