diff --git a/README.md b/README.md
index 0a6af7d..9e4f3f0 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
## 📣Latest News
-- [25/11/03] We have released v0.2. Quantization support for new models, such as `GLM-4.6` and `Qwen3-VL`, open-sources the Eagle3 speculative decoding training framework, and updates the Diffusion model quantization tools.
+- [25/11/05] We have released v0.2. Quantization support for new models, such as `GLM-4.6`, `Qwen3-VL` and `Qwen3-Omni`, open-sources the Eagle3 speculative decoding training framework, and updates the Diffusion model quantization tools.
- [25/09/30] We have released **SpecExit**, the reasoning early-exit algorithm: [[Paper]](http://arxiv.org/abs/2509.24248) | [[Docs]](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/spec_exit.html) | [[vLLM Code]](https://github.com/vllm-project/vllm/pull/27192)🔥🔥🔥
- [25/09/26] We have released **TEQUILA**, the ternary quantization algorithm [[Paper]](https://arxiv.org/abs/2509.23809) | [[Code]](https://github.com/Tencent/AngelSlim/tree/tequila/TernaryQuant)🔥🔥🔥
- [25/09/24] We now support the PTQ quantification of NVFP4 for the Qwen3 series models. We also opensource [Qwen3-32B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-32B_nvfp4) and [Qwen3-235B-A22B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-235B-A22B_nvfp4) weights.
@@ -171,7 +171,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
@@ -510,7 +510,40 @@ Benchmark results for Qwen2.5VL series models with `BF16`、`FP8-Static`、`FP8-
-#### 1.5 Other Models
+#### 1.5 Qwen-Omni Series Models
+
+**Qwen3-Omni Text to Text Benchmark**
+
+Benchmark results for Qwen3-Omni series models in BF16, FP8-Static, and FP8-Dynamic on aime25, gpqa_diamond, and mmlu_redux are as follows:
+
+
+
+ Model Quantization aime25 gpqa_diamond mmlu_redux
+
+
+ Qwen3-Omni-30B-A3B-Instruct BF16 73.32 56.77 88.09
+ FP8-Static 71.33 56.57 87.91
+ FP8-Dynamic 73.33 55.15 88.07
+
+
+
+
+Note
+
+> - The above evaluation results were obtained by deploying with the vLLM framework and averaging over 5 runs (vLLM only supports the thinker component).
+> - The hyperparameters used during evaluation are as follows:
+> ```json
+>{
+> "top_p": 0.95,
+> "temperature": 0.6,
+> "do_sample": true,
+> "max-model-len 65536": 65536
+>}
+>```
+
+
+
+#### 1.6 Other Models
Other models such as GLM-4.6, Qwen2.5, and Seed-OSS have been evaluated on benchmarks like `CEVAL`, `MMLU`, and `GSM8K` using quantization strategies including `FP8-Static`, `FP8-Dynamic`, `INT4-GPTQ`, and `INT4-AWQ`.
diff --git a/README_cn.md b/README_cn.md
index 0c38b91..cda6f60 100644
--- a/README_cn.md
+++ b/README_cn.md
@@ -17,7 +17,7 @@
## 📣最新进展
-- [25/11/03] 我们发布V0.2版本,支持了包括GLM-4.6/Qwen3-VL等更多模型的量化,开源投机采样Eagle3训练框架,更新Diffusion模型量化工具。
+- [25/11/05] 我们发布V0.2版本,支持了包括GLM-4.6/Qwen3-VL/Qwen3-Omni等更多模型的量化,开源投机采样Eagle3训练框架,更新Diffusion模型量化工具。
- [25/09/30] 我们开源了思考早退新算法 **SpecExit** [[论文]](http://arxiv.org/abs/2509.24248) | [[文档]](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/spec_exit.html) | [[vLLM代码]](https://github.com/vllm-project/vllm/pull/27192)🔥🔥🔥
- [25/09/30] 我们发布了三值量化新算法 **Tequila** [[论文]](https://arxiv.org/abs/2509.23809) | [[代码]](https://github.com/Tencent/AngelSlim/tree/tequila/TernaryQuant)。🔥🔥🔥
- [25/09/24] 我们支持了Qwen3系列模型的NVFP4的PTQ量化,我们还开源了[Qwen3-32B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-32B_nvfp4)、[Qwen3-235B-A22B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-235B-A22B_nvfp4)权重。
@@ -172,7 +172,7 @@
@@ -517,7 +517,40 @@ Qwen2.5VL系列模型的`BF16`、`FP8-Static`、`FP8-Dynamic`、`INT4-GPTQ`、`I
-#### 1.5 其他模型
+#### 1.5 Qwen-Omni 系列模型
+
+**Qwen3-Omni Text to Text Benchmark**
+
+Qwen3-Omni系列模型的`BF16`、`FP8-Static`、`FP8-Dynamic`在`aime25`、`gpqa_diamond`、`mmlu_redux`上的评测结果如下:
+
+
+
+ Model Quantization aime25 gpqa_diamond mmlu_redux
+
+
+ Qwen3-Omni-30B-A3B-Instruct BF16 73.32 56.77 88.09
+ FP8-Static 71.33 56.57 87.91
+ FP8-Dynamic 73.33 55.15 88.07
+
+
+
+
+备注
+
+> - 以上评测结果使用vllm框架部署测试5次求平均(vllm只支持thinker部分)
+> - 评测时使用的超参如下:
+> ```json
+>{
+> "top_p": 0.95,
+> "temperature": 0.6,
+> "do_sample": true,
+> "max-model-len 65536": 65536
+>}
+>```
+
+
+
+#### 1.6 其他模型
其他模型比如GLM、Qwen2.5、Seed-OSS等模型利用`FP8-Static`、`FP8-Dynamic`、`INT4-GPTQ`、`INT4-AWQ`量化等策略在`CEVAL`、`MMLU`、`GSM8K`上进行了评测。
diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py
index 3fe1d32..8aabb1f 100644
--- a/angelslim/compressor/quant/ptq.py
+++ b/angelslim/compressor/quant/ptq.py
@@ -14,6 +14,7 @@
import json
import os
+import warnings
import torch
from safetensors.torch import load_file
@@ -193,9 +194,19 @@ def _convert(self):
)
is not None
):
- self.quant_model.act_scales_dict[name] = self.ptq_hook.observer_dict[
- sub_layer
- ].act_observer.scales()
+ try:
+ self.quant_model.act_scales_dict[name] = (
+ self.ptq_hook.observer_dict[sub_layer].act_observer.scales()
+ )
+ except ValueError:
+ self.quant_model.act_scales_dict[name] = torch.tensor(
+ 1.0, device=torch.cuda.current_device()
+ )
+ warnings.warn(
+ f"Not calibrated for {name}. Using default act scale 1.0.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
if (
getattr( # noqa: B009
self.ptq_hook.observer_dict[sub_layer], "kv_cache_observer"
diff --git a/angelslim/data/__init__.py b/angelslim/data/__init__.py
index 97d05f9..7e5c29c 100644
--- a/angelslim/data/__init__.py
+++ b/angelslim/data/__init__.py
@@ -6,5 +6,6 @@
from .dataloader import DataLoaderFactory # noqa: F401
from .multimodal_dataset import MultiModalDataset # noqa: F401
+from .omni_dataset import OmniDataset # noqa: F401
from .text2image_dataset import Text2ImageDataset # noqa: F401
from .text_dataset import TextDataset # noqa: F401
diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py
index 1c92cdd..c41b755 100644
--- a/angelslim/data/dataloader.py
+++ b/angelslim/data/dataloader.py
@@ -20,6 +20,7 @@
from .base_dataset import BaseDataset
from .multimodal_dataset import MultiModalDataset
+from .omni_dataset import OmniDataset
from .text2image_dataset import Text2ImageDataset
from .text_dataset import TextDataset
@@ -39,6 +40,7 @@ def create_data_loader(
data_type: str = "auto",
num_workers: int = 0,
inference_settings: Dict = None,
+ use_audio_in_video: bool = False,
model_name: str = None,
) -> DataLoader:
"""
@@ -98,6 +100,16 @@ def create_data_loader(
num_samples=num_samples,
inference_settings=inference_settings,
)
+ elif data_type == "OmniDataset":
+ dataset = OmniDataset(
+ processor=processor,
+ device=device,
+ max_length=max_length,
+ num_samples=num_samples,
+ data_source=data_source,
+ is_hf_dataset=not os.path.isfile(data_source),
+ use_audio_in_video=use_audio_in_video,
+ )
else:
raise ValueError(f"Unsupported data type: {data_type}")
diff --git a/angelslim/data/multimodal_dataset.py b/angelslim/data/multimodal_dataset.py
index bd0ddc2..62f42cc 100644
--- a/angelslim/data/multimodal_dataset.py
+++ b/angelslim/data/multimodal_dataset.py
@@ -16,12 +16,12 @@
import os
from typing import Dict, List, Union
-import qwen_vl_utils
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm
from transformers import ProcessorMixin
+from ..utils.lazy_imports import qwen_vl_utils
from .base_dataset import BaseDataset
diff --git a/angelslim/data/omni_dataset.py b/angelslim/data/omni_dataset.py
new file mode 100644
index 0000000..52f51ed
--- /dev/null
+++ b/angelslim/data/omni_dataset.py
@@ -0,0 +1,127 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from pathlib import Path
+from typing import Dict, List, Union
+
+from transformers import ProcessorMixin
+
+from ..utils.lazy_imports import qwen_omni_utils
+from .base_dataset import BaseDataset
+
+
+class OmniDataset(BaseDataset):
+ """Dataset for multimodal (text + image) data"""
+
+ def __init__(
+ self,
+ processor: ProcessorMixin,
+ device: str = "cpu",
+ max_length: int = 4096,
+ num_samples: int = -1,
+ data_source: Union[str, Dict] = None,
+ is_hf_dataset: bool = False,
+ use_audio_in_video: bool = False,
+ ):
+ super().__init__(processor, device, max_length)
+ self.is_hf_dataset = is_hf_dataset
+ self.use_audio_in_video = use_audio_in_video
+
+ self._load_file_based_dataset(data_source, num_samples)
+
+ def _load_file_based_dataset(self, data_path: str, num_samples: int):
+ """Load dataset from local file system"""
+ path_obj = Path(data_path)
+ data_dir = path_obj.parent
+
+ line_count = 0
+ with open(data_path, "r") as f:
+ for line in f:
+ if num_samples > 0 and line_count >= num_samples:
+ break
+ data = json.loads(line.strip())
+ video_path = None
+ audio_path = None
+ image_path = None
+
+ if "video_path" in data:
+ video_path = os.path.normpath(
+ os.path.join(data_dir, data["video_path"])
+ )
+ if "audio_path" in data:
+ audio_path = os.path.normpath(
+ os.path.join(data_dir, data["audio_path"])
+ )
+ if "image_path" in data:
+ image_path = os.path.normpath(
+ os.path.join(data_dir, data["image_path"])
+ )
+
+ ms = data.get("messages")
+
+ conversation = []
+ for m in ms:
+ if m["role"] == "system":
+ conversation.append(
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": m["content"]}],
+ }
+ )
+ elif m["role"] == "user":
+ content = []
+ text_content = m["content"]
+ text_content = (
+ text_content.replace("", "")
+ .replace("", "")
+ .replace("", "")
+ )
+ content.append({"type": "text", "text": text_content})
+ if video_path:
+ content.append({"type": "video", "video": video_path})
+ if audio_path:
+ content.append({"type": "audio", "audio": audio_path})
+ if image_path:
+ content.append({"type": "image", "image": image_path})
+ conversation.append(
+ {
+ "role": "user",
+ "content": content,
+ }
+ )
+ self._process_and_append(conversation)
+ line_count += 1
+
+ def _process_and_append(self, messages: List[Dict]):
+ """Process messages and append to dataset"""
+ text = self.processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ audios, images, videos = qwen_omni_utils.process_mm_info(
+ messages, use_audio_in_video=self.use_audio_in_video
+ )
+
+ # Process inputs
+ inputs = self.processor(
+ text=text,
+ images=images,
+ audios=audios,
+ videos=videos,
+ padding=True,
+ return_tensors="pt",
+ use_audio_in_video=self.use_audio_in_video,
+ )
+ self.data.append(inputs)
diff --git a/angelslim/engine.py b/angelslim/engine.py
index 9aa0769..28fb0e8 100644
--- a/angelslim/engine.py
+++ b/angelslim/engine.py
@@ -73,6 +73,7 @@ def prepare_model(
cache_dir=None,
deploy_backend="vllm",
using_multi_nodes=False,
+ use_audio_in_video=False,
) -> Any:
"""Load pretrained model and tokenizer
Args:
@@ -116,6 +117,16 @@ def prepare_model(
using_multi_nodes=using_multi_nodes,
)
self.model_path = model_path
+ elif self.series in ["Omni"]:
+ if not model:
+ self.slim_model.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ trust_remote_code=trust_remote_code,
+ use_audio_in_video=use_audio_in_video,
+ )
+ self.model_path = model_path
else:
raise ValueError(f"Unsupported series: {self.series}")
@@ -131,6 +142,7 @@ def prepare_data(
num_samples=128,
shuffle=True,
inference_settings=None,
+ use_audio_in_video=False,
model_name=None,
) -> Optional[Any]:
"""Prepare compression dataset"""
@@ -145,7 +157,7 @@ def prepare_data(
data_type=data_type,
processor=(
self.slim_model.processor
- if self.series == "VLM"
+ if self.series == "VLM" or self.series == "Omni"
else self.slim_model.tokenizer
),
device=self.slim_model.model.device,
@@ -155,6 +167,7 @@ def prepare_data(
num_samples=num_samples,
data_source=data_path,
inference_settings=inference_settings,
+ use_audio_in_video=use_audio_in_video,
model_name=model_name,
)
self.max_seq_length = max_length
@@ -187,7 +200,7 @@ def prepare_compressor(
f"Compression method '{method_name}' not registered. "
f"Available methods: {CompressorFactory.get_available_compressor()}"
)
- if self.series in ["LLM", "VLM"]:
+ if self.series in ["LLM", "VLM", "Omni"]:
global_config.update(self.model_path, self.max_seq_length)
if default_method:
diff --git a/angelslim/models/__init__.py b/angelslim/models/__init__.py
index 7202adb..4036371 100644
--- a/angelslim/models/__init__.py
+++ b/angelslim/models/__init__.py
@@ -15,4 +15,5 @@
from .diffusion import * # noqa: F401 F403
from .llm import * # noqa: F401 F403
from .model_factory import SlimModelFactory # noqa: F401
+from .omni import * # noqa: F401 F403
from .vlm import * # noqa: F401 F403
diff --git a/angelslim/models/model_factory.py b/angelslim/models/model_factory.py
index 3427e09..8e02996 100644
--- a/angelslim/models/model_factory.py
+++ b/angelslim/models/model_factory.py
@@ -22,7 +22,7 @@ class SlimModelFactory:
registry: Dict[str, Type] = {}
series_registry: Dict[str, str] = {}
- ALLOWED_SERIES = ("LLM", "VLM", "Diffusion")
+ ALLOWED_SERIES = ("LLM", "VLM", "Diffusion", "Omni")
@classmethod
def register(cls, model_class: Type) -> Type:
@@ -39,6 +39,8 @@ def register(cls, model_class: Type) -> Type:
series = "VLM"
elif "diffusion" in module_path:
series = "Diffusion"
+ elif "omni" in module_path:
+ series = "Omni"
else:
raise ValueError(
f"model_class '{class_name}' is not in a valid series: {cls.ALLOWED_SERIES}" # noqa: E501
diff --git a/angelslim/models/omni/__init__.py b/angelslim/models/omni/__init__.py
new file mode 100644
index 0000000..6fb9571
--- /dev/null
+++ b/angelslim/models/omni/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .qwen3_omni import Qwen_Omni # noqa: F401
diff --git a/angelslim/models/omni/qwen3_omni.py b/angelslim/models/omni/qwen3_omni.py
new file mode 100644
index 0000000..3dce11b
--- /dev/null
+++ b/angelslim/models/omni/qwen3_omni.py
@@ -0,0 +1,149 @@
+# Copyright 2025 Tencent Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from tqdm import tqdm
+from transformers import (
+ AutoProcessor,
+ AutoTokenizer,
+ Qwen3OmniMoeForConditionalGeneration,
+)
+
+from ...compressor.quant.core import PTQVLMSaveVllmHF
+from ...utils import find_layers, print_info
+from ..base_model import BaseLLMModel
+from ..model_factory import SlimModelFactory
+
+
+@SlimModelFactory.register
+class Qwen_Omni(BaseLLMModel):
+ def __init__(
+ self,
+ model=None,
+ deploy_backend="vllm",
+ ):
+ super().__init__(
+ model=model,
+ deploy_backend=deploy_backend,
+ )
+ self.modal_type = "Omni"
+ self.block_name = ["thinker.model.layers", "talker.model.layers"]
+
+ def from_pretrained(
+ self,
+ model_path,
+ torch_dtype="auto",
+ device_map="auto",
+ trust_remote_code=True,
+ use_audio_in_video=False,
+ ):
+ self.use_audio_in_video = use_audio_in_video
+ self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ attn_implementation="flash_attention_2",
+ )
+
+ # Load tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=trust_remote_code
+ )
+
+ # Load processor
+ self.processor = AutoProcessor.from_pretrained(
+ model_path, trust_remote_code=trust_remote_code
+ )
+
+ def get_observer_layers(self):
+ names = [
+ "k_proj",
+ "v_proj",
+ "q_proj",
+ "o_proj",
+ "up_proj",
+ "gate_proj",
+ "down_proj",
+ ]
+
+ observer_layers_dict = {}
+ layers_dict = find_layers(self.model, layers=self.observer_layer_classes)
+
+ ignore_layers = self.skip_layer_names()
+ for name, module in layers_dict.items():
+ block_condition = any(name.startswith(block) for block in self.block_name)
+ if block_condition and name.split(".")[-1] in names:
+ observer_layers_dict[name] = module
+ else:
+ ignore_layers.append(name)
+ self.quant_config.quant_algo_info["ignore_layers"] = ignore_layers
+
+ if self.quant_config.custom_observe_layers_names != "default":
+ for custom_observe_name in self.quant_config.custom_observe_layers_names:
+ for default_name in observer_layers_dict.keys():
+ if custom_observe_name not in default_name:
+ observer_layers_dict.pop(default_name)
+ return observer_layers_dict
+
+ def get_kvcache_observer_layers_names(self, observe_names):
+ names = ["self_attn.k_proj", "self_attn.v_proj"]
+ return [
+ k
+ for k in observe_names
+ if any(k.startswith(block) for block in self.block_name)
+ and k.split(".")[-2] + "." + k.split(".")[-1] in names
+ ]
+
+ def model_forward(self, dataloader, **kwargs):
+ self.model.use_cache = False
+
+ calibrated_cnt = 0
+ if (
+ "gptq" in self.quant_config.quant_algo
+ or "awq" in self.quant_config.quant_algo
+ or "gptaq" in self.quant_config.quant_algo
+ ):
+ device = "cuda:0"
+ else:
+ device = self.model.device
+ print_info(f"device is {device}")
+ if dataloader is not None:
+ with torch.no_grad():
+ for batch in tqdm(
+ dataloader, desc="calibrating...", total=len(dataloader)
+ ):
+ inputs = {k: v.to(device) for k, v in batch.items()}
+ try:
+ text_ids, audio = self.model.generate(
+ **inputs, use_audio_in_video=self.use_audio_in_video
+ )
+ calibrated_cnt += 1
+ except ValueError:
+ calibrated_cnt += 1
+ pass
+
+ def get_quant_module(self):
+ """
+ Returns the module that will be quantized.
+ This is typically the main transformer module of the model.
+ """
+ return self.model.thinker.model.layers
+
+ def get_save_func(self):
+ if self.deploy_backend in ["vllm", "huggingface"]:
+ return PTQVLMSaveVllmHF
+ else:
+ raise NotImplementedError(
+ f"deploy_backend {self.deploy_backend} is not supported for saving."
+ )
diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py
index 6b24dd1..c64ee00 100644
--- a/angelslim/utils/config_parser.py
+++ b/angelslim/utils/config_parser.py
@@ -93,6 +93,12 @@ def set_model_hidden_size(self, model_path) -> int:
json_data = get_hf_config(model_path)
if json_data["model_type"] in ["qwen3_vl"]:
self.hidden_size = json_data["text_config"]["hidden_size"]
+ elif (
+ json_data["architectures"][0]
+ if isinstance(json_data["architectures"], list)
+ else json_data["architectures"]
+ ) == "Qwen3OmniMoeForConditionalGeneration":
+ self.hidden_size = json_data["thinker_config"]["text_config"]["hidden_size"]
else:
self.hidden_size = json_data["hidden_size"]
@@ -125,6 +131,7 @@ class ModelConfig:
low_cpu_mem_usage: bool = field(default=True)
use_cache: bool = field(default=False)
cache_dir: Optional[str] = field(default=None)
+ use_audio_in_video: bool = field(default=False)
@dataclass
diff --git a/angelslim/utils/lazy_imports.py b/angelslim/utils/lazy_imports.py
index cc371c8..2db14ea 100644
--- a/angelslim/utils/lazy_imports.py
+++ b/angelslim/utils/lazy_imports.py
@@ -203,5 +203,6 @@ def __getattr__(self, name: str) -> Any:
deepspeed = LazyModule("deepspeed", "speculative")
-# --- VLM related lazy imports ---
-# qwen_vl_utils = LazyModule("qwen_vl_utils", "vlm")
+# --- multimodal related lazy imports ---
+qwen_vl_utils = LazyModule("qwen_vl_utils", "multimodal")
+qwen_omni_utils = LazyModule("qwen_omni_utils", "multimodal")
diff --git a/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml b/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml
new file mode 100644
index 0000000..58b8d74
--- /dev/null
+++ b/configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml
@@ -0,0 +1,24 @@
+# Global configuration of pipeline
+global:
+ save_path: ./output
+
+# Simplified Configuration for LLM compression
+model:
+ name: Qwen_Omni
+ model_path: Qwen/Qwen3-Omni-30B-A3B-Instruct
+ trust_remote_code: true
+ low_cpu_mem_usage: true
+ use_cache: false
+ torch_dtype: auto
+ device_map: auto
+ use_audio_in_video: false
+
+# Compression configuration
+compression:
+ name: PTQ
+ quantization:
+ name: fp8_dynamic # Supported: fp8_static, fp8_dynamic, int4_awq, int4_gptq
+ bits: 8 # Quantization bits (4/8)
+ quant_method:
+ weight: "per-tensor"
+ activation: "per-tensor"
diff --git a/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml b/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml
new file mode 100644
index 0000000..3c30997
--- /dev/null
+++ b/configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml
@@ -0,0 +1,32 @@
+# Global configuration of pipeline
+global:
+ save_path: ./output
+
+# Simplified Configuration for LLM compression
+model:
+ name: Qwen_Omni
+ model_path: Qwen/Qwen3-Omni-30B-A3B-Instruct
+ trust_remote_code: true
+ low_cpu_mem_usage: true
+ use_cache: false
+ torch_dtype: auto
+ device_map: auto
+ use_audio_in_video: false
+
+# Compression configuration
+compression:
+ name: PTQ
+ quantization:
+ name: fp8_static
+ bits: 8
+ quant_method:
+ weight: "per-tensor"
+ activation: "per-tensor"
+
+# Dataset for calibration
+dataset:
+ name: OmniDataset
+ data_path: your/data/path
+ max_seq_length: 8192
+ num_samples: 256
+ batch_size: 1
\ No newline at end of file
diff --git a/dataset/omni_fake_data/fake_data.json b/dataset/omni_fake_data/fake_data.json
new file mode 100755
index 0000000..885bced
--- /dev/null
+++ b/dataset/omni_fake_data/fake_data.json
@@ -0,0 +1,3 @@
+{"messages": [{"role": "user", "content": "What happens after the text disappears from the screen?"}], "video_path": "./videos/0.mp4"}
+{"messages": [{"role": "user", "content": "How many food item is shown in the bar graph?"}], "image_path": "./images/0.png"}
+{"messages": [{"role": "user", "content": "Why is the speech described as rich in frequency content?"}], "audio_path": "./audios/0.png"}
\ No newline at end of file
diff --git a/dataset/omni_fake_data/images/0.png b/dataset/omni_fake_data/images/0.png
new file mode 100755
index 0000000..b433d7d
Binary files /dev/null and b/dataset/omni_fake_data/images/0.png differ
diff --git a/docs/source/index.md b/docs/source/index.md
index 305a8da..b5a0e08 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -57,7 +57,7 @@ AngelSlim是腾讯自研的,致力于打造更易用、更全面和更高效
| | Wan | | | 建设中 |
| | SDXL | | | |
+-------------------+-----------------+----------------------+------------+-----------------+
- | **语音(TTS/ASR)** | Qwen3-Omni | 建设中 | 建设中 | **Token剪枝** |
+ | **语音(TTS/ASR)** | Qwen3-Omni | FP8-Static/Dynamic | 建设中 | **Token剪枝** |
| | | | | 建设中 |
+-------------------+-----------------+----------------------+------------+-----------------+
@@ -94,6 +94,7 @@ models/hunyuan/hunyuan_quant
models/deepseek/deepseek_quant
models/qwen/qwen_quant
models/qwenvl/qwenvl_quant
+models/qwen3_omni/qwen3_omni_quant
:::
diff --git a/docs/source/models/qwen3_omni/qwen3_omni_quant.md b/docs/source/models/qwen3_omni/qwen3_omni_quant.md
new file mode 100644
index 0000000..2e11254
--- /dev/null
+++ b/docs/source/models/qwen3_omni/qwen3_omni_quant.md
@@ -0,0 +1,50 @@
+# Qwen3-Omni量化指南
+
+Qwen3-Omni模型可采用**FP8(static、dynamic)** 方式进行模型压缩,以下是详细的量化配置与操作说明。
+
+
+## FP8 量化(W8A8)
+
+Qwen3-Omni的FP8量化采用**per-tensor粒度**,支持thinker与talker的llm模块的动态量化(dynamic)和静态量化(static)两种模式。
+
+### 配置参数说明
+
+FP8量化的配置文件可参考路径:`configs/qwen3_omni/fp8_static` 和 `configs/qwen3_omni/fp8_dynamic`,核心参数如下:
+
+#### model配置
+- `name`:模型名称,固定填写`Qwen_Omni`。
+- `model_path`:可填写hugging face模型卡片名称或者本地路径。
+- `use_audio_in_video`: 用于控制是否使用源视频的音频轨道
+
+#### compression配置
+- `name`:压缩策略类型,固定选择量化模式`PTQ`。
+- `quantization.name`:量化算法类型,根据需求选择`fp8_static`(静态量化)或`fp8_dynamic`(动态量化)。
+- `quantization.bits`:量化比特数,FP8量化固定填写`8`。
+- `quantization.quant_method`:权重量化粒度,FP8量化固定为`per-tensor`。
+
+#### dataset配置
+- `name`:数据集类型,固定选择`OmniDataset`。
+- `data_path`:数据集路径,支持jsonl文件路径。自定义数据集需参考`dataset/omni_fake_data/fake_data.json`格式。
+
+### 启动量化流程
+
+通过以下命令启动FP8量化校准:
+
+```shell
+# 动态FP8量化
+python3 tools/run.py -c configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml
+```
+
+```shell
+# 静态FP8量化
+python3 tools/run.py -c configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml
+```
+
+## 模型部署
+
+vLLM框架支持Qwen3-Omni的FP8(per-tensor)量化模型部署,建议使用官方部署方式:
+### vllm
+参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#vllm-usage](URL)
+
+### transformers
+参考[https://github.com/QwenLM/Qwen3-Omni?tab=readme-ov-file#transformers-usage](URL)
\ No newline at end of file
diff --git a/docs/source/performance/quantization/benchmarks.md b/docs/source/performance/quantization/benchmarks.md
index 863ee88..f008705 100644
--- a/docs/source/performance/quantization/benchmarks.md
+++ b/docs/source/performance/quantization/benchmarks.md
@@ -458,3 +458,26 @@ Qwen3VL系列模型的`BF16`、`FP8-Static`、`FP8-Dynamic`在`MMMU_VAL`、`DocV
```
FP8-Dynamic采用Block-wise的量化,启动命令:python3 tools/fp8_quant_blockwise.py --block_size --input_path --output_path
+
+
+## Qwen3-Omni
+
+**Qwen3-Omni Text -> Text Benchmark**
+
+Qwen3-Omni模型的`BF16`、`FP8-Static`、`FP8-Dynamic`在`aime25`、`gpqa_diamond`、`mmlu_redux`上的评测结果如下:
+
+```{eval-rst}
+.. table::
+ :align: center
+ :name: table-qwen3-omni-performance
+
+ +-----------------------------+----------------+----------+--------------+------------+
+ | Model | Quantization | aime25 | gpqa_diamond | mmlu_redux |
+ +=============================+================+==========+==============+============+
+ | Qwen3-Omni-30B-A3B-Instruct | BF16 | 73.32 | 56.77 | 88.09 |
+ + +----------------+----------+--------------+------------+
+ | | FP8-Static | 71.33 | 56.57 | 87.91 |
+ + +----------------+----------+--------------+------------+
+ | | FP8-Dynamic | 73.33 | 55.15 | 88.07 |
+ +-----------------------------+----------------+----------+--------------+------------+
+```
\ No newline at end of file
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 47d5610..d093b8c 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -10,4 +10,3 @@ tiktoken
datasets
threadpoolctl
shortuuid
-qwen_vl_utils==0.0.11
diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt
new file mode 100644
index 0000000..0c073e6
--- /dev/null
+++ b/requirements/requirements_multimodal.txt
@@ -0,0 +1,2 @@
+qwen_vl_utils==0.0.11
+qwen_omni_utils
diff --git a/setup.py b/setup.py
index a8bdf69..bd22c04 100644
--- a/setup.py
+++ b/setup.py
@@ -51,11 +51,14 @@ def get_requirements(filename):
"all": (
get_requirements("requirements/requirements_speculative.txt")
+ get_requirements("requirements/requirements_diffusion.txt")
+ + get_requirements("requirements/requirements_multimodal.txt")
),
# Install speculative sampling functionality: pip install angelslim[speculative]
"speculative": get_requirements("requirements/requirements_speculative.txt"),
# Install Diffusion functionality: pip install angelslim[diffusion]
"diffusion": get_requirements("requirements/requirements_diffusion.txt"),
+ # Install multimodal functionality: pip install angelslim[multimodal]
+ "multimodal": get_requirements("requirements/requirements_multimodal.txt"),
},
packages=find_packages(),
python_requires=">=3.0",
diff --git a/tools/run.py b/tools/run.py
index dca463e..8ebb86a 100644
--- a/tools/run.py
+++ b/tools/run.py
@@ -90,6 +90,7 @@ def multi_nodes_run(config):
low_cpu_mem_usage=model_config.low_cpu_mem_usage,
use_cache=model_config.use_cache,
cache_dir=model_config.cache_dir,
+ use_audio_in_video=model_config.use_audio_in_video,
deploy_backend=global_config.deploy_backend,
using_multi_nodes=True,
)
@@ -105,6 +106,7 @@ def multi_nodes_run(config):
num_samples=dataset_config.num_samples,
shuffle=dataset_config.shuffle,
inference_settings=dataset_config.inference_settings,
+ use_audio_in_video=model_config.use_audio_in_video,
)
# Step 6: Initialize compressor
@@ -148,6 +150,7 @@ def run(config):
low_cpu_mem_usage=model_config.low_cpu_mem_usage,
use_cache=model_config.use_cache,
cache_dir=model_config.cache_dir,
+ use_audio_in_video=model_config.use_audio_in_video,
deploy_backend=global_config.deploy_backend,
)
@@ -162,6 +165,7 @@ def run(config):
num_samples=dataset_config.num_samples,
shuffle=dataset_config.shuffle,
inference_settings=dataset_config.inference_settings,
+ use_audio_in_video=model_config.use_audio_in_video,
model_name=model_config.name,
)