diff --git a/README.md b/README.md index 0a6af7d..9e4f3f0 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress

## 📣Latest News -- [25/11/03] We have released v0.2. Quantization support for new models, such as `GLM-4.6` and `Qwen3-VL`, open-sources the Eagle3 speculative decoding training framework, and updates the Diffusion model quantization tools. +- [25/11/05] We have released v0.2. Quantization support for new models, such as `GLM-4.6`, `Qwen3-VL` and `Qwen3-Omni`, open-sources the Eagle3 speculative decoding training framework, and updates the Diffusion model quantization tools. - [25/09/30] We have released **SpecExit**, the reasoning early-exit algorithm: [[Paper]](http://arxiv.org/abs/2509.24248) | [[Docs]](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/spec_exit.html) | [[vLLM Code]](https://github.com/vllm-project/vllm/pull/27192)🔥🔥🔥 - [25/09/26] We have released **TEQUILA**, the ternary quantization algorithm [[Paper]](https://arxiv.org/abs/2509.23809) | [[Code]](https://github.com/Tencent/AngelSlim/tree/tequila/TernaryQuant)🔥🔥🔥 - [25/09/24] We now support the PTQ quantification of NVFP4 for the Qwen3 series models. We also opensource [Qwen3-32B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-32B_nvfp4) and [Qwen3-235B-A22B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-235B-A22B_nvfp4) weights. @@ -171,7 +171,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress @@ -510,7 +510,40 @@ Benchmark results for Qwen2.5VL series models with `BF16`、`FP8-Static`、`FP8- -#### 1.5 Other Models +#### 1.5 Qwen-Omni Series Models + +**Qwen3-Omni Text to Text Benchmark** + +Benchmark results for Qwen3-Omni series models in BF16, FP8-Static, and FP8-Dynamic on aime25, gpqa_diamond, and mmlu_redux are as follows: + + + + + + + + + + +
ModelQuantizationaime25gpqa_diamondmmlu_redux
Qwen3-Omni-30B-A3B-InstructBF1673.3256.7788.09
FP8-Static71.3356.5787.91
FP8-Dynamic73.3355.1588.07
+ +
+Note + +> - The above evaluation results were obtained by deploying with the vLLM framework and averaging over 5 runs (vLLM only supports the thinker component). +> - The hyperparameters used during evaluation are as follows: +> ```json +>{ +> "top_p": 0.95, +> "temperature": 0.6, +> "do_sample": true, +> "max-model-len 65536": 65536 +>} +>``` + +
+ +#### 1.6 Other Models Other models such as GLM-4.6, Qwen2.5, and Seed-OSS have been evaluated on benchmarks like `CEVAL`, `MMLU`, and `GSM8K` using quantization strategies including `FP8-Static`, `FP8-Dynamic`, `INT4-GPTQ`, and `INT4-AWQ`. diff --git a/README_cn.md b/README_cn.md index 0c38b91..cda6f60 100644 --- a/README_cn.md +++ b/README_cn.md @@ -17,7 +17,7 @@

