11"""Interface to initialize and load HF models."""
22
33import os
4+ import re
45import types
56from contextlib import contextmanager , nullcontext
67from typing import Any , Dict , List , Optional , Tuple , Union
@@ -99,6 +100,11 @@ def __init__(self, *args, **kwargs):
99100 # set sharding config source to huggingface
100101 self ._sharding_config ["source" ] = ShardingConfigSource .HUGGINGFACE
101102
103+ # Some models' transformers implementation has changed in between when safetensors were produced
104+ # and / or uploaded to HuggingFace hub. When building the model, we will try to determine whether
105+ # a mapping of the parameter names exists and hold that information in this attribute.
106+ self ._checkpoint_conversion_mapping : Optional [Dict [str , str ]] = None
107+
102108 @property
103109 def autoconfig_from_pretrained (self ):
104110 return AutoConfig .from_pretrained
@@ -168,6 +174,7 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
168174
169175 # if present, initialize sharding config. We need head_dim for colwise sharding.
170176 self ._set_sharding_config (model .config )
177+ self ._checkpoint_conversion_mapping = getattr (model , "_checkpoint_conversion_mapping" , None )
171178
172179 # patch forward method
173180 model .forward = types .MethodType (self ._simple_forward , model )
@@ -326,15 +333,30 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
326333 """Load the checkpoint into the model."""
327334 # identify the most relevant checkpoint file
328335 ckpt_file = self ._get_checkpoint_file (self .model )
336+
337+ load_handle = model .register_load_state_dict_pre_hook (self ._remap_param_names_load_hook )
338+ # Ensure it's the first one.
339+ model ._load_state_dict_pre_hooks .move_to_end (key = load_handle .id , last = False )
340+
341+ get_handle = model .register_state_dict_post_hook (
342+ _StateDictParamNameConverter (self ._checkpoint_conversion_mapping )
343+ )
344+ # Ensure it's the first one.
345+ model ._state_dict_hooks .move_to_end (key = get_handle .id , last = False )
346+
329347 # reuse the load checkpoint utility from accelerate
330- with hf_load_state_dict_with_device (device ):
331- # Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
332- # Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
333- # which collects local model params, syncs weights from checkpoint, and applies them via
334- # model.load_state_dict.
335- # This sync step can interfere with load_hooks by mixing raw checkpoint weights and
336- # model-transformed weights,leading to unexpected key mismatches or format issues.
337- load_checkpoint_in_model (model , checkpoint = ckpt_file , full_state_dict = False )
348+ try :
349+ with hf_load_state_dict_with_device (device ):
350+ # Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
351+ # Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
352+ # which collects local model params, syncs weights from checkpoint, and applies them via
353+ # model.load_state_dict.
354+ # This sync step can interfere with load_hooks by mixing raw checkpoint weights and
355+ # model-transformed weights,leading to unexpected key mismatches or format issues.
356+ load_checkpoint_in_model (model , checkpoint = ckpt_file , full_state_dict = False )
357+ finally :
358+ load_handle .remove ()
359+ get_handle .remove ()
338360
339361 def _load_quantization_config (self , fetched_dir : str ):
340362 """Load the quantization config from the model directory if not done already."""
@@ -351,6 +373,63 @@ def _load_quantization_config(self, fetched_dir: str):
351373 self ._quant_config_reader = reader
352374 self .model_kwargs = deep_merge_dicts (self .model_kwargs , extra_model_kwargs )
353375
376+ def _remap_param_names_load_hook (self , model , state_dict , * args , ** kwargs ) -> None :
377+ """Hook to handle potential param name conversions.
378+
379+ Some models' transformers implementation can change in between when safetensors were produced
380+ and / or uploaded to HuggingFace hub. This hook applies the mapping (when present) to reflect
381+ these differences.
382+ """
383+ conversion_mapping = self ._checkpoint_conversion_mapping
384+ if conversion_mapping :
385+ keys_to_process = list (state_dict .keys ())
386+ for key in keys_to_process :
387+ new_key = key
388+ for pattern , replacement in conversion_mapping .items ():
389+ new_key = re .sub (pattern , replacement , new_key )
390+
391+ if new_key != key :
392+ state_dict [new_key ] = state_dict .pop (key )
393+
394+
395+ class _StateDictParamNameConverter :
396+ """Helper class for applying param name conversions to a state dict.
397+
398+ The reason this is a class instead of a method of factory like `_remap_param_names_load_hook`
399+ is because PyTorch tries to set an `_from_public_api` attribute on hooks, and bound instance
400+ methods cannot have attributes set on them without major hacks.
401+ """
402+
403+ def __init__ (self , conversion_mapping : Optional [Dict [str , str ]]):
404+ conversion_mapping = conversion_mapping or {}
405+
406+ # NOTE: most of the code in this class is forked from `PreTrainedModel.save_pretrained`.
407+ reverse_key_mapping = {v : k for k , v in conversion_mapping .items ()}
408+ self ._mapping = reverse_key_mapping
409+
410+ def __call__ (self , module , state_dict , * args , ** kwargs ) -> None :
411+ """Hook to handle potential param name conversions.
412+
413+ For the same reasons as the `load` hook, we define one to for `state_dict`. This is to silence
414+ potentially misleading warnings about certain parameter names not being used, because the
415+ `accelerate` library's logic for determining which keys are unexpected bases it on the keys
416+ in the `module.state_dict()` return value, not on what `module.load_state_dict()` returns.
417+ """
418+ if self ._mapping :
419+ keys_to_process = list (state_dict .keys ())
420+ for key in keys_to_process :
421+ new_key = key
422+ for pattern , replacement in self ._mapping .items ():
423+ replacement = replacement .lstrip ("^" ) # strip off un-needed chars and patterns
424+ replacement = re .sub (r"\(.*\)" , "" , replacement )
425+ new_key , n_replace = re .subn (pattern , replacement , key )
426+ # Early exit of the loop
427+ if n_replace > 0 :
428+ break
429+
430+ if new_key != key :
431+ state_dict [new_key ] = state_dict .pop (key )
432+
354433
355434@ModelFactoryRegistry .register ("AutoModelForImageTextToText" )
356435class AutoModelForImageTextToTextFactory (AutoModelForCausalLMFactory ):
@@ -426,17 +505,19 @@ def _prep_seq(text, img1, img2):
426505 }
427506 ]
428507
429- # Create a batch of conversations (batch_size = 2)
508+ # Create a batch of conversations (batch_size = 2).
509+ # Note that we explicitly use 2 images in the examples to avoid potential shape specialization(s)
510+ # in `torch.compile` / `torch.export`.
430511 batch_messages = [
431512 _prep_seq (
432513 "Describe what you see in the two images and their differences." ,
433- Image .new ("RGB" , ( 16 , 16 ) , color = (128 , 128 , 128 )),
434- Image .new ("RGB" , ( 16 , 16 ) , color = (64 , 64 , 64 )),
514+ Image .new ("RGB" , self . _example_image_dims , color = (128 , 128 , 128 )),
515+ Image .new ("RGB" , self . _example_image_dims , color = (64 , 64 , 64 )),
435516 ),
436517 _prep_seq (
437518 "What are the main differences between these two images?" ,
438- Image .new ("RGB" , ( 16 , 16 ) , color = (255 , 0 , 0 )),
439- Image .new ("RGB" , ( 16 , 16 ) , color = (0 , 255 , 0 )),
519+ Image .new ("RGB" , self . _example_image_dims , color = (255 , 0 , 0 )),
520+ Image .new ("RGB" , self . _example_image_dims , color = (0 , 255 , 0 )),
440521 ),
441522 ]
442523
@@ -451,10 +532,15 @@ def _prep_seq(text, img1, img2):
451532 return_attention_mask = False ,
452533 )
453534
454- return {
455- "input_ids" : inputs ["input_ids" ],
456- "pixel_values" : inputs ["pixel_values" ],
457- }
535+ # We should have no need for the attention mask, and it can actually cause issues in
536+ # downstream code.
537+ inputs .pop ("attention_mask" , None )
538+
539+ # NOTES:
540+ # 1. `inputs` is dict-like, but not a dict (hence the dict unpacking below).
541+ # 2. Although `get_extra_inputs` allows implementations to specify "extra inputs", the example
542+ # values still need to be returned by `get_example_inputs`.
543+ return {** inputs }
458544
459545 def get_extra_inputs (self ) -> Dict [str , Tuple [torch .Tensor , Optional [DynamicShapeCallback ]]]:
460546 """Return a dictionary of extra inputs for the model.
@@ -476,3 +562,10 @@ def _get_dynamic_shape():
476562
477563 none_pixel_values = torch .zeros (0 , 3 , 336 , 336 )
478564 return {"pixel_values" : (none_pixel_values , _get_dynamic_shape )}
565+
566+ @property
567+ def _example_image_dims (self ) -> Tuple [int , int ]:
568+ # Some specializations (children) of this class may override this if their models have
569+ # assumptions on the image dimensions. For example, they may have a lower bound due to
570+ # the patch size they use.
571+ return (16 , 16 )
0 commit comments