diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index c43c5864b..5c141c6a8 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -626,6 +626,8 @@ def visit_DataType(self, t: DataType): 'float32x8': '__m256', 'int8x4': 'char4', 'uint8x4': 'uint4', + 'int4bx8': 'uint32_t', + 'uint4bx8': 'uint32_t', } self.require_complex = self.require_complex or t.name in ['complex64', 'complex128'] diff --git a/python/hidet/ir/dtypes/__init__.py b/python/hidet/ir/dtypes/__init__.py index b51d485b9..7a1cf8396 100644 --- a/python/hidet/ir/dtypes/__init__.py +++ b/python/hidet/ir/dtypes/__init__.py @@ -12,11 +12,13 @@ from hidet.ir.type import DataType from .integer import int8, int16, int32, int64, uint8, uint16, uint32, uint64 from .integer import i8, i16, i32, i64, u8, u16, u32, u64 +from .integer_subbyte import int4b, int3b, int2b, int1b, uint4b, uint3b, uint2b, uint1b +from .integer_subbyte import i4, i3, i2, i1, u4, u3, u2, u1 from .floats import float16, float32, float64, bfloat16, tfloat32 from .floats import f16, f32, f64, bf16, tf32 from .boolean import boolean -from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, vectorize -from .vector import f16x2, f32x4, f32x8 +from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize +from .vector import f16x2, f32x4, f32x8, i4x8, u4x8 from .complex import complex64, complex128 from .promotion import promote_type from .utils import dtype_to_numpy, finfo, iinfo @@ -43,6 +45,16 @@ 'float16x2': float16x2, 'int8x4': int8x4, 'uint8x4': uint8x4, + 'int4b': int4b, + 'int3b': int3b, + 'int2b': int2b, + 'int1b': int1b, + 'uint4b': uint4b, + 'uint3b': uint3b, + 'uint2b': uint2b, + 'uint1b': uint1b, + 'int4bx8': int4bx8, + 'uint4bx8': uint4bx8, } sname2dtype = { @@ -66,6 +78,16 @@ 'f32x8': f32x8, 'f16x2': f16x2, 'i8x4': int8x4, + 'i4': int4b, + 'i3': int3b, + 'i2': int2b, + 'i1': int1b, + 'u4': uint4b, + 'u3': uint3b, + 'u2': uint2b, + 'u1': uint1b, + 'i4x8': int4bx8, + 'u4x8': uint4bx8, } diff --git a/python/hidet/ir/dtypes/boolean.py b/python/hidet/ir/dtypes/boolean.py index 1336ca303..b613c8314 100644 --- a/python/hidet/ir/dtypes/boolean.py +++ b/python/hidet/ir/dtypes/boolean.py @@ -18,6 +18,9 @@ class Boolean(DataType): def __init__(self): super().__init__('bool', 'bool', 1) + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False diff --git a/python/hidet/ir/dtypes/complex.py b/python/hidet/ir/dtypes/complex.py index cecda8ad6..948668086 100644 --- a/python/hidet/ir/dtypes/complex.py +++ b/python/hidet/ir/dtypes/complex.py @@ -19,6 +19,9 @@ def __init__(self, name, short_name, base_dtype: DataType): super().__init__(name, short_name, 2 * base_dtype.nbytes) self.base_dtype: DataType = base_dtype + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False diff --git a/python/hidet/ir/dtypes/floats.py b/python/hidet/ir/dtypes/floats.py index afd627d38..4e3074443 100644 --- a/python/hidet/ir/dtypes/floats.py +++ b/python/hidet/ir/dtypes/floats.py @@ -35,6 +35,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value, eps, smallest self._eps: float = eps self._smallest_normal: float = smallest_normal + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return True diff --git a/python/hidet/ir/dtypes/integer.py b/python/hidet/ir/dtypes/integer.py index bc605c1a1..27268c106 100644 --- a/python/hidet/ir/dtypes/integer.py +++ b/python/hidet/ir/dtypes/integer.py @@ -29,6 +29,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value): self._min_value: int = min_value self._max_value: int = max_value + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False @@ -54,6 +57,9 @@ def constant(self, value: Any): raise ValueError('Value {} is out of range for {}.'.format(value, self.name)) return constant(value, self) + def signedness(self): + return self._min_value < 0 + @property def one(self): return self.constant(1) diff --git a/python/hidet/ir/dtypes/integer_subbyte.py b/python/hidet/ir/dtypes/integer_subbyte.py new file mode 100644 index 000000000..ed44de9a4 --- /dev/null +++ b/python/hidet/ir/dtypes/integer_subbyte.py @@ -0,0 +1,70 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir.type import DataType +from .integer import IntegerType, IntInfo, uint8, uint32 + + +class IntegerSubbyteType(IntegerType): + def __init__(self, name, short_name, storage, nbits, signed, min_value, max_value): + nbytes = storage.nbytes + super().__init__(name, short_name, nbytes, min_value, max_value) + self._storage: DataType = storage + self._nbits: int = nbits + self._bits_mask: int = (1 << self._nbits) - 1 + self._sign_mask: int = 1 << (self._nbits - 1) if self.signedness() else 0 + + @property + def storage(self): + return self._storage + + @property + def nbytes(self): + raise TypeError(f"Cannot access nbytes property for the type({self}") + + @property + def nbits(self): + return self._nbits + + @property + def bits_mask(self): + return self._bits_mask + + @property + def sign_mask(self): + return self._sign_mask + + def is_integer_subbyte(self): + return True + + def iinfo(self) -> IntInfo: + return IntInfo(self._nbits, self._max_value, self._min_value, self) + + +int4b = IntegerSubbyteType('int4b', 'i4', uint8, 4, True, -8, 7) +int3b = IntegerSubbyteType('int3b', 'i3', uint32, 3, True, -4, 3) +int2b = IntegerSubbyteType('int2b', 'i2', uint8, 2, True, -2, 1) +int1b = IntegerSubbyteType('int1b', 'i1', uint8, 1, True, -1, 0) + +uint4b = IntegerSubbyteType('uint4b', 'u4', uint8, 4, False, 0, 16) +uint3b = IntegerSubbyteType('uint3b', 'u3', uint32, 3, False, 0, 8) +uint2b = IntegerSubbyteType('uint2b', 'u2', uint8, 2, False, 0, 4) +uint1b = IntegerSubbyteType('uint1b', 'u1', uint8, 1, False, 0, 1) + +i4 = int4b +i3 = int3b +i2 = int2b +i1 = int1b + +u4 = uint4b +u3 = uint3b +u2 = uint2b +u1 = uint1b diff --git a/python/hidet/ir/dtypes/vector.py b/python/hidet/ir/dtypes/vector.py index 0ad01cadd..6d0ac6da4 100644 --- a/python/hidet/ir/dtypes/vector.py +++ b/python/hidet/ir/dtypes/vector.py @@ -13,13 +13,16 @@ from hidet.ir.type import DataType from .floats import float32, float16 from .integer import int8, uint8 +from .integer_subbyte import int4b, uint4b class VectorType(DataType): def __init__(self, lane_type: DataType, num_lanes: int): name = '{}x{}'.format(lane_type.name, num_lanes) short_name = '{}x{}'.format(lane_type.short_name, num_lanes) - nbytes = lane_type.nbytes * num_lanes + nbytes = ( + lane_type.nbytes * num_lanes if not lane_type.is_integer_subbyte() else lane_type.nbits * num_lanes // 8 + ) super().__init__(name, short_name, nbytes) self._num_lanes: int = num_lanes self._lane_type: DataType = lane_type @@ -27,6 +30,9 @@ def __init__(self, lane_type: DataType, num_lanes: int): if lane_type.is_vector(): raise ValueError('Cannot create a vector type of vectors') + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False @@ -90,6 +96,18 @@ def max_value(self): float16x2 = VectorType(float16, 2) f16x2 = float16x2 +int4bx2 = VectorType(int4b, 2) +i4x2 = int4bx2 + +uint4bx2 = VectorType(uint4b, 2) +u4x2 = uint4bx2 + +int4bx8 = VectorType(int4b, 8) +i4x8 = int4bx8 + +uint4bx8 = VectorType(uint4b, 8) +u4x8 = uint4bx8 + def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType: table = { diff --git a/python/hidet/ir/stmt.py b/python/hidet/ir/stmt.py index 1c0af2423..748256e57 100644 --- a/python/hidet/ir/stmt.py +++ b/python/hidet/ir/stmt.py @@ -41,6 +41,18 @@ def from_str(name): else: return DeclareScope.Default + def is_global(self): + return self == DeclareScope.Global + + def is_shared(self): + return self == DeclareScope.Shared + + def is_register(self): + return self == DeclareScope.Register + + def is_memory(self): + return not self.is_register() + class ForStmtAttr: def __init__(self, unroll=False, unroll_factor=None, unroll_explicit=False, parallel=False, parallel_threads=None): diff --git a/python/hidet/ir/type.py b/python/hidet/ir/type.py index 01e34f34f..9c479ce43 100644 --- a/python/hidet/ir/type.py +++ b/python/hidet/ir/type.py @@ -67,6 +67,7 @@ def as_data_type(self) -> Optional[DataType]: class DataType(BaseType): """ The data type that defines how to interpret the data in memory. + """ def __init__(self, name: str, short_name: str, nbytes: int): @@ -131,6 +132,36 @@ def short_name(self) -> str: def nbytes(self) -> int: return self._nbytes + @property + def nbits(self) -> int: + """ + Get the bit length of the data type + + Note: + 1. The bit length of the data type itself other than the bit length of its storage. + 2. For regular data types, the nbits can be computed from its nbytes property. + 3. For subbyte data types, the nbits is defined when constructing the data type, + and this method will also be overridden for subbyte data types. + 4. In addition, we cannot access the nbytes for a subbyte data type, otherwise + a type error will be raised. + """ + return self._nbytes * 8 + + @property + def storage(self) -> DataType: + """ + Get the actual storage type of the data type + + Note: + 1. The storage of a regular data type is the data type itself, while the storage + of a subbyte type is the type of its actual storage. e.g., the storage of int4b is uint8 + 2. The property will be overridden in the subclass of subbyte types. + """ + return self + + def is_integer_subbyte(self) -> bool: + raise NotImplementedError() + def is_float(self) -> bool: raise NotImplementedError() @@ -187,7 +218,10 @@ def __invert__(self): return TensorPointerType.from_tensor_type(self) def storage_bytes(self) -> Expr: - return self.layout.size * self.dtype.nbytes + if self.dtype.is_integer_subbyte(): + return self.layout.size * self.dtype.nbits // 8 + else: + return self.layout.size * self.dtype.nbytes def const_shape(self) -> List[int]: return [int(v) for v in self.shape] diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index 3da283c64..857f1efca 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -37,6 +37,7 @@ from .check_launch_configuration import check_launch_configuration_pass from .lower_special_cast import lower_special_cast_pass from .annotate_header_and_libs import annotate_header_and_libs_pass +from .lower_integer_subbyte import lower_integer_subbyte_pass def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule: @@ -63,6 +64,7 @@ def lower(ir_module: IRModule) -> IRModule: declare_to_let_pass(), rule_based_simplify_pass(), # make ir more readable flatten_tensor_index_pass(), + lower_integer_subbyte_pass(), lower_special_cast_pass(), inline_function_pass(), resolve_primitive_func_pass(), diff --git a/python/hidet/transforms/flatten_tensor_index.py b/python/hidet/transforms/flatten_tensor_index.py index ded30854e..c50429af2 100644 --- a/python/hidet/transforms/flatten_tensor_index.py +++ b/python/hidet/transforms/flatten_tensor_index.py @@ -9,11 +9,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from hidet.ir.type import TensorType, tensor_type, tensor_pointer_type, PointerType, TensorPointerType, ArrayType +from typing import Dict + +from hidet.ir.type import ( + TensorType, + tensor_type, + tensor_pointer_type, + PointerType, + TensorPointerType, + ArrayType, + FuncType, + func_type, +) from hidet.ir.expr import Var, TensorElement, TensorSlice, tensor_element from hidet.ir.stmt import BufferStoreStmt, DeclareStmt from hidet.ir.layout import row_major from hidet.ir.func import Function +from hidet.ir.module import IRModule from hidet.ir.functors import IRRewriter from hidet.ir.tools import simplify, TypeInfer from hidet.transforms import Pass @@ -28,6 +40,15 @@ class FlattenTensorAccessRewriter(IRRewriter): def __init__(self): super().__init__() self.type_infer = TypeInfer() + self.func2func_type: Dict[str, FuncType] = {} + + def visit_Var(self, v: Var): + if isinstance(v.type, FuncType): + if v.name in self.func2func_type: + func_ty = self.func2func_type[v.name] + if func_ty is not v.type: + return Var(v.hint, func_ty, v.name) + return super().visit_Var(v) def visit_Function(self, func: Function): for var in func.params: @@ -39,7 +60,13 @@ def visit_Function(self, func: Function): self.memo[var] = Var(var.hint, tensor_pointer_type(var.type.tensor_type.dtype, [size])) body = self(func.body) params = [self(p) for p in func.params] - return Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs) + if body is func.body and all(p is p1 for p, p1 in zip(params, func.params)): + return func + else: + new_func = Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs) + param_types = [p.type for p in params] + self.func2func_type[func.name] = func_type(param_types, func.ret_type) + return new_func def get_layout(self, e) -> DataLayout: e_type = self.type_infer(e) @@ -103,9 +130,16 @@ def visit_TensorSlice(self, e: TensorSlice): class FlattenTensorIndexPass(Pass): - def process_func(self, func: Function) -> Function: + def process_module(self, ir_module: IRModule) -> IRModule: flatten_index = FlattenTensorAccessRewriter() - return flatten_index(func) + + new_funcs = {} + for name, func in ir_module.functions.items(): + new_funcs[name] = flatten_index(func) + if all(new_funcs[name] is ir_module.functions[name] for name in new_funcs): + return ir_module + else: + return ir_module.copy().reset_funcs(new_funcs, ir_module.global_vars) def flatten_tensor_index_pass(): diff --git a/python/hidet/transforms/lower_integer_subbyte.py b/python/hidet/transforms/lower_integer_subbyte.py new file mode 100644 index 000000000..9742b4e2f --- /dev/null +++ b/python/hidet/transforms/lower_integer_subbyte.py @@ -0,0 +1,350 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=unused-variable +from typing import Dict, List, Union, Tuple + +from hidet.ir.tools import infer_type, simplify +from hidet.ir.type import BaseType, DataType, TensorType, TensorPointerType, PointerType +from hidet.ir.dtypes import i32 +from hidet.ir.expr import Var, Expr, Add, TensorElement, Address, Constant, Cast, var, cast, bitwise_not +from hidet.ir.stmt import ( + Stmt, + DeclareStmt, + AssignStmt, + LetStmt, + EvaluateStmt, + BufferStoreStmt, + SeqStmt, + BlackBoxStmt, + WhileStmt, + DeclareScope, +) + +from hidet.ir.func import Function +from hidet.ir.module import IRModule +from hidet.ir.functors import IRRewriter +from hidet.transforms import Pass +from hidet.utils.py import is_power_of_two + + +def is_pointer_type(base_ty: BaseType): + return isinstance(base_ty, (PointerType, TensorPointerType, TensorType)) + + +def get_pointer_base_type(base_ty: BaseType): + if isinstance(base_ty, PointerType): + return base_ty.base_type + elif isinstance(base_ty, TensorType): + return base_ty.dtype + else: + ttype = base_ty.tensor_type + return ttype.dtype + + +def is_integer_subbyte(dtype: BaseType): + return dtype.is_data_type() and dtype.is_integer_subbyte() + + +class LowerIntegerSubbyteRewriter(IRRewriter): + # convert subbyte integers to their storage type + # e.g., + # a = register_tensor("int4b", [16]) ==> a = register_tensor("uint8", [8]) + # int4b* ptr = &a[8] ==> uint8* ptr = &a[4] + # int4b* ptr = ptr + 10 ==> uint8* ptr = ptr + 5 + # we support tensor element access for global|shared|register tensors, and + # we support buffer store statment for register tensors, i.e. + # a = register_tensor("int4b", [4, 4]) + # b = a[2, 2] + # a[2, 2] = int4b(-5) + # user has to explicitly insert dtype conversion before applying arithmetic + # operations on subbyte integers. + # int4b a = int4b(-2) + # int4b b = int4b(-3) + # int4b c = a + b ==> not allowed + def __init__(self): + super().__init__() + self.old2new: Dict[Var, Var] = {} + self.stmts: List[Stmt] = [] + self.var2scope: Dict[Var, DeclareScope] = {} + self.recursive_depth = 0 + + def auto_var(self, v: Var = None, hint: str = None, e: Expr = None): + if v is not None: + self.stmts.append(DeclareStmt(v)) + return v + v_ty = infer_type(e) + v = var(hint, v_ty) + self.stmts.append(DeclareStmt(v, e)) + return v + + def append_stmt(self, stmt: Union[Stmt, Expr]): + if isinstance(stmt, Expr): + stmt = EvaluateStmt(stmt) + self.stmts.append(stmt) + + def flush_stmts(self): + stmts = self.stmts + self.stmts = [] + return stmts + + def flatten_stmts(self, stmts: List[Stmt]): + if len(stmts) == 1: + return stmts[0] + else: + return SeqStmt(stmts) + + def _get_divisor(self, dtype: DataType): + storage_ty = dtype.storage + storage_bits = storage_ty.nbits + dtype_bits = dtype.nbits + divisor = storage_bits // dtype_bits + return divisor + + def _get_subbyte_value(self, dtype: DataType, base: Var, offset: Expr): + storage_ty = dtype.storage + storage_bits = storage_ty.nbits + dtype_bits = dtype.nbits + divisor = storage_bits // dtype_bits + if not is_power_of_two(divisor): + raise TypeError(f"data type not supported yet(got:{dtype})") + idx = simplify(offset // divisor) + offset_ = simplify(offset % divisor) + mask = storage_ty.constant(dtype.bits_mask) + return (base[idx] >> (offset_ * dtype_bits)) & mask + + def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Expr): + storage_ty = dtype.storage + storage_bits = storage_ty.nbits + dtype_bits = dtype.nbits + divisor = storage_bits // dtype_bits + if not is_power_of_two(divisor): + raise TypeError(f"data type not supported yet(got:{dtype})") + idx = simplify(offset // divisor) + offset_ = simplify(offset % divisor) + value_ty = infer_type(value) + assert value_ty == storage_ty + mask = storage_ty.constant(dtype.bits_mask) + item = self.auto_var(hint="item", e=value & mask) + updated_mask = self.auto_var(hint="updated_mask", e=bitwise_not(mask << (offset_ * dtype_bits))) + new_bits = self.auto_var(hint="new_bits", e=item << (offset_ * dtype_bits)) + + from hidet.ir.dtypes import u32, u16 + + if self.var2scope[base].is_memory(): + if not any(storage_ty is ty for ty in [i32, u32, u16]): + raise NotImplementedError( + "writing subbyte data to memory requires the storage type must be" + " int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})" + ) + original = self.auto_var(hint="original", e=storage_ty.zero) + updated = self.auto_var(hint="updated", e=storage_ty.zero) + body = [] + body.append(AssignStmt(original, base[idx])) + body.append(AssignStmt(updated, (original & updated_mask) | new_bits)) + body.append(BlackBoxStmt("atomicCAS({}, {}, {});", ~base[idx], original, updated)) + body = SeqStmt(body) + self.stmts.append(WhileStmt(original == updated, body)) + else: + assert self.var2scope[base].is_register() + original = self.auto_var(hint="original", e=base[idx]) + updated = self.auto_var(hint="updated", e=(original & updated_mask) | new_bits) + self.stmts.append(BufferStoreStmt(base, [idx], updated)) + + def visit_DataType(self, t: DataType): + if t.is_integer_subbyte(): + return t.storage + else: + return t + + def visit_TensorType(self, t: TensorType): + from hidet.ir.layout import row_major + + if is_integer_subbyte(t.dtype): + shape = list(self.visit(t.shape)) + assert len(shape) == 1 + dtype = t.dtype + storage_ty = dtype.storage + storage_bits = storage_ty.nbits + dtype_bits = dtype.nbits + divisor = storage_bits // dtype_bits + shape[-1] = shape[-1] // divisor + layout = row_major(*shape) + return TensorType(storage_ty, shape, layout) + return super().visit_TensorType(t) + + def visit_Var(self, v: Var): + if v in self.old2new: + return self.old2new[v] + return super().visit_Var(v) + + def visit_Constant(self, e: Constant): + if is_integer_subbyte(e.type): + ty = self.visit(e.type) + value = self.visit(e.value) + dtype = e.type + storage_ty = dtype.storage + mask = storage_ty.constant(dtype.bits_mask) + value = value & mask + return Constant(value, ty) + return super().visit_Constant(e) + + def visit_TensorElement(self, e: TensorElement): + if isinstance(e.base, Var): + base_ty = infer_type(e.base) + if is_pointer_type(base_ty): + dtype = get_pointer_base_type(base_ty) + if is_integer_subbyte(dtype): + base = self.visit(e.base) + assert len(e.indices) == 1 + offset = self.visit(e.indices[0]) + return self._get_subbyte_value(dtype, base, offset) + return super().visit_TensorElement(e) + + def visit_Address(self, e: Address): + if isinstance(e.expr, TensorElement): + base = e.expr.base + if isinstance(base, Var): + base_ty = infer_type(base) + if is_pointer_type(base_ty): + dtype = get_pointer_base_type(base_ty) + if is_integer_subbyte(dtype): + storage_ty = dtype.storage + storage_bits = storage_ty.nbits + dtype_bits = dtype.nbits + divisor = storage_bits // dtype_bits + assert len(e.expr.indices) == 1 + offset = self.visit(e.expr.indices[0]) + idx = simplify(offset // divisor) + base = self.visit(base) + return ~base[idx] + return super().visit_Address(e) + + def _cast_int(self, dtype: DataType, expr: Expr): + if not dtype.signedness(): + return expr + int_type = i32 + int_data = cast(expr, int_type) + shift = int_type.nbits - dtype.nbits + return (int_data << shift) >> shift + + def visit_Cast(self, e: Cast): + expr_ty = infer_type(e.expr) + if is_integer_subbyte(expr_ty): + if is_integer_subbyte(e.target_type): + raise NotImplementedError(f"casting from {expr_ty} to {e.target_type} is not supported yet") + expr = self.visit(e.expr) + return cast(self._cast_int(expr_ty, expr), e.target_type) + elif is_integer_subbyte(e.target_type): + from hidet.ir.expr import if_then_else + + expr = self.visit(e.expr) + dtype = e.target_type + min_val = expr_ty(dtype.min_value.value) + max_val = expr_ty(dtype.max_value.value) + expr = if_then_else(expr < min_val, min_val, expr) + expr = if_then_else(expr >= max_val, max_val, expr) + if not dtype.signedness(): + return cast(expr, dtype.storage) + storage_ty = dtype.storage + int_type = i32 + int_data = cast(expr, int_type) + shift = int_type.nbits - dtype.nbits + int_data = (int_data << shift) >> shift + return cast(int_data, storage_ty) + target_type = self.visit(e.target_type) + expr = self.visit(e.expr) + if target_type is e.target_type and expr is e.expr: + return e + else: + return cast(expr, target_type) + + def _subbyte_pointer_add(self, dtype: DataType, ptr: Union[Expr, Tuple[Expr]], offset: Expr): + divisor = self._get_divisor(dtype) + if isinstance(ptr, tuple): + ptr, offset_ = ptr + offset = offset + offset_ + if self.recursive_depth == 0: + return ptr + offset // divisor + else: + return ptr, offset + + def visit_Add(self, e: Add): + a_ty = infer_type(e.a) + b_ty = infer_type(e.b) + if isinstance(a_ty, PointerType) and is_integer_subbyte(a_ty.base_type): + self.recursive_depth += 1 + a = self.visit(e.a) + b = self.visit(e.b) + self.recursive_depth -= 1 + return self._subbyte_pointer_add(a_ty.base_type, a, b) + elif isinstance(b_ty, PointerType) and is_integer_subbyte(b_ty.base_type): + self.recursive_depth += 1 + a = self.visit(e.a) + b = self.visit(e.b) + self.recursive_depth -= 1 + return self._subbyte_pointer_add(b_ty.base_type, b, a) + return super().visit_Add(e) + + def visit_DeclareStmt(self, stmt: DeclareStmt): + v_type = self.visit(stmt.var.type) + if v_type is not stmt.var.type: + v = var(stmt.var.hint, v_type) + init = self.visit(stmt.init) + self.old2new[stmt.var] = v + if isinstance(v_type, TensorType): + self.var2scope[v] = stmt.scope + self.append_stmt(DeclareStmt(v, init, stmt.is_static, stmt.scope)) + return self.flatten_stmts(self.flush_stmts()) + self.append_stmt(super().visit_DeclareStmt(stmt)) + return self.flatten_stmts(self.flush_stmts()) + + def visit_AssignStmt(self, stmt: AssignStmt): + self.append_stmt(super().visit_AssignStmt(stmt)) + return self.flatten_stmts(self.flush_stmts()) + + def visit_LetStmt(self, stmt: LetStmt): + self.append_stmt(super().visit_LetStmt(stmt)) + return self.flatten_stmts(self.flush_stmts()) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + if isinstance(stmt.buf, Var): + buf_ty = infer_type(stmt.buf) + if isinstance(buf_ty, TensorType): + dtype = buf_ty.dtype + if is_integer_subbyte(dtype): + buf = self.visit(stmt.buf) + indices = self.visit(stmt.indices) + value = self.visit(stmt.value) + assert len(indices) == 1 + self._set_subbyte_value(dtype, buf, indices[0], value) + return self.flatten_stmts(self.flush_stmts()) + self.append_stmt(super().visit_BufferStoreStmt(stmt)) + return self.flatten_stmts(self.flush_stmts()) + + +class LowerIntegerSubbytePass(Pass): + def process_func(self, func: Function) -> Function: + rewriter = LowerIntegerSubbyteRewriter() + return rewriter(func) + + def process_module(self, ir_module: IRModule) -> IRModule: + new_funcs = {} + for name, func in ir_module.functions.items(): + new_funcs[name] = self.process_func(func) + if all(new_funcs[name] is ir_module.functions[name] for name in new_funcs): + return ir_module + else: + return ir_module.copy().reset_funcs(new_funcs, ir_module.global_vars) + + +def lower_integer_subbyte_pass() -> Pass: + return LowerIntegerSubbytePass() diff --git a/tests/ir/test_int_subbyte.py b/tests/ir/test_int_subbyte.py new file mode 100644 index 000000000..118aed2a4 --- /dev/null +++ b/tests/ir/test_int_subbyte.py @@ -0,0 +1,129 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import hidet + + +def test_int_4bit(): + from hidet.ir.dtypes import i4, u4, i4x8, i8, f16, f32 + from hidet.ir.expr import constant, cast + from hidet.lang import attrs + from hidet.lang import shared_tensor, register_tensor + from hidet.lang.cuda import blockIdx, threadIdx, dynamic_shared_memory + + with hidet.script_module() as script_module: + + @hidet.script + def func(out: f32[4, 4], inp: i4[4, 4]): + attrs.func_kind = "cuda_kernel" + attrs.cuda.block_dim = 128 + attrs.cuda.grid_dim = 4 + + a = constant(1, i4) + b = register_tensor('int4b', shape=[8, 2]) + ptr = ~b[3, 1] + ptr = ptr + 4 + ptr = ptr + (threadIdx.x * 444 + blockIdx.x * 888 + 555) + c = register_tensor('i4x8', shape=[1]) + b[0, 1] = a + b[0, 1] = b[0, 2] + d = b[0, 1] + s = shared_tensor('uint4b', shape=[7, 8]) + e = f32(s[2, 4]) + s1 = shared_tensor('float32', shape=[64, 64]) + s1[3, 4] = f16(b[4, 0]) + s2 = s[:, 4:] + f = f32(s2[2, 0]) + + data = register_tensor('int4b', shape=[4, 4]) + + for i in range(4): + for j in range(4): + if i == 0 and j == 0: + data[i, j] = i4(-8) + elif j == 0: + data[i, j] = i4(f32(data[i - 1, 3]) + 1) + else: + data[i, j] = i4(f32(data[i, j - 1]) + 1) + + if threadIdx.x == 0 and blockIdx.x == 0: + for i in range(4): + for j in range(4): + d = data[i, j] + out[i, j] = f32(d) + + func = script_module.build() + import torch + + data = torch.empty((4, 4), dtype=torch.float32, device="cuda") + data = hidet.from_torch(data) + inp = torch.empty((4, 2), dtype=torch.int8, device="cuda") + inp = hidet.from_torch(inp) + func(data, inp) + import numpy as np + + groundtruth = np.resize(np.arange(-8, 8), (4, 4)).astype(np.float32) + np.testing.assert_equal(data.cpu().numpy(), groundtruth) + + +def test_int_2bit(): + from hidet.ir.dtypes import i2, u2, i8, f16, f32 + from hidet.ir.expr import constant, cast + from hidet.lang import attrs + from hidet.lang import shared_tensor, register_tensor + from hidet.lang.cuda import blockIdx, threadIdx, dynamic_shared_memory + + with hidet.script_module() as script_module: + + @hidet.script + def func(out: f32[2, 2]): + attrs.func_kind = "cuda_kernel" + attrs.cuda.block_dim = 128 + attrs.cuda.grid_dim = 4 + + a = constant(1, i2) + + data = register_tensor('int2b', shape=[2, 2]) + + for i in range(2): + for j in range(2): + if i == 0 and j == 0: + data[i, j] = i2(-2) + elif j == 0: + data[i, j] = i2(f32(data[i - 1, 1]) + 1) + else: + data[i, j] = i2(f32(data[i, j - 1]) + 1) + + if threadIdx.x == 0 and blockIdx.x == 0: + for i in range(2): + for j in range(2): + d = data[i, j] + out[i, j] = f32(d) + + func = script_module.build() + import torch + + data = torch.empty((2, 2), dtype=torch.float32, device="cuda") + data = hidet.from_torch(data) + func(data) + import numpy as np + + groundtruth = np.resize(np.arange(-2, 2), (2, 2)).astype(np.float32) + np.testing.assert_equal(data.cpu().numpy(), groundtruth) + + +if __name__ == "__main__": + hidet.option.cache_dir("./demo_int_subbyte") + hidet.option.save_lower_ir(True) + + pytest.main(__file__)