Skip to content

Commit

Permalink
[Ir][DTypes] resolve review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaocenxiaocen committed Jan 25, 2024
1 parent 803b1c2 commit 25e9a56
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 19 deletions.
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class Boolean(DataType):
def __init__(self):
super().__init__('bool', 'bool', 1)

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, name, short_name, base_dtype: DataType):
super().__init__(name, short_name, 2 * base_dtype.nbytes)
self.base_dtype: DataType = base_dtype

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value, eps, smallest
self._eps: float = eps
self._smallest_normal: float = smallest_normal

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return True

Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value):
self._min_value: int = min_value
self._max_value: int = max_value

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand Down
15 changes: 15 additions & 0 deletions python/hidet/ir/dtypes/integer_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def __init__(self, name, short_name, storage, nbits, signed, min_value, max_valu
self._bits_mask: int = (1 << self._nbits) - 1
self._sign_mask: int = 1 << (self._nbits - 1) if self.signedness() else 0

@property
def storage(self):
return self._storage

@property
def nbytes(self):
raise TypeError(f"Cannot access nbytes property for the type({self}")

@property
def nbits(self):
return self._nbits

@property
def bits_mask(self):
return self._bits_mask
Expand All @@ -30,6 +42,9 @@ def bits_mask(self):
def sign_mask(self):
return self._sign_mask

def is_integer_subbyte(self):
return True

def iinfo(self) -> IntInfo:
return IntInfo(self._nbits, self._max_value, self._min_value, self)

Expand Down
35 changes: 22 additions & 13 deletions python/hidet/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,12 @@ class DataType(BaseType):
"""
The data type that defines how to interpret the data in memory.
Note:
1. The _storage field for non-subbyte types is the type itself, while the _storage
for subbyte types is the type of its actual storage. e.g., the storage for int4b is uint8
2. The _storage field will be overwritten during the construction of subbyte types
2. The _nbits field in the constructor denotes the bit length of the storage, and
it will be overwritten in the constructor of subbyte types
"""

def __init__(self, name: str, short_name: str, nbytes: int):
self._name: str = name
self._short_name: str = short_name
self._storage = self
self._nbytes: int = nbytes
self._nbits: int = self._nbytes * 8

def __str__(self):
return 'hidet.{}'.format(self.name)
Expand Down Expand Up @@ -138,20 +130,37 @@ def short_name(self) -> str:

@property
def nbytes(self) -> int:
if self._nbits < 8:
raise TypeError(f"Cannot access nbytes property for the type({self}")
return self._nbytes

@property
def nbits(self) -> int:
return self._nbits
"""
Get the bit length of the data type
Note:
1. The bit length of the data type itself other than the bit length of its storage.
2. For regular data types, the nbits can be computed from its nbytes property.
3. For subbyte data types, the nbits is defined when constructing the data type,
and this method will also be overridden for subbyte data types.
4. In addition, we cannot access the nbytes for a subbyte data type, otherwise
a type error will be raised.
"""
return self._nbytes * 8

@property
def storage(self) -> DataType:
return self._storage
"""
Get the actual storage type of the data type
Note:
1. The storage of a regular data type is the data type itself, while the storage
of a subbyte type is the type of its actual storage. e.g., the storage of int4b is uint8
2. The property will be overridden in the subclass of subbyte types.
"""
return self

def is_integer_subbyte(self) -> bool:
return self.is_integer() and self._nbits < 8
raise NotImplementedError()

def is_float(self) -> bool:
raise NotImplementedError()
Expand Down
9 changes: 5 additions & 4 deletions python/hidet/transforms/lower_integer_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from hidet.ir.module import IRModule
from hidet.ir.functors import IRRewriter
from hidet.transforms import Pass
from hidet.utils.py import is_power_of_two


def is_pointer_type(base_ty: BaseType):
Expand Down Expand Up @@ -113,10 +114,10 @@ def _get_subbyte_value(self, dtype: DataType, base: Var, offset: Expr):
storage_bits = storage_ty.nbits
dtype_bits = dtype.nbits
divisor = storage_bits // dtype_bits
if divisor & (divisor - 1) != 0:
if is_power_of_two(divisor):
raise TypeError(f"data type not supported yet(got:{dtype})")
idx = simplify(offset // divisor)
offset_ = simplify(offset & (divisor - 1))
offset_ = simplify(offset % divisor)
mask = storage_ty.constant(dtype.bits_mask)
return (base[idx] >> (offset_ * dtype_bits)) & mask

Expand All @@ -125,10 +126,10 @@ def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Ex
storage_bits = storage_ty.nbits
dtype_bits = dtype.nbits
divisor = storage_bits // dtype_bits
if divisor & (divisor - 1) != 0:
if is_power_of_two(divisor):
raise TypeError(f"data type not supported yet(got:{dtype})")
idx = simplify(offset // divisor)
offset_ = simplify(offset & (divisor - 1))
offset_ = simplify(offset % divisor)
value_ty = infer_type(value)
assert value_ty == storage_ty
mask = storage_ty.constant(dtype.bits_mask)
Expand Down
10 changes: 8 additions & 2 deletions tests/ir/test_int_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def func(out: f32[4, 4], inp: i4[4, 4]):
inp = torch.empty((4, 2), dtype=torch.int8, device="cuda")
inp = hidet.from_torch(inp)
func(data, inp)
print(data.cpu().numpy())
import numpy as np

groundtruth = np.resize(np.arange(-8, 8), (4, 4)).astype(np.float32)
np.testing.assert_equal(data.cpu().numpy(), groundtruth)


def test_int_2bit():
Expand Down Expand Up @@ -113,7 +116,10 @@ def func(out: f32[2, 2]):
data = torch.empty((2, 2), dtype=torch.float32, device="cuda")
data = hidet.from_torch(data)
func(data)
print(data.cpu().numpy())
import numpy as np

groundtruth = np.resize(np.arange(-2, 2), (2, 2)).astype(np.float32)
np.testing.assert_equal(data.cpu().numpy(), groundtruth)


if __name__ == "__main__":
Expand Down

0 comments on commit 25e9a56

Please sign in to comment.