diff --git a/.azure-pipelines/scripts/ut/run_ut_xpu.sh b/.azure-pipelines/scripts/ut/run_ut_xpu.sh
index 1978700db..7716a0cc6 100644
--- a/.azure-pipelines/scripts/ut/run_ut_xpu.sh
+++ b/.azure-pipelines/scripts/ut/run_ut_xpu.sh
@@ -4,8 +4,7 @@ set -xe
# install requirements
echo "##[group]set up UT env..."
uv pip install pytest-cov pytest-html
-uv pip install -r /auto-round/test/test_ark/requirements.txt \
- --extra-index-url https://download.pytorch.org/whl/xpu
+uv pip install -r /auto-round/test/test_ark/requirements.txt
cd /auto-round && uv pip install .
echo "##[endgroup]"
diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py
index 7b1783998..de9ea903f 100644
--- a/auto_round/inference/backend.py
+++ b/auto_round/inference/backend.py
@@ -439,21 +439,20 @@ def fp8_static_scheme_checker(
requirements=["autoawq", "transformers"],
)
-# BackendInfos["auto_round_kernel"] = BackendInfo(
-# device=["cpu"],
-# sym=[True, False],
-# packing_format=GPTQ_FORMAT_NO_ZP,
-# bits=[2, 4, 8],
-# group_size=None,
-# priority=6,
-# checkers=[ark_feature_checker],
-# alias=["ark"],
-# compute_dtype=["float32", "float16"],
-# data_type=["int"],
-# act_bits=WOQ_DEFAULT_ACT_BITS,
-# requirements=["torch>=2.9.0", "auto_round_kernel"],
-# systems=["linux"],
-# )
+BackendInfos["auto_round_kernel"] = BackendInfo(
+ device=["cpu"],
+ sym=[True, False],
+ packing_format=GPTQ_FORMAT_NO_ZP,
+ bits=[2, 4, 8],
+ group_size=None,
+ priority=6,
+ checkers=[ark_feature_checker],
+ alias=["ark"],
+ compute_dtype=["float32", "float16"],
+ data_type=["int"],
+ act_bits=WOQ_DEFAULT_ACT_BITS,
+ requirements=["torch>=2.8.0", "auto_round_kernel"],
+)
BackendInfos["auto_round_kernel_xpu"] = BackendInfo(
device=["xpu"],
@@ -467,25 +466,23 @@ def fp8_static_scheme_checker(
compute_dtype=["float32", "float16"],
data_type=["int"],
act_bits=WOQ_DEFAULT_ACT_BITS,
- requirements=["torch>=2.9.0", "auto_round_kernel"],
- systems=["linux"],
+ requirements=["torch>=2.8.0", "auto_round_kernel"],
)
-# BackendInfos["auto_round_kernel_zp"] = BackendInfo(
-# device=["cpu"],
-# sym=[True, False],
-# packing_format=GPTQ_FORMAT,
-# bits=[2, 4, 8],
-# group_size=None,
-# priority=6,
-# checkers=[ark_feature_checker],
-# alias=["ark"],
-# compute_dtype=["float32", "float16"],
-# data_type=["int"],
-# act_bits=WOQ_DEFAULT_ACT_BITS,
-# requirements=["torch>=2.9.0", "auto_round_kernel"],
-# systems=["linux"],
-# )
+BackendInfos["auto_round_kernel_zp"] = BackendInfo(
+ device=["cpu"],
+ sym=[True, False],
+ packing_format=GPTQ_FORMAT,
+ bits=[2, 4, 8],
+ group_size=None,
+ priority=6,
+ checkers=[ark_feature_checker],
+ alias=["ark"],
+ compute_dtype=["float32", "float16"],
+ data_type=["int"],
+ act_bits=WOQ_DEFAULT_ACT_BITS,
+ requirements=["torch>=2.8.0", "auto_round_kernel"],
+)
BackendInfos["auto_round_kernel_zp_xpu"] = BackendInfo(
device=["xpu"],
@@ -499,31 +496,29 @@ def fp8_static_scheme_checker(
compute_dtype=["float32", "float16"],
data_type=["int"],
act_bits=WOQ_DEFAULT_ACT_BITS,
- requirements=["torch>=2.9.0", "auto_round_kernel"],
- systems=["linux"],
+ requirements=["torch>=2.8.0", "auto_round_kernel"],
)
-# BackendInfos["auto_round_kernel_awq"] = BackendInfo(
-# device=["cpu"],
-# sym=[True, False],
-# packing_format=AWQ_FORMAT,
-# bits=[2, 4, 8],
-# group_size=None,
-# priority=6,
-# checkers=[ark_feature_checker],
-# alias=["ark"],
-# compute_dtype=["float32", "float16"],
-# data_type=["int"],
-# act_bits=WOQ_DEFAULT_ACT_BITS,
-# requirements=["torch>=2.9.0", "auto_round_kernel"],
-# systems=["linux"],
-# )
+BackendInfos["auto_round_kernel_awq"] = BackendInfo(
+ device=["cpu"],
+ sym=[True, False],
+ packing_format=AWQ_FORMAT,
+ bits=[4],
+ group_size=None,
+ priority=6,
+ checkers=[ark_feature_checker],
+ alias=["ark"],
+ compute_dtype=["float32", "float16"],
+ data_type=["int"],
+ act_bits=WOQ_DEFAULT_ACT_BITS,
+ requirements=["torch>=2.8.0", "auto_round_kernel"],
+)
BackendInfos["auto_round_kernel_awq_xpu"] = BackendInfo(
device=["xpu"],
sym=[True],
packing_format=AWQ_FORMAT,
- bits=[4, 8],
+ bits=[4],
group_size=None,
priority=6,
checkers=[ark_feature_checker],
@@ -531,8 +526,7 @@ def fp8_static_scheme_checker(
compute_dtype=["float32", "float16"],
data_type=["int"],
act_bits=WOQ_DEFAULT_ACT_BITS,
- requirements=["torch>=2.9.0", "auto_round_kernel"],
- systems=["linux"],
+ requirements=["torch>=2.8.0", "auto_round_kernel"],
)
BackendInfos["ipex_gptq_cpu"] = BackendInfo(
diff --git a/auto_round_extension/ark/README.md b/auto_round_extension/ark/README.md
new file mode 100644
index 000000000..dd14579de
--- /dev/null
+++ b/auto_round_extension/ark/README.md
@@ -0,0 +1,90 @@
+## What is AutoRound Kernel?
+AutoRound Kernel is a low-bit acceleration library for Intel platform.
+
+The kernels are optimized for the following CPUs:
+* Intel Xeon Scalable processor (formerly Sapphire Rapids, and Emerald Rapids)
+* Intel Xeon 6 processors (formerly Sierra Forest and Granite Rapids)
+
+The kernels are optimized for the following GPUs:
+* Intel Arc B-Series Graphics and Intel Arc Pro B-Series Graphics
+ (formerly Battlemage)
+
+## Key Features
+AutoRound Kernel provides weight-only linear computational capabilities for LLM inference. Specifically, the weight-only-quantization configs we support are given in the table below:
+### CPU
+| Weight dtype | Compute dtype | Scale dtype | Algorithm[1] |
+| ---------------------- | :----------------: | :---------------: | :--------: |
+| INT8 | INT8[2] / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT4 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT3 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT2 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT5 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT6 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT7 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| INT1 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym |
+| FP8 (E4M3, E5M2) | BF16 / FP32 | FP32 / FP8 (E8M0) | NA |
+| FP4 (E2M1) | BF16 / FP32 | BF16 / FP32 | NA |
+
+### XPU
+| Weight dtype | Compute dtype | Scale dtype | Algorithm |
+| ---------------------- | :----------------: | :---------------: | :--------: |
+| INT8 | INT8 / FP16 | FP16 | sym |
+| INT4 | INT8 / FP16 | FP16 | sym |
+| FP8 (E4M3, E5M2) | FP16 | FP16 / FP8 (E8M0) | NA |
+
+[1]: Quantization algorithms for integer types: symmetric or asymmetric.
+[2]: Includes dynamic activation quantization; results are dequantized to floating-point formats.
+
+
+## Installation
+### Install via pip
+```bash
+# Install the latest auto-round kernel which may upgrade your PyTorch version automatically
+pip install auto-round-kernel
+# Install auto-round kernel with respective to specific PyTorch version (e.g., v2.8.x)
+pip install auto-round-kernel torch~=2.8.0
+```
+
+
+Other Installation Methods
+
+### Install via Script
+```bash
+curl -fsSL https://raw.githubusercontent.com/intel/auto-round/main/auto_round_extension/ark/install_kernel.py
+python3 install_kernel.py
+```
+**Notes:**
+Recommend to use this method if you want to keep your current PyTorch and auto-round versions.
+This installation script will detect the current environment and install the corresponding auto-round-kernel version.
+
+### Install via auto_round
+```bash
+pip install auto-round
+auto-round-kernel-install
+```
+
+
+
+### Versioning Scheme
+The version number of auto-round-kernel follows the format:
+`{auto-round major version}.{auto-round minor version}.{oneAPI version}.{kernel version}`
+
+**For example: v0.9.1.1**
+- The first two digits (0.9) correspond to the major and minor version of the auto_round framework.
+- The third digit (1) represents the major version of Intel oneAPI: `1` indicates support for oneAPI 2025.1 (typically Torch 2.8), `2` indicates support for oneAPI 2025.2 (typically Torch 2.9).
+- The final digit (1) is the patch version of auto-round-kernel, reflecting updates, bug fixes, or improvements to the kernel package itself.
+
+**Version mapping table**
+
+| auto-round-kernel Version | auto-round Version | oneAPI Version | Typical PyTorch Version |
+|:-------------------------:|:------------------:|:--------------:|:-------------------------:|
+| 0.9.1.x | 0.9.x | 2025.1 | 2.8.x |
+| 0.9.2.x | 0.9.x | 2025.2 | 2.9.x |
+
+**Notes:** oneAPI version is aligned with PyTorch version during auto-round-kernel binary build, but oneAPI toolkit is not required in runtime.
+
+### Validated Hardware Environment
+#### CPU based on [Intel 64 architecture or compatible processors](https://en.wikipedia.org/wiki/X86-64):
+* Intel Xeon Scalable processor (Granite Rapids)
+#### GPU built on Intel's Xe architecture:
+* Intel Arc B-Series Graphics (Battlemage)
\ No newline at end of file
diff --git a/auto_round_extension/ark/install_kernel.py b/auto_round_extension/ark/install_kernel.py
new file mode 100644
index 000000000..bb1adfc82
--- /dev/null
+++ b/auto_round_extension/ark/install_kernel.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2026 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import subprocess
+import sys
+
+
+def get_torch_minor():
+ try:
+ import torch
+
+ m = re.match(r"^(\d+)\.(\d+)", torch.__version__)
+ return f"{m.group(1)}.{m.group(2)}" if m else None
+ except ImportError:
+ return None
+
+
+def get_auto_round_minor():
+ try:
+ import auto_round
+
+ m = re.match(r"^(\d+)\.(\d+)", auto_round.__version__)
+ return f"{m.group(1)}.{m.group(2)}" if m else None
+ except ImportError:
+ return None
+
+
+# Map torch minor version to kernel version
+auto_round_minor = "0.9" if get_auto_round_minor() is None else get_auto_round_minor()
+KERNEL_MAP = {
+ "2.8": f"auto-round-kernel~={auto_round_minor}.1.0",
+ "2.9": f"auto-round-kernel~={auto_round_minor}.2.0",
+}
+
+
+def main():
+ torch_minor = get_torch_minor()
+ if torch_minor and torch_minor in KERNEL_MAP:
+ pkg = KERNEL_MAP[torch_minor]
+ print(f"Detected torch {torch_minor}, installing {pkg} ...")
+ subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "--upgrade-strategy", "only-if-needed"])
+ else:
+ print("torch not found or no mapping for your version. Installing the latest auto-round-kernel ...")
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "auto-round-kernel"])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/auto_round_extension/ark/qlinear.py b/auto_round_extension/ark/qlinear.py
index 9e28b7347..1e34ad4f6 100644
--- a/auto_round_extension/ark/qlinear.py
+++ b/auto_round_extension/ark/qlinear.py
@@ -21,9 +21,10 @@
from auto_round.utils import convert_dtype_torch2str, logger
try:
- import auto_round_kernel as ark
+ import auto_round_kernel
ARK_INSTALLED = True
+ ark = auto_round_kernel.ARK()
except:
ARK_INSTALLED = False
@@ -68,7 +69,6 @@ def __init__(
self.pack_num = 32 // self.bits
self.weight_dtype = weight_dtype
self.asym = not sym
- ark.set_threads(torch.get_num_threads())
if not self.infeatures % self.group_size == 0:
raise NotImplementedError("in_features must be divisible by group_size")
if "awq" in self.QUANT_TYPE:
@@ -148,21 +148,21 @@ def post_init(self):
if self.qweight.device.type == "xpu":
self.sdt = "fp16"
- self.cdt = "fp16"
- scales = self.scales.to(torch.float16).contiguous()
+ self.cdt = "int8"
+ self.torch_dt = torch.float16
else:
self.sdt = "fp32"
self.cdt = "auto"
- if self.asym and self.bits == 8:
- self.cdt = "fp32"
- scales = self.scales.float().contiguous()
+ self.torch_dt = torch.float32
+
self.wdt = BITS_DTYPE_MAPPING[self.bits]
+ scales = self.scales.to(self.torch_dt).contiguous()
self.qweight = ark.repack_quantized_weight(
intweight.contiguous(),
scales,
zeros.contiguous(),
- torch.empty(0),
+ self.group_size,
# compute_dtype
self.cdt,
# weight_dtype
@@ -170,48 +170,34 @@ def post_init(self):
# scale_dtype
self.sdt,
self.asym,
- self.group_size,
)
# free mem
self.qzeros = torch.empty(0)
self.scales = torch.empty(0)
if self.bias is not None:
- if self.bias.device.type == "cpu":
- self.bias = self.bias.to(torch.float32)
- else:
- self.bias = self.bias.to(torch.float16)
+ self.bias = self.bias.to(self.torch_dt)
else:
self.bias = torch.empty(0)
def forward(self, x: torch.Tensor):
raw_input_dtype = x.dtype
- if x.device.type == "cpu":
- odt = torch.float32
- self.bias = self.bias.to(torch.float32)
- if raw_input_dtype != torch.float32:
- x = x.to(torch.float32)
- else:
- odt = x.dtype
-
+ x = x.to(self.torch_dt)
out_shape = x.shape[:-1] + (self.outfeatures,)
- x = x.view(-1, x.shape[-1]) # convert xd to 2d
- out_2d_shape = x.shape[:-1] + (self.outfeatures,)
- outputs = torch.empty(out_2d_shape, device=x.device, dtype=odt)
-
- ark.woq_linear(
- x,
+ x = x.view(-1, x.shape[-1])
+ self.bias = self.bias.to(self.torch_dt)
+ outputs = ark.woqgemm(
+ x, # convert xd to 2d,
self.qweight,
self.bias,
- outputs,
+ self.outfeatures,
+ self.infeatures,
+ self.group_size,
self.cdt, # compute_dtype
self.wdt, # weight_dtype
self.sdt, # scale_dtype
self.asym,
- self.group_size,
)
- if x.device.type == "xpu":
- outputs = outputs + self.bias
return outputs.to(raw_input_dtype).view(out_shape)
diff --git a/setup.cfg b/setup.cfg
index 6ec792249..3c95893f1 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -17,4 +17,5 @@ console_scripts =
auto_round_best = auto_round.__main__:run_best
auto-round-light = auto_round.__main__:run_light
auto_round_light = auto_round.__main__:run_light
+ auto-round-kernel-install = auto_round_extension.ark.install_kernel:main
diff --git a/setup.py b/setup.py
index d0d5ec397..724ff4ee8 100644
--- a/setup.py
+++ b/setup.py
@@ -108,7 +108,7 @@ def fetch_requirements(path):
),
"install_requires": fetch_requirements("requirements.txt"),
# auto-round[cpu] is deprecated, will be removed from v1.0.0
- "extras_require": {"cpu": fetch_requirements("requirements-cpu.txt"), "kernel": ["auto-round-kernel"]},
+ "extras_require": {"cpu": fetch_requirements("requirements-cpu.txt")},
}
###############################################################################
diff --git a/test/test_ark/test_model.py b/test/test_ark/test_model.py
index b30d0a083..361f1bdf9 100644
--- a/test/test_ark/test_model.py
+++ b/test/test_ark/test_model.py
@@ -58,27 +58,27 @@ def main_op(self, format, bits, group_size, sym, dtype, device, fast_cfg=True, t
@pytest.mark.parametrize("format", ["auto_round", "auto_round:gptqmodel"])
@pytest.mark.parametrize("bits, group_size, sym", [(4, 128, True), (8, 128, True)])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
- @pytest.mark.parametrize("device", ["xpu"])
+ @pytest.mark.parametrize("device", ["cpu", "xpu"])
def test_formats(self, format, bits, group_size, sym, dtype, device):
self.main_op(format, bits, group_size, sym, dtype, device)
@pytest.mark.parametrize("format", ["auto_round:auto_awq"])
@pytest.mark.parametrize("bits, group_size, sym", [(4, 32, True)])
@pytest.mark.parametrize("dtype", [torch.float16])
- @pytest.mark.parametrize("device", ["xpu"])
+ @pytest.mark.parametrize("device", ["cpu", "xpu"])
def test_awq_fp16(self, format, bits, group_size, sym, dtype, device):
self.main_op(format, bits, group_size, sym, dtype, device)
- # @pytest.mark.parametrize("format", ["auto_round"])
- # @pytest.mark.parametrize("bits, group_size, sym", [(2, 32, False)])
- # @pytest.mark.parametrize("dtype", [torch.bfloat16])
- # @pytest.mark.parametrize("device", ["cpu"])
- # def test_other_bits(self, format, bits, group_size, sym, dtype, device):
- # self.main_op(format, bits, group_size, sym, dtype, device, False, 0.2)
+ @pytest.mark.parametrize("format", ["auto_round"])
+ @pytest.mark.parametrize("bits, group_size, sym", [(2, 32, False)])
+ @pytest.mark.parametrize("dtype", [torch.bfloat16])
+ @pytest.mark.parametrize("device", ["cpu"])
+ def test_other_bits(self, format, bits, group_size, sym, dtype, device):
+ self.main_op(format, bits, group_size, sym, dtype, device, False, 0.2)
if __name__ == "__main__":
p = TestAutoRoundARKBackend()
p.setup_class()
- p.test_formats("auto_round:auto_awq", 4, 32, True, torch.bfloat16, "xpu")
+ p.test_formats("auto_round:auto_awq", 4, 64, True, torch.bfloat16, "xpu")
p.teardown_class()