4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import argparse
7
+ import base64
7
8
import itertools
8
9
import logging
9
10
import os
12
13
13
14
from abc import ABC , abstractmethod
14
15
from dataclasses import dataclass
16
+ from io import BytesIO
15
17
from os import PathLike
16
18
from pathlib import Path
17
19
from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
@@ -600,9 +602,8 @@ def generate(
600
602
601
603
if len (prompt .shape ) > 1 :
602
604
prompt = prompt .squeeze (0 )
603
- T = prompt .size (0 )
604
- max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - T )
605
- T_new = T + max_new_tokens
605
+ prompt_length = prompt .size (0 )
606
+ max_new_tokens = min (max_new_tokens , max_seq_length - start_pos - prompt_length )
606
607
# set up caches only if first inference
607
608
if start_pos == 0 :
608
609
model = model .to (device = device )
@@ -616,7 +617,7 @@ def generate(
616
617
batch_size = 1 ,
617
618
dtype = self .dtype ,
618
619
encoder_max_seq_len = 6404 ,
619
- decoder_max_seq_len = T_new ,
620
+ decoder_max_seq_len = max_seq_length ,
620
621
)
621
622
else :
622
623
model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
@@ -629,7 +630,7 @@ def generate(
629
630
model .reset_caches ()
630
631
631
632
input_pos = torch .arange (
632
- start_pos , T + start_pos , device = device , dtype = torch .int
633
+ start_pos , prompt_length + start_pos , device = device , dtype = torch .int
633
634
)
634
635
635
636
prefill_t0 = time .perf_counter ()
@@ -655,7 +656,9 @@ def generate(
655
656
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
656
657
callback (next_token .clone ().view (- 1 ), done_generating = max_new_tokens <= 2 )
657
658
658
- input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
659
+ input_pos = torch .tensor (
660
+ [start_pos + prompt_length ], device = device , dtype = torch .int
661
+ )
659
662
accept_counts = [0 ] * (
660
663
speculate_k + 1
661
664
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
@@ -678,7 +681,7 @@ def generate(
678
681
)
679
682
680
683
accept_counts [len (next_tokens ) - 1 ] += 1
681
- num_added = min (T_new - input_pos - 1 , len (next_tokens ))
684
+ num_added = min (max_new_tokens - input_pos - 1 , len (next_tokens ))
682
685
for token in next_tokens [:num_added ,]:
683
686
callback (token )
684
687
yield token , None
@@ -741,6 +744,7 @@ def _gen_model_input(
741
744
prompt : Union [str | List [Any ]],
742
745
image_prompts : Optional [List [str | Image .Image ]] = None ,
743
746
max_new_tokens : Optional [int ] = None ,
747
+ max_seq_len : Optional [int ] = 2048 ,
744
748
) -> Tuple [torch .Tensor , Optional [Dict [str , Any ]]]:
745
749
"""
746
750
Convert prompt and image prompts into consumable model input args.
@@ -757,7 +761,7 @@ def _gen_model_input(
757
761
Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
758
762
"""
759
763
760
- # Not Llama 3.2 11B
764
+ # Text-Only model
761
765
if self .model .config .model_type != ModelType .Flamingo :
762
766
# Single String prompt
763
767
if isinstance (prompt , str ):
@@ -778,32 +782,69 @@ def _gen_model_input(
778
782
assert (
779
783
image_prompts is None or len (image_prompts ) == 1
780
784
), "At most one image is supported at the moment"
785
+
781
786
if image_prompts and isinstance (image_prompts [0 ], str ):
782
787
images = [Image .open (image_prompts [0 ])]
783
788
else :
784
- images = image_prompts
789
+ images = None
785
790
786
791
assert (
787
792
max_new_tokens is not None
788
793
), "max_new_tokens must be specified for Flamingo models"
789
- assert isinstance (
790
- prompt , str
791
- ), "(Currently) prompt must be a str for Flamingo models"
792
794
793
- is_multimodal = images is not None
794
- content = [{"type" : "text" , "content" : prompt }]
795
+ image_found = False
796
+ messages = []
797
+ for message in prompt :
798
+ if isinstance (message ["content" ], str ):
799
+ if not image_found and image_prompts :
800
+ messages .append (
801
+ Message (
802
+ role = message ["role" ],
803
+ content = [
804
+ {"type" : "image" , "content" : images [0 ]},
805
+ {"type" : "text" , "content" : message ["content" ]},
806
+ ],
807
+ )
808
+ )
809
+ image_found = True
810
+ else :
811
+ messages .append (Message (** message ))
812
+
813
+ elif isinstance (message ["content" ], list ):
814
+ images = None
815
+ for content_dict in message ["content" ]:
816
+ if content_dict ["type" ] == "text" :
817
+ prompt_arg = content_dict ["text" ]
818
+ elif content_dict ["type" ] == "image_url" :
819
+ assert (
820
+ images is None
821
+ ), "At most one image is supported at the moment"
822
+
823
+ base64_decoded = base64 .b64decode (
824
+ content_dict ["image_url" ].split (";base64," )[1 ]
825
+ )
826
+ images = [Image .open (BytesIO (base64_decoded ))]
827
+ image_found = True
828
+
829
+ is_multimodal = images is not None
830
+ content = [{"type" : "text" , "content" : prompt_arg }]
831
+
832
+ if is_multimodal :
833
+ content = [{"type" : "image" , "content" : images [0 ]}] + content
795
834
796
- if is_multimodal :
797
- content = [{"type" : "image" , "content" : images [0 ]}] + content
835
+ messages .append (
836
+ Message (
837
+ role = message ["role" ],
838
+ content = content ,
839
+ )
840
+ )
798
841
799
- messages = [
842
+ messages . append (
800
843
Message (
801
- role = "user" ,
802
- content = content ,
803
- eot = True ,
804
- ),
805
- Message (role = "assistant" , content = "" ),
806
- ]
844
+ role = "assistant" ,
845
+ content = "" ,
846
+ )
847
+ )
807
848
808
849
transform = llama3_2_vision_transform (str (self .tokenizer_args .tokenizer_path ))
809
850
@@ -812,7 +853,7 @@ def _gen_model_input(
812
853
with device , set_default_dtype (self .dtype ):
813
854
data = transform ({"messages" : messages }, inference = True )
814
855
815
- if is_multimodal :
856
+ if image_found :
816
857
batch = padded_collate_tiled_images_and_mask (
817
858
[data ], pad_direction = "left" , pad_max_images = 1
818
859
)
@@ -822,17 +863,27 @@ def _gen_model_input(
822
863
batch ["encoder_input" ]["images" ] = batch ["encoder_input" ]["images" ].to (
823
864
self .dtype
824
865
)
866
+
825
867
else :
826
868
encoded = torch .tensor (data ["tokens" ], device = device ).view (- 1 )
827
869
seq_len = encoded .size (0 )
828
870
batch = {}
829
871
830
872
total_response_length = seq_len + max_new_tokens
831
- batch ["causal_mask" ] = torch .tril (
832
- torch .ones (
833
- size = (total_response_length , total_response_length ),
834
- dtype = torch .bool ,
835
- )
873
+ batch ["causal_mask" ] = torch .nn .functional .pad (
874
+ torch .tril (
875
+ torch .ones (
876
+ size = (total_response_length , total_response_length ),
877
+ dtype = torch .bool ,
878
+ )
879
+ ),
880
+ (
881
+ 0 ,
882
+ max_seq_len - total_response_length ,
883
+ 0 ,
884
+ max_seq_len - total_response_length ,
885
+ ),
886
+ value = 0 ,
836
887
)
837
888
838
889
logging .debug (encoded )
@@ -845,12 +896,6 @@ def chat(
845
896
if generator_args .chat_mode :
846
897
print ("Starting Interactive Chat" )
847
898
848
- encoded , batch = self ._gen_model_input (
849
- generator_args .prompt ,
850
- generator_args .image_prompts ,
851
- generator_args .max_new_tokens ,
852
- )
853
-
854
899
model_size = sum (
855
900
[
856
901
p .numel () * p .dtype .itemsize
@@ -896,6 +941,12 @@ def chat(
896
941
max_seq_length = (
897
942
text_transformer_args .max_seq_length if text_transformer_args else 2048
898
943
)
944
+ encoded , batch = self ._gen_model_input (
945
+ [{"role" : "user" , "content" : generator_args .prompt }],
946
+ generator_args .image_prompts ,
947
+ generator_args .max_new_tokens ,
948
+ max_seq_length ,
949
+ )
899
950
900
951
if generator_args .chat_mode :
901
952
print (
@@ -907,16 +958,16 @@ def chat(
907
958
if get_system_prompt == "y" or get_system_prompt == "Y" :
908
959
self .system_prompt = input ("What is your system prompt? \n " )
909
960
910
- elif not generator_args .is_torchtune_model :
911
- max_seq_length = min (
912
- encoded .size (0 ) + generator_args .max_new_tokens ,
913
- (
914
- text_transformer_args .block_size
915
- if text_transformer_args is not None
916
- else 2048
917
- ),
918
- max_seq_length ,
919
- )
961
+ # elif not generator_args.is_torchtune_model:
962
+ # max_seq_length = min(
963
+ # encoded.size(0) + generator_args.max_new_tokens,
964
+ # (
965
+ # text_transformer_args.block_size
966
+ # if text_transformer_args is not None
967
+ # else 2048
968
+ # ),
969
+ # max_seq_length,
970
+ # )
920
971
921
972
max_seq_length = (
922
973
max_seq_length + self .speculative_builder_args .speculate_k + 1
0 commit comments