diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 9a2524953..0c51822a6 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -26,7 +26,7 @@ "cpu", "cuda", # NVIDIA/AMD GPU "xpu", # Intel GPU - "hpu", # Gaudi + "hpu", # Intel Gaudi "npu", # Ascend NPU "mps", # Apple Silicon } @@ -37,6 +37,9 @@ if torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops +if hasattr(torch, "hpu") and torch.hpu.is_available(): + from .backends.hpu import ops as hpu_ops + def _import_backends(): """ diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 746d6c1ec..80fc86861 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -451,7 +451,7 @@ def matmul_4bit( else: return MatMul4Bit.apply(A, B, out, bias, quant_state) - if A.numel() == A.shape[-1] and A.requires_grad == False: + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/hpu/__init__.py b/bitsandbytes/backends/hpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py new file mode 100644 index 000000000..1eeb7f014 --- /dev/null +++ b/bitsandbytes/backends/hpu/ops.py @@ -0,0 +1,53 @@ +from collections.abc import Sequence +import math + +import torch + +from bitsandbytes.utils import _reverse_4bit_compress_format + +from ..._ops import register_kernel +from ..utils import GAUDI_SW_VER + + +@register_kernel("bitsandbytes::dequantize_4bit", "hpu") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}") + torch._check( + A.dtype in [torch.bfloat16, torch.uint8], + lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}", + ) + + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True + + A = A.reshape(-1) + + if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22): + A = _reverse_4bit_compress_format(A) + + # HPU dequantization function for NF4 quantized tensors. + out_dq = torch.ops.hpu.dequantize_nf4( + A, + absmax.to(dtype), + blocksize, + out_shape=(math.prod(shape),), + out_dtype=dtype, + ) + + output = out_dq.reshape(shape) + + if transpose: + output = output.t() + + return output diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index cc88ffae1..c7aba2964 100755 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -1,3 +1,6 @@ +import subprocess + +from packaging import version import torch try: @@ -55,3 +58,23 @@ device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} + + +def get_gaudi_sw_version(): + """ + Returns the installed version of Gaudi SW. + """ + output = subprocess.run( + "pip list | grep habana-torch-plugin", + shell=True, + text=True, + capture_output=True, + ) + # If grep return nothing + if not output.stdout.strip(): + return None + + return version.parse(output.stdout.split("\n")[0].split()[-1]) + + +GAUDI_SW_VER = get_gaudi_sw_version() diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ccd842ce3..639ef125a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -442,7 +442,7 @@ def __init__( ) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype - self.compute_type_is_set = False + self.compute_type_is_set = False if compute_dtype is None else True self.quant_state = None self.quant_storage = quant_storage self.ipex_linear_is_set = False