From 25e9a56d3ca370de5481d1f38d854346625cd004 Mon Sep 17 00:00:00 2001 From: xiaocenxiaocen Date: Thu, 25 Jan 2024 17:36:52 -0500 Subject: [PATCH] [Ir][DTypes] resolve review comments --- python/hidet/ir/dtypes/boolean.py | 3 ++ python/hidet/ir/dtypes/complex.py | 3 ++ python/hidet/ir/dtypes/floats.py | 3 ++ python/hidet/ir/dtypes/integer.py | 3 ++ python/hidet/ir/dtypes/integer_subbyte.py | 15 ++++++++ python/hidet/ir/type.py | 35 ++++++++++++------- .../hidet/transforms/lower_integer_subbyte.py | 9 ++--- tests/ir/test_int_subbyte.py | 10 ++++-- 8 files changed, 62 insertions(+), 19 deletions(-) diff --git a/python/hidet/ir/dtypes/boolean.py b/python/hidet/ir/dtypes/boolean.py index 1336ca303..b613c8314 100644 --- a/python/hidet/ir/dtypes/boolean.py +++ b/python/hidet/ir/dtypes/boolean.py @@ -18,6 +18,9 @@ class Boolean(DataType): def __init__(self): super().__init__('bool', 'bool', 1) + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False diff --git a/python/hidet/ir/dtypes/complex.py b/python/hidet/ir/dtypes/complex.py index cecda8ad6..948668086 100644 --- a/python/hidet/ir/dtypes/complex.py +++ b/python/hidet/ir/dtypes/complex.py @@ -19,6 +19,9 @@ def __init__(self, name, short_name, base_dtype: DataType): super().__init__(name, short_name, 2 * base_dtype.nbytes) self.base_dtype: DataType = base_dtype + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False diff --git a/python/hidet/ir/dtypes/floats.py b/python/hidet/ir/dtypes/floats.py index afd627d38..4e3074443 100644 --- a/python/hidet/ir/dtypes/floats.py +++ b/python/hidet/ir/dtypes/floats.py @@ -35,6 +35,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value, eps, smallest self._eps: float = eps self._smallest_normal: float = smallest_normal + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return True diff --git a/python/hidet/ir/dtypes/integer.py b/python/hidet/ir/dtypes/integer.py index 6e65abc16..27268c106 100644 --- a/python/hidet/ir/dtypes/integer.py +++ b/python/hidet/ir/dtypes/integer.py @@ -29,6 +29,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value): self._min_value: int = min_value self._max_value: int = max_value + def is_integer_subbyte(self) -> bool: + return False + def is_float(self) -> bool: return False diff --git a/python/hidet/ir/dtypes/integer_subbyte.py b/python/hidet/ir/dtypes/integer_subbyte.py index 5c6fa2865..ed44de9a4 100644 --- a/python/hidet/ir/dtypes/integer_subbyte.py +++ b/python/hidet/ir/dtypes/integer_subbyte.py @@ -22,6 +22,18 @@ def __init__(self, name, short_name, storage, nbits, signed, min_value, max_valu self._bits_mask: int = (1 << self._nbits) - 1 self._sign_mask: int = 1 << (self._nbits - 1) if self.signedness() else 0 + @property + def storage(self): + return self._storage + + @property + def nbytes(self): + raise TypeError(f"Cannot access nbytes property for the type({self}") + + @property + def nbits(self): + return self._nbits + @property def bits_mask(self): return self._bits_mask @@ -30,6 +42,9 @@ def bits_mask(self): def sign_mask(self): return self._sign_mask + def is_integer_subbyte(self): + return True + def iinfo(self) -> IntInfo: return IntInfo(self._nbits, self._max_value, self._min_value, self) diff --git a/python/hidet/ir/type.py b/python/hidet/ir/type.py index 32d1e46af..9c479ce43 100644 --- a/python/hidet/ir/type.py +++ b/python/hidet/ir/type.py @@ -68,20 +68,12 @@ class DataType(BaseType): """ The data type that defines how to interpret the data in memory. - Note: - 1. The _storage field for non-subbyte types is the type itself, while the _storage - for subbyte types is the type of its actual storage. e.g., the storage for int4b is uint8 - 2. The _storage field will be overwritten during the construction of subbyte types - 2. The _nbits field in the constructor denotes the bit length of the storage, and - it will be overwritten in the constructor of subbyte types """ def __init__(self, name: str, short_name: str, nbytes: int): self._name: str = name self._short_name: str = short_name - self._storage = self self._nbytes: int = nbytes - self._nbits: int = self._nbytes * 8 def __str__(self): return 'hidet.{}'.format(self.name) @@ -138,20 +130,37 @@ def short_name(self) -> str: @property def nbytes(self) -> int: - if self._nbits < 8: - raise TypeError(f"Cannot access nbytes property for the type({self}") return self._nbytes @property def nbits(self) -> int: - return self._nbits + """ + Get the bit length of the data type + + Note: + 1. The bit length of the data type itself other than the bit length of its storage. + 2. For regular data types, the nbits can be computed from its nbytes property. + 3. For subbyte data types, the nbits is defined when constructing the data type, + and this method will also be overridden for subbyte data types. + 4. In addition, we cannot access the nbytes for a subbyte data type, otherwise + a type error will be raised. + """ + return self._nbytes * 8 @property def storage(self) -> DataType: - return self._storage + """ + Get the actual storage type of the data type + + Note: + 1. The storage of a regular data type is the data type itself, while the storage + of a subbyte type is the type of its actual storage. e.g., the storage of int4b is uint8 + 2. The property will be overridden in the subclass of subbyte types. + """ + return self def is_integer_subbyte(self) -> bool: - return self.is_integer() and self._nbits < 8 + raise NotImplementedError() def is_float(self) -> bool: raise NotImplementedError() diff --git a/python/hidet/transforms/lower_integer_subbyte.py b/python/hidet/transforms/lower_integer_subbyte.py index ad30f4c49..d4e102fbc 100644 --- a/python/hidet/transforms/lower_integer_subbyte.py +++ b/python/hidet/transforms/lower_integer_subbyte.py @@ -33,6 +33,7 @@ from hidet.ir.module import IRModule from hidet.ir.functors import IRRewriter from hidet.transforms import Pass +from hidet.utils.py import is_power_of_two def is_pointer_type(base_ty: BaseType): @@ -113,10 +114,10 @@ def _get_subbyte_value(self, dtype: DataType, base: Var, offset: Expr): storage_bits = storage_ty.nbits dtype_bits = dtype.nbits divisor = storage_bits // dtype_bits - if divisor & (divisor - 1) != 0: + if is_power_of_two(divisor): raise TypeError(f"data type not supported yet(got:{dtype})") idx = simplify(offset // divisor) - offset_ = simplify(offset & (divisor - 1)) + offset_ = simplify(offset % divisor) mask = storage_ty.constant(dtype.bits_mask) return (base[idx] >> (offset_ * dtype_bits)) & mask @@ -125,10 +126,10 @@ def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Ex storage_bits = storage_ty.nbits dtype_bits = dtype.nbits divisor = storage_bits // dtype_bits - if divisor & (divisor - 1) != 0: + if is_power_of_two(divisor): raise TypeError(f"data type not supported yet(got:{dtype})") idx = simplify(offset // divisor) - offset_ = simplify(offset & (divisor - 1)) + offset_ = simplify(offset % divisor) value_ty = infer_type(value) assert value_ty == storage_ty mask = storage_ty.constant(dtype.bits_mask) diff --git a/tests/ir/test_int_subbyte.py b/tests/ir/test_int_subbyte.py index fc695e3de..118aed2a4 100644 --- a/tests/ir/test_int_subbyte.py +++ b/tests/ir/test_int_subbyte.py @@ -70,7 +70,10 @@ def func(out: f32[4, 4], inp: i4[4, 4]): inp = torch.empty((4, 2), dtype=torch.int8, device="cuda") inp = hidet.from_torch(inp) func(data, inp) - print(data.cpu().numpy()) + import numpy as np + + groundtruth = np.resize(np.arange(-8, 8), (4, 4)).astype(np.float32) + np.testing.assert_equal(data.cpu().numpy(), groundtruth) def test_int_2bit(): @@ -113,7 +116,10 @@ def func(out: f32[2, 2]): data = torch.empty((2, 2), dtype=torch.float32, device="cuda") data = hidet.from_torch(data) func(data) - print(data.cpu().numpy()) + import numpy as np + + groundtruth = np.resize(np.arange(-2, 2), (2, 2)).astype(np.float32) + np.testing.assert_equal(data.cpu().numpy(), groundtruth) if __name__ == "__main__":