Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions xinference/model/llm/transformers/multimodal/qwen-omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from threading import Thread
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch

from .....types import (
ChatCompletion,
ChatCompletionAudio,
Expand All @@ -35,20 +37,31 @@

@register_transformer
@register_non_default_model("qwen2.5-omni")
class Qwen2_5OmniChatModel(PytorchMultiModalModel):
@register_non_default_model("Qwen3-Omni-Thinking")
@register_non_default_model("Qwen3-Omni-Instruct")
class QwenOmniChatModel(PytorchMultiModalModel):
DEFAULT_SYSTEM_PROMPT = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
"capable of perceiving auditory and visual inputs, as well as generating text and speech."
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 2.5 or 3
model_family = self.model_family.model_family or self.model_family.model_name
self._omni_version = "2.5" if "2.5" in model_family else "3"

@classmethod
def match_json(
cls, model_family: "LLMFamilyV2", model_spec: "LLMSpecV1", quantization: str
) -> bool:
if model_spec.model_format not in ["pytorch", "gptq", "awq", "bnb"]:
return False
llm_family = model_family.model_family or model_family.model_name
if "qwen2.5-omni".lower() in llm_family.lower():
if (
"qwen2.5-omni".lower() in llm_family.lower()
or "qwen3-omni".lower() in llm_family.lower()
):
return True
return False

Expand All @@ -58,15 +71,25 @@ def decide_device(self):
self._device = device

def load_processor(self):
from transformers import Qwen2_5OmniProcessor
if self._omni_version == "2.5":
from transformers import Qwen2_5OmniProcessor as QwenOminiProcessor
else:
from transformers import Qwen3OmniMoeProcessor as QwenOminiProcessor

self._processor = Qwen2_5OmniProcessor.from_pretrained(
self._processor = QwenOminiProcessor.from_pretrained(
self.model_path, trust_remote_code=True
)
self._tokenizer = self._processor.tokenizer

def load_multimodal_model(self):
from transformers import Qwen2_5OmniForConditionalGeneration
if self._omni_version == "2.5":
from transformers import (
Qwen2_5OmniForConditionalGeneration as QwenOmniForConditionalGeneration,
)
else:
from transformers import (
Qwen3OmniMoeForConditionalGeneration as QwenOmniForConditionalGeneration,
)

# for multiple GPU, set back to auto to make multiple devices work
device = "auto" if self._device == "cuda" else self._device
Expand All @@ -79,7 +102,7 @@ def load_multimodal_model(self):
kwargs = self.apply_bnb_quantization(kwargs)
logger.debug("Loading model with extra kwargs: %s", kwargs)

self._model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
self._model = QwenOmniForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype="auto",
device_map=device,
Expand Down Expand Up @@ -181,11 +204,37 @@ def generate_non_streaming(
inputs = self.build_inputs_from_messages(messages, generate_config) # type: ignore
use_audio_in_video = generate_config.get("use_audio_in_video", True)
gen_kwargs = dict(**inputs, **config, use_audio_in_video=use_audio_in_video)
generated_ids, audio = self._model.generate(**gen_kwargs)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# === Run model.generate() (handle both (ids, audio) and ids-only cases) ===
result = self._model.generate(**gen_kwargs)
if isinstance(result, tuple) and len(result) == 2:
# Qwen2.5-Omni returns (generated_ids, audio)
generated_ids, audio = result
else:
# Qwen3-Omni returns only generated_ids
generated_ids, audio = result, None
if hasattr(generated_ids, "sequences"):
generated_ids = generated_ids.sequences

# === Handle text decoding ===
input_len = inputs.input_ids.shape[1]
# Ensure we have a consistent 2D structure
# Normalize to list[list[int]]
if isinstance(generated_ids, torch.Tensor):
generated_ids = generated_ids.tolist()
elif isinstance(generated_ids, list) and all(
isinstance(x, int) for x in generated_ids
):
# Single sequence as flat list of ints
generated_ids = [generated_ids]
elif isinstance(generated_ids, list) and all(
isinstance(x, list) for x in generated_ids
):
pass # already correct
else:
raise TypeError(f"Unexpected generated_ids type: {type(generated_ids)}")

# Remove prompt tokens
generated_ids_trimmed = [out_ids[input_len:] for out_ids in generated_ids]
output_text = self._processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
Expand Down
Loading