Skip to content

Commit c53d181

Browse files
authored
[None][feat] Extend VLM factory and add Mistral3 factory (#7583)
This commit: * extends existing factory interfaces to enable Mistral3 in AutoDeploy. * adds a Mistral3 VLM factory. * adds various model patches for pixtral (the vision model) and mistral3 to make the VLM export compliant. * adjusts checkpoint loading code to take possible parameter name conversions into account. * fixes a sampling bug (the `end_id` needs to be take into account when sampling, but it is not included in the stop words' token IDs). Signed-off-by: William Zhang <[email protected]>
1 parent 6ba1c84 commit c53d181

File tree

13 files changed

+737
-19
lines changed

13 files changed

+737
-19
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,17 @@ def update_attn_page_size(self):
211211
self.attn_page_size = self.max_seq_len
212212
return self
213213

214+
@field_validator("model_factory", mode="after")
215+
@classmethod
216+
def model_factory_exists(cls, value: str) -> str:
217+
if not ModelFactoryRegistry.has(value):
218+
raise ValueError(
219+
f"'{value}' does not exist in the model factory registry. Available values: "
220+
f"{ModelFactoryRegistry.entries()}."
221+
)
222+
223+
return value
224+
214225
### UTILITY METHODS ############################################################################
215226
def create_factory(self) -> ModelFactory:
216227
"""Create a model factory from the arguments."""
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from . import hf, patches
1+
from . import hf, mistral3, patches
22
from .factory import *

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
from abc import ABC, abstractmethod
55
from enum import Enum
6-
from typing import Any, Callable, Dict, Optional, Tuple, Type
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
77

88
import torch
99
import torch.nn as nn
@@ -282,3 +282,7 @@ def get(cls, name: str) -> Type[ModelFactory]:
282282
@classmethod
283283
def has(cls, model_factory_cls: str) -> bool:
284284
return model_factory_cls in cls._registry
285+
286+
@classmethod
287+
def entries(cls) -> List[str]:
288+
return list(cls._registry.keys())

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Interface to initialize and load HF models."""
22

33
import os
4+
import re
45
import types
56
from contextlib import contextmanager, nullcontext
67
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -99,6 +100,11 @@ def __init__(self, *args, **kwargs):
99100
# set sharding config source to huggingface
100101
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
101102

103+
# Some models' transformers implementation has changed in between when safetensors were produced
104+
# and / or uploaded to HuggingFace hub. When building the model, we will try to determine whether
105+
# a mapping of the parameter names exists and hold that information in this attribute.
106+
self._checkpoint_conversion_mapping: Optional[Dict[str, str]] = None
107+
102108
@property
103109
def autoconfig_from_pretrained(self):
104110
return AutoConfig.from_pretrained
@@ -168,6 +174,7 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
168174

169175
# if present, initialize sharding config. We need head_dim for colwise sharding.
170176
self._set_sharding_config(model.config)
177+
self._checkpoint_conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
171178

172179
# patch forward method
173180
model.forward = types.MethodType(self._simple_forward, model)
@@ -326,15 +333,30 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
326333
"""Load the checkpoint into the model."""
327334
# identify the most relevant checkpoint file
328335
ckpt_file = self._get_checkpoint_file(self.model)
336+
337+
load_handle = model.register_load_state_dict_pre_hook(self._remap_param_names_load_hook)
338+
# Ensure it's the first one.
339+
model._load_state_dict_pre_hooks.move_to_end(key=load_handle.id, last=False)
340+
341+
get_handle = model.register_state_dict_post_hook(
342+
_StateDictParamNameConverter(self._checkpoint_conversion_mapping)
343+
)
344+
# Ensure it's the first one.
345+
model._state_dict_hooks.move_to_end(key=get_handle.id, last=False)
346+
329347
# reuse the load checkpoint utility from accelerate
330-
with hf_load_state_dict_with_device(device):
331-
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
332-
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
333-
# which collects local model params, syncs weights from checkpoint, and applies them via
334-
# model.load_state_dict.
335-
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
336-
# model-transformed weights,leading to unexpected key mismatches or format issues.
337-
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
348+
try:
349+
with hf_load_state_dict_with_device(device):
350+
# Set `full_state_dict=False` to skip Accelerate's FSDP weight sync logic.
351+
# Internally, load_checkpoint_in_model → set_model_state_dict → _load_model_state_dict,
352+
# which collects local model params, syncs weights from checkpoint, and applies them via
353+
# model.load_state_dict.
354+
# This sync step can interfere with load_hooks by mixing raw checkpoint weights and
355+
# model-transformed weights,leading to unexpected key mismatches or format issues.
356+
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
357+
finally:
358+
load_handle.remove()
359+
get_handle.remove()
338360

339361
def _load_quantization_config(self, fetched_dir: str):
340362
"""Load the quantization config from the model directory if not done already."""
@@ -351,6 +373,63 @@ def _load_quantization_config(self, fetched_dir: str):
351373
self._quant_config_reader = reader
352374
self.model_kwargs = deep_merge_dicts(self.model_kwargs, extra_model_kwargs)
353375

376+
def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> None:
377+
"""Hook to handle potential param name conversions.
378+
379+
Some models' transformers implementation can change in between when safetensors were produced
380+
and / or uploaded to HuggingFace hub. This hook applies the mapping (when present) to reflect
381+
these differences.
382+
"""
383+
conversion_mapping = self._checkpoint_conversion_mapping
384+
if conversion_mapping:
385+
keys_to_process = list(state_dict.keys())
386+
for key in keys_to_process:
387+
new_key = key
388+
for pattern, replacement in conversion_mapping.items():
389+
new_key = re.sub(pattern, replacement, new_key)
390+
391+
if new_key != key:
392+
state_dict[new_key] = state_dict.pop(key)
393+
394+
395+
class _StateDictParamNameConverter:
396+
"""Helper class for applying param name conversions to a state dict.
397+
398+
The reason this is a class instead of a method of factory like `_remap_param_names_load_hook`
399+
is because PyTorch tries to set an `_from_public_api` attribute on hooks, and bound instance
400+
methods cannot have attributes set on them without major hacks.
401+
"""
402+
403+
def __init__(self, conversion_mapping: Optional[Dict[str, str]]):
404+
conversion_mapping = conversion_mapping or {}
405+
406+
# NOTE: most of the code in this class is forked from `PreTrainedModel.save_pretrained`.
407+
reverse_key_mapping = {v: k for k, v in conversion_mapping.items()}
408+
self._mapping = reverse_key_mapping
409+
410+
def __call__(self, module, state_dict, *args, **kwargs) -> None:
411+
"""Hook to handle potential param name conversions.
412+
413+
For the same reasons as the `load` hook, we define one to for `state_dict`. This is to silence
414+
potentially misleading warnings about certain parameter names not being used, because the
415+
`accelerate` library's logic for determining which keys are unexpected bases it on the keys
416+
in the `module.state_dict()` return value, not on what `module.load_state_dict()` returns.
417+
"""
418+
if self._mapping:
419+
keys_to_process = list(state_dict.keys())
420+
for key in keys_to_process:
421+
new_key = key
422+
for pattern, replacement in self._mapping.items():
423+
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
424+
replacement = re.sub(r"\(.*\)", "", replacement)
425+
new_key, n_replace = re.subn(pattern, replacement, key)
426+
# Early exit of the loop
427+
if n_replace > 0:
428+
break
429+
430+
if new_key != key:
431+
state_dict[new_key] = state_dict.pop(key)
432+
354433

355434
@ModelFactoryRegistry.register("AutoModelForImageTextToText")
356435
class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
@@ -426,17 +505,19 @@ def _prep_seq(text, img1, img2):
426505
}
427506
]
428507

429-
# Create a batch of conversations (batch_size = 2)
508+
# Create a batch of conversations (batch_size = 2).
509+
# Note that we explicitly use 2 images in the examples to avoid potential shape specialization(s)
510+
# in `torch.compile` / `torch.export`.
430511
batch_messages = [
431512
_prep_seq(
432513
"Describe what you see in the two images and their differences.",
433-
Image.new("RGB", (16, 16), color=(128, 128, 128)),
434-
Image.new("RGB", (16, 16), color=(64, 64, 64)),
514+
Image.new("RGB", self._example_image_dims, color=(128, 128, 128)),
515+
Image.new("RGB", self._example_image_dims, color=(64, 64, 64)),
435516
),
436517
_prep_seq(
437518
"What are the main differences between these two images?",
438-
Image.new("RGB", (16, 16), color=(255, 0, 0)),
439-
Image.new("RGB", (16, 16), color=(0, 255, 0)),
519+
Image.new("RGB", self._example_image_dims, color=(255, 0, 0)),
520+
Image.new("RGB", self._example_image_dims, color=(0, 255, 0)),
440521
),
441522
]
442523

@@ -451,10 +532,15 @@ def _prep_seq(text, img1, img2):
451532
return_attention_mask=False,
452533
)
453534

454-
return {
455-
"input_ids": inputs["input_ids"],
456-
"pixel_values": inputs["pixel_values"],
457-
}
535+
# We should have no need for the attention mask, and it can actually cause issues in
536+
# downstream code.
537+
inputs.pop("attention_mask", None)
538+
539+
# NOTES:
540+
# 1. `inputs` is dict-like, but not a dict (hence the dict unpacking below).
541+
# 2. Although `get_extra_inputs` allows implementations to specify "extra inputs", the example
542+
# values still need to be returned by `get_example_inputs`.
543+
return {**inputs}
458544

459545
def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]:
460546
"""Return a dictionary of extra inputs for the model.
@@ -476,3 +562,10 @@ def _get_dynamic_shape():
476562

477563
none_pixel_values = torch.zeros(0, 3, 336, 336)
478564
return {"pixel_values": (none_pixel_values, _get_dynamic_shape)}
565+
566+
@property
567+
def _example_image_dims(self) -> Tuple[int, int]:
568+
# Some specializations (children) of this class may override this if their models have
569+
# assumptions on the image dimensions. For example, they may have a lower bound due to
570+
# the patch size they use.
571+
return (16, 16)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Auto-deploy model factory for Mistral3 models."""
2+
3+
from typing import Dict, Tuple
4+
5+
import torch
6+
7+
from tensorrt_llm._torch.auto_deploy.custom_ops import attention_interface
8+
from tensorrt_llm._torch.auto_deploy.models import factory, hf
9+
10+
11+
@factory.ModelFactoryRegistry.register("Mistral3VLM")
12+
class Mistral3VLM(hf.AutoModelForImageTextToTextFactory):
13+
def get_extra_inputs(
14+
self,
15+
) -> Dict[str, Tuple[torch.Tensor, attention_interface.DynamicShapeCallback]]:
16+
"""Return a dictionary of extra inputs for the model.
17+
18+
Returns:
19+
A dictionary of extra inputs for the model where the key corresponds to the argument
20+
name and the value corresponds to a tuple of (example_input, dynamic_shape_callback).
21+
The dynamic shape callback is a function that returns the dynamic shape of the extra
22+
input.
23+
"""
24+
extra_inputs = super().get_extra_inputs()
25+
# Reuse the same dynamic batch dimension for `image_sizes`.
26+
batch_dim = extra_inputs["pixel_values"][1]()[0]
27+
extra_inputs["image_sizes"] = (torch.zeros(0, 2, dtype=torch.long), lambda: {0: batch_dim})
28+
29+
return extra_inputs
30+
31+
@staticmethod
32+
def _simple_forward(
33+
model: torch.nn.Module,
34+
input_ids: torch.Tensor,
35+
position_ids: torch.Tensor,
36+
pixel_values: torch.Tensor,
37+
image_sizes: torch.Tensor,
38+
):
39+
"""A simple forward pass for the model to functionalize the args.
40+
41+
This follows the standard function signature as expected by factory.py.
42+
"""
43+
return type(model).forward(
44+
model,
45+
input_ids=input_ids,
46+
position_ids=position_ids,
47+
pixel_values=pixel_values,
48+
image_sizes=image_sizes,
49+
)
50+
51+
@property
52+
def _example_image_dims(self) -> Tuple[int, int]:
53+
# The pixtral processor requires a minimum image size, which is larger than the default (16, 16)
54+
# in the parent class.
55+
# TODO: figure this out on the model config somehow (patch size value, etc.).
56+
return (64, 64)

0 commit comments

Comments
 (0)