## 📣最新进展 -- [25/11/03] 我们发布V0.2版本,支持了包括GLM-4.6/Qwen3-VL等更多模型的量化,开源投机采样Eagle3训练框架,更新Diffusion模型量化工具。 +- [25/11/05] 我们发布V0.2版本,支持了包括GLM-4.6/Qwen3-VL/Qwen3-Omni等更多模型的量化,开源投机采样Eagle3训练框架,更新Diffusion模型量化工具。 - [25/09/30] 我们开源了思考早退新算法 **SpecExit** [[论文]](http://arxiv.org/abs/2509.24248) | [[文档]](https://angelslim.readthedocs.io/zh-cn/latest/features/speculative_decoding/spec_exit.html) | [[vLLM代码]](https://github.com/vllm-project/vllm/pull/27192)🔥🔥🔥 - [25/09/30] 我们发布了三值量化新算法 **Tequila** [[论文]](https://arxiv.org/abs/2509.23809) | [[代码]](https://github.com/Tencent/AngelSlim/tree/tequila/TernaryQuant)。🔥🔥🔥 - [25/09/24] 我们支持了Qwen3系列模型的NVFP4的PTQ量化,我们还开源了[Qwen3-32B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-32B_nvfp4)、[Qwen3-235B-A22B-NVFP4](https://huggingface.co/AngelSlim/Qwen3-235B-A22B_nvfp4)权重。 @@ -172,7 +172,7 @@ @@ -517,7 +517,40 @@ Qwen2.5VL系列模型的`BF16`、`FP8-Static`、`FP8-Dynamic`、`INT4-GPTQ`、`I -#### 1.5 其他模型 +#### 1.5 Qwen-Omni 系列模型 + +**Qwen3-Omni Text to Text Benchmark** + +Qwen3-Omni系列模型的`BF16`、`FP8-Static`、`FP8-Dynamic`在`aime25`、`gpqa_diamond`、`mmlu_redux`上的评测结果如下: + + + + + + + + + + +
ModelQuantizationaime25gpqa_diamondmmlu_redux
Qwen3-Omni-30B-A3B-InstructBF1673.3256.7788.09
FP8-Static71.3356.5787.91
FP8-Dynamic73.3355.1588.07
+ +
+备注 + +> - 以上评测结果使用vllm框架部署测试5次求平均(vllm只支持thinker部分) +> - 评测时使用的超参如下: +> ```json +>{ +> "top_p": 0.95, +> "temperature": 0.6, +> "do_sample": true, +> "max-model-len 65536": 65536 +>} +>``` + +
+ +#### 1.6 其他模型 其他模型比如GLM、Qwen2.5、Seed-OSS等模型利用`FP8-Static`、`FP8-Dynamic`、`INT4-GPTQ`、`INT4-AWQ`量化等策略在`CEVAL`、`MMLU`、`GSM8K`上进行了评测。 diff --git a/angelslim/compressor/quant/ptq.py b/angelslim/compressor/quant/ptq.py index 3fe1d32..8aabb1f 100644 --- a/angelslim/compressor/quant/ptq.py +++ b/angelslim/compressor/quant/ptq.py @@ -14,6 +14,7 @@ import json import os +import warnings import torch from safetensors.torch import load_file @@ -193,9 +194,19 @@ def _convert(self): ) is not None ): - self.quant_model.act_scales_dict[name] = self.ptq_hook.observer_dict[ - sub_layer - ].act_observer.scales() + try: + self.quant_model.act_scales_dict[name] = ( + self.ptq_hook.observer_dict[sub_layer].act_observer.scales() + ) + except ValueError: + self.quant_model.act_scales_dict[name] = torch.tensor( + 1.0, device=torch.cuda.current_device() + ) + warnings.warn( + f"Not calibrated for {name}. Using default act scale 1.0.", + RuntimeWarning, + stacklevel=2, + ) if ( getattr( # noqa: B009 self.ptq_hook.observer_dict[sub_layer], "kv_cache_observer" diff --git a/angelslim/data/__init__.py b/angelslim/data/__init__.py index 97d05f9..7e5c29c 100644 --- a/angelslim/data/__init__.py +++ b/angelslim/data/__init__.py @@ -6,5 +6,6 @@ from .dataloader import DataLoaderFactory # noqa: F401 from .multimodal_dataset import MultiModalDataset # noqa: F401 +from .omni_dataset import OmniDataset # noqa: F401 from .text2image_dataset import Text2ImageDataset # noqa: F401 from .text_dataset import TextDataset # noqa: F401 diff --git a/angelslim/data/dataloader.py b/angelslim/data/dataloader.py index 1c92cdd..c41b755 100644 --- a/angelslim/data/dataloader.py +++ b/angelslim/data/dataloader.py @@ -20,6 +20,7 @@ from .base_dataset import BaseDataset from .multimodal_dataset import MultiModalDataset +from .omni_dataset import OmniDataset from .text2image_dataset import Text2ImageDataset from .text_dataset import TextDataset @@ -39,6 +40,7 @@ def create_data_loader( data_type: str = "auto", num_workers: int = 0, inference_settings: Dict = None, + use_audio_in_video: bool = False, model_name: str = None, ) -> DataLoader: """ @@ -98,6 +100,16 @@ def create_data_loader( num_samples=num_samples, inference_settings=inference_settings, ) + elif data_type == "OmniDataset": + dataset = OmniDataset( + processor=processor, + device=device, + max_length=max_length, + num_samples=num_samples, + data_source=data_source, + is_hf_dataset=not os.path.isfile(data_source), + use_audio_in_video=use_audio_in_video, + ) else: raise ValueError(f"Unsupported data type: {data_type}") diff --git a/angelslim/data/multimodal_dataset.py b/angelslim/data/multimodal_dataset.py index bd0ddc2..62f42cc 100644 --- a/angelslim/data/multimodal_dataset.py +++ b/angelslim/data/multimodal_dataset.py @@ -16,12 +16,12 @@ import os from typing import Dict, List, Union -import qwen_vl_utils from datasets import load_dataset from PIL import Image from tqdm import tqdm from transformers import ProcessorMixin +from ..utils.lazy_imports import qwen_vl_utils from .base_dataset import BaseDataset diff --git a/angelslim/data/omni_dataset.py b/angelslim/data/omni_dataset.py new file mode 100644 index 0000000..52f51ed --- /dev/null +++ b/angelslim/data/omni_dataset.py @@ -0,0 +1,127 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path +from typing import Dict, List, Union + +from transformers import ProcessorMixin + +from ..utils.lazy_imports import qwen_omni_utils +from .base_dataset import BaseDataset + + +class OmniDataset(BaseDataset): + """Dataset for multimodal (text + image) data""" + + def __init__( + self, + processor: ProcessorMixin, + device: str = "cpu", + max_length: int = 4096, + num_samples: int = -1, + data_source: Union[str, Dict] = None, + is_hf_dataset: bool = False, + use_audio_in_video: bool = False, + ): + super().__init__(processor, device, max_length) + self.is_hf_dataset = is_hf_dataset + self.use_audio_in_video = use_audio_in_video + + self._load_file_based_dataset(data_source, num_samples) + + def _load_file_based_dataset(self, data_path: str, num_samples: int): + """Load dataset from local file system""" + path_obj = Path(data_path) + data_dir = path_obj.parent + + line_count = 0 + with open(data_path, "r") as f: + for line in f: + if num_samples > 0 and line_count >= num_samples: + break + data = json.loads(line.strip()) + video_path = None + audio_path = None + image_path = None + + if "video_path" in data: + video_path = os.path.normpath( + os.path.join(data_dir, data["video_path"]) + ) + if "audio_path" in data: + audio_path = os.path.normpath( + os.path.join(data_dir, data["audio_path"]) + ) + if "image_path" in data: + image_path = os.path.normpath( + os.path.join(data_dir, data["image_path"]) + ) + + ms = data.get("messages") + + conversation = [] + for m in ms: + if m["role"] == "system": + conversation.append( + { + "role": "system", + "content": [{"type": "text", "text": m["content"]}], + } + ) + elif m["role"] == "user": + content = [] + text_content = m["content"] + text_content = ( + text_content.replace("