diff --git a/.azure-pipelines/scripts/ut/run_ut_xpu.sh b/.azure-pipelines/scripts/ut/run_ut_xpu.sh index 1978700db..7716a0cc6 100644 --- a/.azure-pipelines/scripts/ut/run_ut_xpu.sh +++ b/.azure-pipelines/scripts/ut/run_ut_xpu.sh @@ -4,8 +4,7 @@ set -xe # install requirements echo "##[group]set up UT env..." uv pip install pytest-cov pytest-html -uv pip install -r /auto-round/test/test_ark/requirements.txt \ - --extra-index-url https://download.pytorch.org/whl/xpu +uv pip install -r /auto-round/test/test_ark/requirements.txt cd /auto-round && uv pip install . echo "##[endgroup]" diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index 7b1783998..de9ea903f 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -439,21 +439,20 @@ def fp8_static_scheme_checker( requirements=["autoawq", "transformers"], ) -# BackendInfos["auto_round_kernel"] = BackendInfo( -# device=["cpu"], -# sym=[True, False], -# packing_format=GPTQ_FORMAT_NO_ZP, -# bits=[2, 4, 8], -# group_size=None, -# priority=6, -# checkers=[ark_feature_checker], -# alias=["ark"], -# compute_dtype=["float32", "float16"], -# data_type=["int"], -# act_bits=WOQ_DEFAULT_ACT_BITS, -# requirements=["torch>=2.9.0", "auto_round_kernel"], -# systems=["linux"], -# ) +BackendInfos["auto_round_kernel"] = BackendInfo( + device=["cpu"], + sym=[True, False], + packing_format=GPTQ_FORMAT_NO_ZP, + bits=[2, 4, 8], + group_size=None, + priority=6, + checkers=[ark_feature_checker], + alias=["ark"], + compute_dtype=["float32", "float16"], + data_type=["int"], + act_bits=WOQ_DEFAULT_ACT_BITS, + requirements=["torch>=2.8.0", "auto_round_kernel"], +) BackendInfos["auto_round_kernel_xpu"] = BackendInfo( device=["xpu"], @@ -467,25 +466,23 @@ def fp8_static_scheme_checker( compute_dtype=["float32", "float16"], data_type=["int"], act_bits=WOQ_DEFAULT_ACT_BITS, - requirements=["torch>=2.9.0", "auto_round_kernel"], - systems=["linux"], + requirements=["torch>=2.8.0", "auto_round_kernel"], ) -# BackendInfos["auto_round_kernel_zp"] = BackendInfo( -# device=["cpu"], -# sym=[True, False], -# packing_format=GPTQ_FORMAT, -# bits=[2, 4, 8], -# group_size=None, -# priority=6, -# checkers=[ark_feature_checker], -# alias=["ark"], -# compute_dtype=["float32", "float16"], -# data_type=["int"], -# act_bits=WOQ_DEFAULT_ACT_BITS, -# requirements=["torch>=2.9.0", "auto_round_kernel"], -# systems=["linux"], -# ) +BackendInfos["auto_round_kernel_zp"] = BackendInfo( + device=["cpu"], + sym=[True, False], + packing_format=GPTQ_FORMAT, + bits=[2, 4, 8], + group_size=None, + priority=6, + checkers=[ark_feature_checker], + alias=["ark"], + compute_dtype=["float32", "float16"], + data_type=["int"], + act_bits=WOQ_DEFAULT_ACT_BITS, + requirements=["torch>=2.8.0", "auto_round_kernel"], +) BackendInfos["auto_round_kernel_zp_xpu"] = BackendInfo( device=["xpu"], @@ -499,31 +496,29 @@ def fp8_static_scheme_checker( compute_dtype=["float32", "float16"], data_type=["int"], act_bits=WOQ_DEFAULT_ACT_BITS, - requirements=["torch>=2.9.0", "auto_round_kernel"], - systems=["linux"], + requirements=["torch>=2.8.0", "auto_round_kernel"], ) -# BackendInfos["auto_round_kernel_awq"] = BackendInfo( -# device=["cpu"], -# sym=[True, False], -# packing_format=AWQ_FORMAT, -# bits=[2, 4, 8], -# group_size=None, -# priority=6, -# checkers=[ark_feature_checker], -# alias=["ark"], -# compute_dtype=["float32", "float16"], -# data_type=["int"], -# act_bits=WOQ_DEFAULT_ACT_BITS, -# requirements=["torch>=2.9.0", "auto_round_kernel"], -# systems=["linux"], -# ) +BackendInfos["auto_round_kernel_awq"] = BackendInfo( + device=["cpu"], + sym=[True, False], + packing_format=AWQ_FORMAT, + bits=[4], + group_size=None, + priority=6, + checkers=[ark_feature_checker], + alias=["ark"], + compute_dtype=["float32", "float16"], + data_type=["int"], + act_bits=WOQ_DEFAULT_ACT_BITS, + requirements=["torch>=2.8.0", "auto_round_kernel"], +) BackendInfos["auto_round_kernel_awq_xpu"] = BackendInfo( device=["xpu"], sym=[True], packing_format=AWQ_FORMAT, - bits=[4, 8], + bits=[4], group_size=None, priority=6, checkers=[ark_feature_checker], @@ -531,8 +526,7 @@ def fp8_static_scheme_checker( compute_dtype=["float32", "float16"], data_type=["int"], act_bits=WOQ_DEFAULT_ACT_BITS, - requirements=["torch>=2.9.0", "auto_round_kernel"], - systems=["linux"], + requirements=["torch>=2.8.0", "auto_round_kernel"], ) BackendInfos["ipex_gptq_cpu"] = BackendInfo( diff --git a/auto_round_extension/ark/README.md b/auto_round_extension/ark/README.md new file mode 100644 index 000000000..dd14579de --- /dev/null +++ b/auto_round_extension/ark/README.md @@ -0,0 +1,90 @@ +## What is AutoRound Kernel? +AutoRound Kernel is a low-bit acceleration library for Intel platform. + +The kernels are optimized for the following CPUs: +* Intel Xeon Scalable processor (formerly Sapphire Rapids, and Emerald Rapids) +* Intel Xeon 6 processors (formerly Sierra Forest and Granite Rapids) + +The kernels are optimized for the following GPUs: +* Intel Arc B-Series Graphics and Intel Arc Pro B-Series Graphics + (formerly Battlemage) + +## Key Features +AutoRound Kernel provides weight-only linear computational capabilities for LLM inference. Specifically, the weight-only-quantization configs we support are given in the table below: +### CPU +| Weight dtype | Compute dtype | Scale dtype | Algorithm[1] | +| ---------------------- | :----------------: | :---------------: | :--------: | +| INT8 | INT8[2] / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT4 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT3 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT2 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT5 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT6 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT7 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT1 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| FP8 (E4M3, E5M2) | BF16 / FP32 | FP32 / FP8 (E8M0) | NA | +| FP4 (E2M1) | BF16 / FP32 | BF16 / FP32 | NA | + +### XPU +| Weight dtype | Compute dtype | Scale dtype | Algorithm | +| ---------------------- | :----------------: | :---------------: | :--------: | +| INT8 | INT8 / FP16 | FP16 | sym | +| INT4 | INT8 / FP16 | FP16 | sym | +| FP8 (E4M3, E5M2) | FP16 | FP16 / FP8 (E8M0) | NA | + +[1]: Quantization algorithms for integer types: symmetric or asymmetric. +[2]: Includes dynamic activation quantization; results are dequantized to floating-point formats. + + +## Installation +### Install via pip +```bash +# Install the latest auto-round kernel which may upgrade your PyTorch version automatically +pip install auto-round-kernel +# Install auto-round kernel with respective to specific PyTorch version (e.g., v2.8.x) +pip install auto-round-kernel torch~=2.8.0 +``` + +
+Other Installation Methods + +### Install via Script +```bash +curl -fsSL https://raw.githubusercontent.com/intel/auto-round/main/auto_round_extension/ark/install_kernel.py +python3 install_kernel.py +``` +**Notes:** +Recommend to use this method if you want to keep your current PyTorch and auto-round versions. +This installation script will detect the current environment and install the corresponding auto-round-kernel version. + +### Install via auto_round +```bash +pip install auto-round +auto-round-kernel-install +``` + +
+ +### Versioning Scheme +The version number of auto-round-kernel follows the format: +`{auto-round major version}.{auto-round minor version}.{oneAPI version}.{kernel version}` + +**For example: v0.9.1.1** +- The first two digits (0.9) correspond to the major and minor version of the auto_round framework. +- The third digit (1) represents the major version of Intel oneAPI: `1` indicates support for oneAPI 2025.1 (typically Torch 2.8), `2` indicates support for oneAPI 2025.2 (typically Torch 2.9). +- The final digit (1) is the patch version of auto-round-kernel, reflecting updates, bug fixes, or improvements to the kernel package itself. + +**Version mapping table** + +| auto-round-kernel Version | auto-round Version | oneAPI Version | Typical PyTorch Version | +|:-------------------------:|:------------------:|:--------------:|:-------------------------:| +| 0.9.1.x | 0.9.x | 2025.1 | 2.8.x | +| 0.9.2.x | 0.9.x | 2025.2 | 2.9.x | + +**Notes:** oneAPI version is aligned with PyTorch version during auto-round-kernel binary build, but oneAPI toolkit is not required in runtime. + +### Validated Hardware Environment +#### CPU based on [Intel 64 architecture or compatible processors](https://en.wikipedia.org/wiki/X86-64): +* Intel Xeon Scalable processor (Granite Rapids) +#### GPU built on Intel's Xe architecture: +* Intel Arc B-Series Graphics (Battlemage) \ No newline at end of file diff --git a/auto_round_extension/ark/install_kernel.py b/auto_round_extension/ark/install_kernel.py new file mode 100644 index 000000000..bb1adfc82 --- /dev/null +++ b/auto_round_extension/ark/install_kernel.py @@ -0,0 +1,60 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import subprocess +import sys + + +def get_torch_minor(): + try: + import torch + + m = re.match(r"^(\d+)\.(\d+)", torch.__version__) + return f"{m.group(1)}.{m.group(2)}" if m else None + except ImportError: + return None + + +def get_auto_round_minor(): + try: + import auto_round + + m = re.match(r"^(\d+)\.(\d+)", auto_round.__version__) + return f"{m.group(1)}.{m.group(2)}" if m else None + except ImportError: + return None + + +# Map torch minor version to kernel version +auto_round_minor = "0.9" if get_auto_round_minor() is None else get_auto_round_minor() +KERNEL_MAP = { + "2.8": f"auto-round-kernel~={auto_round_minor}.1.0", + "2.9": f"auto-round-kernel~={auto_round_minor}.2.0", +} + + +def main(): + torch_minor = get_torch_minor() + if torch_minor and torch_minor in KERNEL_MAP: + pkg = KERNEL_MAP[torch_minor] + print(f"Detected torch {torch_minor}, installing {pkg} ...") + subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "--upgrade-strategy", "only-if-needed"]) + else: + print("torch not found or no mapping for your version. Installing the latest auto-round-kernel ...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "auto-round-kernel"]) + + +if __name__ == "__main__": + main() diff --git a/auto_round_extension/ark/qlinear.py b/auto_round_extension/ark/qlinear.py index 9e28b7347..1e34ad4f6 100644 --- a/auto_round_extension/ark/qlinear.py +++ b/auto_round_extension/ark/qlinear.py @@ -21,9 +21,10 @@ from auto_round.utils import convert_dtype_torch2str, logger try: - import auto_round_kernel as ark + import auto_round_kernel ARK_INSTALLED = True + ark = auto_round_kernel.ARK() except: ARK_INSTALLED = False @@ -68,7 +69,6 @@ def __init__( self.pack_num = 32 // self.bits self.weight_dtype = weight_dtype self.asym = not sym - ark.set_threads(torch.get_num_threads()) if not self.infeatures % self.group_size == 0: raise NotImplementedError("in_features must be divisible by group_size") if "awq" in self.QUANT_TYPE: @@ -148,21 +148,21 @@ def post_init(self): if self.qweight.device.type == "xpu": self.sdt = "fp16" - self.cdt = "fp16" - scales = self.scales.to(torch.float16).contiguous() + self.cdt = "int8" + self.torch_dt = torch.float16 else: self.sdt = "fp32" self.cdt = "auto" - if self.asym and self.bits == 8: - self.cdt = "fp32" - scales = self.scales.float().contiguous() + self.torch_dt = torch.float32 + self.wdt = BITS_DTYPE_MAPPING[self.bits] + scales = self.scales.to(self.torch_dt).contiguous() self.qweight = ark.repack_quantized_weight( intweight.contiguous(), scales, zeros.contiguous(), - torch.empty(0), + self.group_size, # compute_dtype self.cdt, # weight_dtype @@ -170,48 +170,34 @@ def post_init(self): # scale_dtype self.sdt, self.asym, - self.group_size, ) # free mem self.qzeros = torch.empty(0) self.scales = torch.empty(0) if self.bias is not None: - if self.bias.device.type == "cpu": - self.bias = self.bias.to(torch.float32) - else: - self.bias = self.bias.to(torch.float16) + self.bias = self.bias.to(self.torch_dt) else: self.bias = torch.empty(0) def forward(self, x: torch.Tensor): raw_input_dtype = x.dtype - if x.device.type == "cpu": - odt = torch.float32 - self.bias = self.bias.to(torch.float32) - if raw_input_dtype != torch.float32: - x = x.to(torch.float32) - else: - odt = x.dtype - + x = x.to(self.torch_dt) out_shape = x.shape[:-1] + (self.outfeatures,) - x = x.view(-1, x.shape[-1]) # convert xd to 2d - out_2d_shape = x.shape[:-1] + (self.outfeatures,) - outputs = torch.empty(out_2d_shape, device=x.device, dtype=odt) - - ark.woq_linear( - x, + x = x.view(-1, x.shape[-1]) + self.bias = self.bias.to(self.torch_dt) + outputs = ark.woqgemm( + x, # convert xd to 2d, self.qweight, self.bias, - outputs, + self.outfeatures, + self.infeatures, + self.group_size, self.cdt, # compute_dtype self.wdt, # weight_dtype self.sdt, # scale_dtype self.asym, - self.group_size, ) - if x.device.type == "xpu": - outputs = outputs + self.bias return outputs.to(raw_input_dtype).view(out_shape) diff --git a/setup.cfg b/setup.cfg index 6ec792249..3c95893f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,4 +17,5 @@ console_scripts = auto_round_best = auto_round.__main__:run_best auto-round-light = auto_round.__main__:run_light auto_round_light = auto_round.__main__:run_light + auto-round-kernel-install = auto_round_extension.ark.install_kernel:main diff --git a/setup.py b/setup.py index d0d5ec397..724ff4ee8 100644 --- a/setup.py +++ b/setup.py @@ -108,7 +108,7 @@ def fetch_requirements(path): ), "install_requires": fetch_requirements("requirements.txt"), # auto-round[cpu] is deprecated, will be removed from v1.0.0 - "extras_require": {"cpu": fetch_requirements("requirements-cpu.txt"), "kernel": ["auto-round-kernel"]}, + "extras_require": {"cpu": fetch_requirements("requirements-cpu.txt")}, } ############################################################################### diff --git a/test/test_ark/test_model.py b/test/test_ark/test_model.py index b30d0a083..361f1bdf9 100644 --- a/test/test_ark/test_model.py +++ b/test/test_ark/test_model.py @@ -58,27 +58,27 @@ def main_op(self, format, bits, group_size, sym, dtype, device, fast_cfg=True, t @pytest.mark.parametrize("format", ["auto_round", "auto_round:gptqmodel"]) @pytest.mark.parametrize("bits, group_size, sym", [(4, 128, True), (8, 128, True)]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) - @pytest.mark.parametrize("device", ["xpu"]) + @pytest.mark.parametrize("device", ["cpu", "xpu"]) def test_formats(self, format, bits, group_size, sym, dtype, device): self.main_op(format, bits, group_size, sym, dtype, device) @pytest.mark.parametrize("format", ["auto_round:auto_awq"]) @pytest.mark.parametrize("bits, group_size, sym", [(4, 32, True)]) @pytest.mark.parametrize("dtype", [torch.float16]) - @pytest.mark.parametrize("device", ["xpu"]) + @pytest.mark.parametrize("device", ["cpu", "xpu"]) def test_awq_fp16(self, format, bits, group_size, sym, dtype, device): self.main_op(format, bits, group_size, sym, dtype, device) - # @pytest.mark.parametrize("format", ["auto_round"]) - # @pytest.mark.parametrize("bits, group_size, sym", [(2, 32, False)]) - # @pytest.mark.parametrize("dtype", [torch.bfloat16]) - # @pytest.mark.parametrize("device", ["cpu"]) - # def test_other_bits(self, format, bits, group_size, sym, dtype, device): - # self.main_op(format, bits, group_size, sym, dtype, device, False, 0.2) + @pytest.mark.parametrize("format", ["auto_round"]) + @pytest.mark.parametrize("bits, group_size, sym", [(2, 32, False)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("device", ["cpu"]) + def test_other_bits(self, format, bits, group_size, sym, dtype, device): + self.main_op(format, bits, group_size, sym, dtype, device, False, 0.2) if __name__ == "__main__": p = TestAutoRoundARKBackend() p.setup_class() - p.test_formats("auto_round:auto_awq", 4, 32, True, torch.bfloat16, "xpu") + p.test_formats("auto_round:auto_awq", 4, 64, True, torch.bfloat16, "xpu") p.teardown_class()