From 1ccd3d77f99258110dae5e6f1e71797f8ea8a4c2 Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Thu, 22 May 2025 12:46:42 +0300 Subject: [PATCH 1/5] supports hpu backend in main branch --- bitsandbytes/__init__.py | 5 ++- bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/backends/hpu/__init__.py | 0 bitsandbytes/backends/hpu/ops.py | 53 +++++++++++++++++++++++++++ bitsandbytes/backends/utils.py | 23 ++++++++++++ bitsandbytes/nn/modules.py | 2 +- 6 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 bitsandbytes/backends/hpu/__init__.py create mode 100644 bitsandbytes/backends/hpu/ops.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 9a2524953..0c51822a6 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -26,7 +26,7 @@ "cpu", "cuda", # NVIDIA/AMD GPU "xpu", # Intel GPU - "hpu", # Gaudi + "hpu", # Intel Gaudi "npu", # Ascend NPU "mps", # Apple Silicon } @@ -37,6 +37,9 @@ if torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops +if hasattr(torch, "hpu") and torch.hpu.is_available(): + from .backends.hpu import ops as hpu_ops + def _import_backends(): """ diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 746d6c1ec..80fc86861 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -451,7 +451,7 @@ def matmul_4bit( else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False: + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/hpu/__init__.py b/bitsandbytes/backends/hpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py new file mode 100644 index 000000000..e425f934e --- /dev/null +++ b/bitsandbytes/backends/hpu/ops.py @@ -0,0 +1,53 @@ +from collections.abc import Sequence +import math + +import torch + +from bitsandbytes.utils import _reverse_4bit_compress_format + +from ..._ops import register_kernel +from ..utils import GAUDI_SW_VER + + +@register_kernel("bitsandbytes::dequantize_4bit", "hpu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + dtype in (torch.bfloat16, torch.float32), lambda: f"4bit dequantization only bf16/f32, but got {dtype}" + ) + torch._check(A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8, but got {A.dtype}") + + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True + + A = A.reshape(-1) + + if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22): + A = _reverse_4bit_compress_format(A) + + # HPU dequantization function for NF4 quantized tensors. + out_dq = torch.ops.hpu.dequantize_nf4( + A, + absmax.to(dtype), + blocksize, + out_shape=(math.prod(shape),), + out_dtype=dtype, + ) + + output = out_dq.reshape(shape) + + if transpose: + output = output.t() + + return output diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index cc88ffae1..c7aba2964 100755 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -1,3 +1,6 @@ +import subprocess + +from packaging import version import torch try: @@ -55,3 +58,23 @@ device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} + + +def get_gaudi_sw_version(): + """ + Returns the installed version of Gaudi SW. + """ + output = subprocess.run( + "pip list | grep habana-torch-plugin", + shell=True, + text=True, + capture_output=True, + ) + # If grep return nothing + if not output.stdout.strip(): + return None + + return version.parse(output.stdout.split("\n")[0].split()[-1]) + + +GAUDI_SW_VER = get_gaudi_sw_version() diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ccd842ce3..639ef125a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -442,7 +442,7 @@ def __init__( ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype - self.compute_type_is_set = False + self.compute_type_is_set = False if compute_dtype is None else True self.quant_state = None self.quant_storage = quant_storage self.ipex_linear_is_set = False From fff24d6d36f70c63d0dabfcaee68f11751bfaf09 Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Thu, 5 Jun 2025 10:27:44 +0530 Subject: [PATCH 2/5] Update bitsandbytes/backends/hpu/ops.py updates the assertion message Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> --- bitsandbytes/backends/hpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index e425f934e..0c2b62ebc 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -23,7 +23,7 @@ def _( torch._check( dtype in (torch.bfloat16, torch.float32), lambda: f"4bit dequantization only bf16/f32, but got {dtype}" ) - torch._check(A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8, but got {A.dtype}") + torch._check(A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}") # Enable non uint8 dtype if A.dtype != torch.uint8: From 805fb6b347b782b08aeb2711ebf0a6d505113702 Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Thu, 5 Jun 2025 10:27:57 +0530 Subject: [PATCH 3/5] Update bitsandbytes/backends/hpu/ops.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> --- bitsandbytes/backends/hpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index 0c2b62ebc..ab2eb1dc4 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -19,7 +19,7 @@ def _( dtype: torch.dtype, ) -> torch.Tensor: torch._check_is_size(blocksize) - torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}") torch._check( dtype in (torch.bfloat16, torch.float32), lambda: f"4bit dequantization only bf16/f32, but got {dtype}" ) From 3900187617f5b949b1164832f92574781a6cfeb5 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 5 Jun 2025 09:46:37 -0400 Subject: [PATCH 4/5] Update ops.py Fix lint issue --- bitsandbytes/backends/hpu/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index ab2eb1dc4..98a80b965 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -21,7 +21,8 @@ def _( torch._check_is_size(blocksize) torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}") torch._check( - dtype in (torch.bfloat16, torch.float32), lambda: f"4bit dequantization only bf16/f32, but got {dtype}" + A.dtype in [torch.bfloat16, torch.uint8], + lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}", ) torch._check(A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}") From 1542c188bd220ce5339daad2fe2927a553be6e31 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 5 Jun 2025 10:05:33 -0400 Subject: [PATCH 5/5] Update ops.py --- bitsandbytes/backends/hpu/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py index 98a80b965..1eeb7f014 100644 --- a/bitsandbytes/backends/hpu/ops.py +++ b/bitsandbytes/backends/hpu/ops.py @@ -24,7 +24,6 @@ def _( A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}", ) - torch._check(A.dtype in [torch.bfloat16, torch.uint8], lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}") # Enable non uint8 dtype if A.dtype != torch.uint8: