From b77ba0582f0b804783dc0eb56edd5b8968f1a4bd Mon Sep 17 00:00:00 2001 From: qinxuye Date: Wed, 22 Oct 2025 11:39:31 +0800 Subject: [PATCH] FEAT: qwen3-omni --- .../llm/transformers/multimodal/qwen-omni.py | 71 ++++++++++++++++--- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/xinference/model/llm/transformers/multimodal/qwen-omni.py b/xinference/model/llm/transformers/multimodal/qwen-omni.py index 4f8d642cf9..fc68bf95fb 100644 --- a/xinference/model/llm/transformers/multimodal/qwen-omni.py +++ b/xinference/model/llm/transformers/multimodal/qwen-omni.py @@ -19,6 +19,8 @@ from threading import Thread from typing import Any, Dict, Iterator, List, Optional, Tuple +import torch + from .....types import ( ChatCompletion, ChatCompletionAudio, @@ -35,12 +37,20 @@ @register_transformer @register_non_default_model("qwen2.5-omni") -class Qwen2_5OmniChatModel(PytorchMultiModalModel): +@register_non_default_model("Qwen3-Omni-Thinking") +@register_non_default_model("Qwen3-Omni-Instruct") +class QwenOmniChatModel(PytorchMultiModalModel): DEFAULT_SYSTEM_PROMPT = ( "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, " "capable of perceiving auditory and visual inputs, as well as generating text and speech." ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # 2.5 or 3 + model_family = self.model_family.model_family or self.model_family.model_name + self._omni_version = "2.5" if "2.5" in model_family else "3" + @classmethod def match_json( cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str @@ -48,7 +58,10 @@ def match_json( if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]: return False llm_family = model_family.model_family or model_family.model_name - if "qwen2.5-omni".lower() in llm_family.lower(): + if ( + "qwen2.5-omni".lower() in llm_family.lower() + or "qwen3-omni".lower() in llm_family.lower() + ): return True return False @@ -58,15 +71,25 @@ def decide_device(self): self._device = device def load_processor(self): - from transformers import Qwen2_5OmniProcessor + if self._omni_version == "2.5": + from transformers import Qwen2_5OmniProcessor as QwenOminiProcessor + else: + from transformers import Qwen3OmniMoeProcessor as QwenOminiProcessor - self._processor = Qwen2_5OmniProcessor.from_pretrained( + self._processor = QwenOminiProcessor.from_pretrained( self.model_path, trust_remote_code=True ) self._tokenizer = self._processor.tokenizer def load_multimodal_model(self): - from transformers import Qwen2_5OmniForConditionalGeneration + if self._omni_version == "2.5": + from transformers import ( + Qwen2_5OmniForConditionalGeneration as QwenOmniForConditionalGeneration, + ) + else: + from transformers import ( + Qwen3OmniMoeForConditionalGeneration as QwenOmniForConditionalGeneration, + ) # for multiple GPU, set back to auto to make multiple devices work device = "auto" if self._device == "cuda" else self._device @@ -79,7 +102,7 @@ def load_multimodal_model(self): kwargs = self.apply_bnb_quantization(kwargs) logger.debug("Loading model with extra kwargs: %s", kwargs) - self._model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + self._model = QwenOmniForConditionalGeneration.from_pretrained( self.model_path, torch_dtype="auto", device_map=device, @@ -181,11 +204,37 @@ def generate_non_streaming( inputs = self.build_inputs_from_messages(messages, generate_config) # type: ignore use_audio_in_video = generate_config.get("use_audio_in_video", True) gen_kwargs = dict(**inputs, **config, use_audio_in_video=use_audio_in_video) - generated_ids, audio = self._model.generate(**gen_kwargs) - generated_ids_trimmed = [ - out_ids[len(in_ids) :] - for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] + # === Run model.generate() (handle both (ids, audio) and ids-only cases) === + result = self._model.generate(**gen_kwargs) + if isinstance(result, tuple) and len(result) == 2: + # Qwen2.5-Omni returns (generated_ids, audio) + generated_ids, audio = result + else: + # Qwen3-Omni returns only generated_ids + generated_ids, audio = result, None + if hasattr(generated_ids, "sequences"): + generated_ids = generated_ids.sequences + + # === Handle text decoding === + input_len = inputs.input_ids.shape[1] + # Ensure we have a consistent 2D structure + # Normalize to list[list[int]] + if isinstance(generated_ids, torch.Tensor): + generated_ids = generated_ids.tolist() + elif isinstance(generated_ids, list) and all( + isinstance(x, int) for x in generated_ids + ): + # Single sequence as flat list of ints + generated_ids = [generated_ids] + elif isinstance(generated_ids, list) and all( + isinstance(x, list) for x in generated_ids + ): + pass # already correct + else: + raise TypeError(f"Unexpected generated_ids type: {type(generated_ids)}") + + # Remove prompt tokens + generated_ids_trimmed = [out_ids[input_len:] for out_ids in generated_ids] output_text = self._processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True,