@@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies(
791791 image_grid_thw = None ,
792792 pixel_attention_mask = None ,
793793 image_sizes = None ,
794- image_split_sizes = None ,
795794 ) -> dict [str , Optional [torch .Tensor ]]:
796795 """Compute log-probs and (optionally) entropies for each token."""
797796 batch_size = batch_size or input_ids .size (0 ) # Chunk inputs into smaller batches to reduce memory peak
@@ -804,15 +803,13 @@ def _get_per_token_logps_and_entropies(
804803 # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
805804 model_inputs = {"input_ids" : input_ids_batch , "attention_mask" : attention_mask_batch }
806805
807- if image_grid_thw is not None :
806+ if image_grid_thw is not None and pixel_values is not None :
808807 model_inputs ["image_grid_thw" ] = image_grid_thw [start : start + batch_size ]
809- if pixel_values is not None :
810- if image_split_sizes is not None :
811- start_pixel_idx = sum (image_split_sizes [:start ])
812- end_pixel_idx = sum (image_split_sizes [: start + batch_size ])
813- model_inputs ["pixel_values" ] = pixel_values [start_pixel_idx :end_pixel_idx ]
814- else :
815- model_inputs ["pixel_values" ] = pixel_values [start : start + batch_size ]
808+ start_pixel_idx = image_grid_thw [:start ].prod (- 1 ).sum ().item ()
809+ end_pixel_idx = image_grid_thw [: start + batch_size ].prod (- 1 ).sum ().item ()
810+ model_inputs ["pixel_values" ] = pixel_values [start_pixel_idx :end_pixel_idx ]
811+ elif pixel_values is not None :
812+ model_inputs ["pixel_values" ] = pixel_values [start : start + batch_size ]
816813 if pixel_attention_mask is not None :
817814 model_inputs ["pixel_attention_mask" ] = pixel_attention_mask [start : start + batch_size ]
818815 if image_sizes is not None :
@@ -1078,19 +1075,13 @@ def _generate_and_score_completions(
10781075 # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
10791076 kwargs = {}
10801077 has_images = "image" in inputs [0 ]
1081- image_split_sizes = None
10821078 if has_images :
10831079 images = [example .get ("image" ) for example in inputs ]
10841080 kwargs = {"images" : [[img ] for img in images ]}
10851081 for prompt in prompts :
10861082 if isinstance (prompt , list ): # i.e., when using conversational data
10871083 prepare_multimodal_messages (prompt , num_images = 1 )
10881084
1089- if hasattr (self .processing_class , "_get_num_multimodal_tokens" ):
1090- image_sizes = [(image .height , image .width ) for image in images ]
1091- multimodal_extra_data = self .processing_class ._get_num_multimodal_tokens (image_sizes )
1092- image_split_sizes = multimodal_extra_data .num_image_patches
1093-
10941085 prompts_text = [maybe_apply_chat_template (example , self .processing_class )["prompt" ] for example in inputs ]
10951086
10961087 prompt_inputs = self .processing_class (
@@ -1104,10 +1095,6 @@ def _generate_and_score_completions(
11041095 prompt_inputs = super ()._prepare_inputs (prompt_inputs )
11051096 prompt_ids , prompt_mask = prompt_inputs ["input_ids" ], prompt_inputs ["attention_mask" ]
11061097
1107- if "image_grid_thw" in prompt_inputs and image_split_sizes is None :
1108- # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
1109- image_split_sizes = prompt_inputs ["image_grid_thw" ].prod (dim = 1 ).tolist ()
1110-
11111098 if self .max_prompt_length is not None :
11121099 # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
11131100 # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
@@ -1392,7 +1379,6 @@ def _generate_and_score_completions(
13921379 image_grid_thw = prompt_inputs .get ("image_grid_thw" ),
13931380 pixel_attention_mask = prompt_inputs .get ("pixel_attention_mask" ),
13941381 image_sizes = prompt_inputs .get ("image_sizes" ),
1395- image_split_sizes = image_split_sizes ,
13961382 )
13971383 else :
13981384 old_per_token_logps = None
@@ -1417,7 +1403,6 @@ def _generate_and_score_completions(
14171403 image_grid_thw = prompt_inputs .get ("image_grid_thw" ),
14181404 pixel_attention_mask = prompt_inputs .get ("pixel_attention_mask" ),
14191405 image_sizes = prompt_inputs .get ("image_sizes" ),
1420- image_split_sizes = image_split_sizes ,
14211406 )
14221407 else :
14231408 with self .accelerator .unwrap_model (self .model ).disable_adapter ():
@@ -1431,7 +1416,6 @@ def _generate_and_score_completions(
14311416 image_grid_thw = prompt_inputs .get ("image_grid_thw" ),
14321417 pixel_attention_mask = prompt_inputs .get ("pixel_attention_mask" ),
14331418 image_sizes = prompt_inputs .get ("image_sizes" ),
1434- image_split_sizes = image_split_sizes ,
14351419 )
14361420 else :
14371421 ref_per_token_logps = None
@@ -1580,8 +1564,6 @@ def _generate_and_score_completions(
15801564 output ["pixel_attention_mask" ] = prompt_inputs ["pixel_attention_mask" ]
15811565 if "image_sizes" in prompt_inputs :
15821566 output ["image_sizes" ] = prompt_inputs ["image_sizes" ]
1583- if image_split_sizes is not None :
1584- output ["image_split_sizes" ] = image_split_sizes
15851567 return output
15861568
15871569 def compute_liger_loss (self , unwrapped_model , inputs ):
@@ -1656,7 +1638,6 @@ def _compute_loss(self, model, inputs):
16561638 image_grid_thw = inputs .get ("image_grid_thw" ),
16571639 pixel_attention_mask = inputs .get ("pixel_attention_mask" ),
16581640 image_sizes = inputs .get ("image_sizes" ),
1659- image_split_sizes = inputs .get ("image_split_sizes" ),
16601641 )
16611642
16621643 if self .top_entropy_quantile < 1.0 :
0 commit comments