Skip to content

Commit

Permalink
[Tests] resolve lint & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaocenxiaocen committed Jan 20, 2024
1 parent fd77cfa commit 23c53d4
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 44 deletions.
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def constant(self, value: Any):
raise ValueError('Value {} is out of range for {}.'.format(value, self.name))
return constant(value, self)

def signedness(self):
return self._min_value < 0

@property
def one(self):
return self.constant(1)
Expand Down
11 changes: 9 additions & 2 deletions python/hidet/ir/dtypes/integer_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ def __init__(self, name, short_name, storage, nbits, signed, min_value, max_valu
super().__init__(name, short_name, nbytes, min_value, max_value)
self._storage: DataType = storage
self._nbits: int = nbits
self._signed: bool = signed
self._bits_mask: int = (1 << self._nbits) - 1
self._sign_mask: int = 1 << (self._nbits - 1) if self._signed else 0
self._sign_mask: int = 1 << (self._nbits - 1) if self.signedness() else 0

@property
def bits_mask(self):
return self._bits_mask

@property
def sign_mask(self):
return self._sign_mask

def iinfo(self) -> IntInfo:
return IntInfo(self._nbits, self._max_value, self._min_value, self)
Expand Down
4 changes: 3 additions & 1 deletion python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class VectorType(DataType):
def __init__(self, lane_type: DataType, num_lanes: int):
name = '{}x{}'.format(lane_type.name, num_lanes)
short_name = '{}x{}'.format(lane_type.short_name, num_lanes)
nbytes = lane_type.nbytes * num_lanes if not lane_type.is_integer_subbyte() else lane_type.nbits * num_lanes // 8
nbytes = (
lane_type.nbytes * num_lanes if not lane_type.is_integer_subbyte() else lane_type.nbits * num_lanes // 8
)
super().__init__(name, short_name, nbytes)
self._num_lanes: int = num_lanes
self._lane_type: DataType = lane_type
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __invert__(self):

def storage_bytes(self) -> Expr:
if self.dtype.is_integer_subbyte():
return self.layout.size * self.dtype._nbits // 8
return self.layout.size * self.dtype.nbits // 8
else:
return self.layout.size * self.dtype.nbytes

Expand Down
13 changes: 11 additions & 2 deletions python/hidet/transforms/flatten_tensor_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@
# limitations under the License.
from typing import Dict

from hidet.ir.type import TensorType, tensor_type, tensor_pointer_type, PointerType, TensorPointerType, ArrayType, FuncType, func_type
from hidet.ir.type import (
TensorType,
tensor_type,
tensor_pointer_type,
PointerType,
TensorPointerType,
ArrayType,
FuncType,
func_type,
)
from hidet.ir.expr import Var, TensorElement, TensorSlice, tensor_element
from hidet.ir.stmt import BufferStoreStmt, DeclareStmt
from hidet.ir.layout import row_major
Expand Down Expand Up @@ -51,7 +60,7 @@ def visit_Function(self, func: Function):
self.memo[var] = Var(var.hint, tensor_pointer_type(var.type.tensor_type.dtype, [size]))
body = self(func.body)
params = [self(p) for p in func.params]
if body is func.body and all([p is p1 for p, p1 in zip(params, func.params)]):
if body is func.body and all(p is p1 for p, p1 in zip(params, func.params)):
return func
else:
new_func = Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs)
Expand Down
24 changes: 13 additions & 11 deletions python/hidet/transforms/lower_integer_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from hidet.ir.tools import infer_type, simplify
from hidet.ir.type import DataType, TensorType, PointerType
from hidet.ir.dtypes import i8, u8, i16, i32
from hidet.ir.expr import Var, Expr, Add, TensorElement, Address, Constant, Cast, var, cast, deref, bitwise_not
from hidet.ir.dtypes import i32
from hidet.ir.expr import Var, Expr, Add, TensorElement, Address, Constant, Cast, var, cast, bitwise_not
from hidet.ir.stmt import (
Stmt,
DeclareStmt,
Expand Down Expand Up @@ -99,7 +99,7 @@ def _get_subbyte_value(self, dtype: DataType, base: Var, offset: Expr):
raise TypeError(f"data type not supported yet(got:{dtype})")
idx = simplify(offset // divisor)
offset_ = simplify(offset & (divisor - 1))
mask = storage_ty.constant(dtype._bits_mask)
mask = storage_ty.constant(dtype.bits_mask)
return (base[idx] >> (offset_ * dtype_bits)) & mask

def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Expr):
Expand All @@ -113,17 +113,18 @@ def _set_subbyte_value(self, dtype: DataType, base: Var, offset: Expr, value: Ex
offset_ = simplify(offset & (divisor - 1))
value_ty = infer_type(value)
assert value_ty == storage_ty
mask = storage_ty.constant(dtype._bits_mask)
mask = storage_ty.constant(dtype.bits_mask)
item = self.auto_var(hint="item", e=value & mask)
updated_mask = self.auto_var(hint="updated_mask", e=bitwise_not(mask << (offset_ * dtype_bits)))
new_bits = self.auto_var(hint="new_bits", e=item << (offset_ * dtype_bits))

from hidet.ir.dtypes import i32, u32, u16
from hidet.ir.dtypes import u32, u16

if self.var2scope[base].is_memory():
if not any([storage_ty is ty for ty in [i32, u32, u16]]):
if not any(storage_ty is ty for ty in [i32, u32, u16]):
raise NotImplementedError(
"writing subbyte data to memory requires the storage type must be int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})"
"writing subbyte data to memory requires the storage type must be"
" int32, uint32, or uint16 due to atomicCAS, but got({storage_ty})"
)
original = self.auto_var(hint="original", e=storage_ty.zero)
updated = self.auto_var(hint="updated", e=storage_ty.zero)
Expand Down Expand Up @@ -173,7 +174,7 @@ def visit_Constant(self, e: Constant):
value = self.visit(e.value)
dtype = e.type
storage_ty = dtype.storage
mask = storage_ty.constant(dtype._bits_mask)
mask = storage_ty.constant(dtype.bits_mask)
value = value & mask
return Constant(value, ty)
return super().visit_Constant(e)
Expand Down Expand Up @@ -207,10 +208,10 @@ def visit_Address(self, e: Address):
idx = simplify(offset // divisor)
base = self.visit(base)
return ~base[idx]
return super().visit_Address()
return super().visit_Address(e)

def _cast_int(self, dtype: DataType, expr: Expr):
if not dtype._signed:
if not dtype.signedness():
return expr
int_type = i32
int_data = cast(expr, int_type)
Expand All @@ -226,11 +227,12 @@ def visit_Cast(self, e: Cast):
return cast(self._cast_int(expr_ty, expr), e.target_type)
elif e.target_type.is_integer_subbyte():
from hidet.ir.expr import if_then_else

expr = self.visit(e.expr)
dtype = e.target_type
expr = if_then_else(expr < dtype.min_value, expr_ty(dtype.min_value), expr)
expr = if_then_else(expr >= dtype.max_value, expr_ty(dtype.max_value), expr)
if not dtype._signed:
if not dtype.signedness():
return cast(expr, dtype.storage)
storage_ty = dtype.storage
int_type = i32
Expand Down
44 changes: 17 additions & 27 deletions tests/ir/test_int_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import hidet


Expand Down Expand Up @@ -62,6 +64,7 @@ def func(out: f32[4, 4], inp: i4[4, 4]):

func = script_module.build()
import torch

data = torch.empty((4, 4), dtype=torch.float32, device="cuda")
data = hidet.from_torch(data)
inp = torch.empty((4, 2), dtype=torch.int8, device="cuda")
Expand All @@ -71,7 +74,7 @@ def func(out: f32[4, 4], inp: i4[4, 4]):


def test_int_2bit():
from hidet.ir.dtypes import i4, u4, i4x8, i8, f16, f32
from hidet.ir.dtypes import i2, u2, i8, f16, f32
from hidet.ir.expr import constant, cast
from hidet.lang import attrs
from hidet.lang import shared_tensor, register_tensor
Expand All @@ -80,54 +83,41 @@ def test_int_2bit():
with hidet.script_module() as script_module:

@hidet.script
def func(out: f32[4, 4]):
def func(out: f32[2, 2]):
attrs.func_kind = "cuda_kernel"
attrs.cuda.block_dim = 128
attrs.cuda.grid_dim = 4

a = constant(1, i4)
b = register_tensor('int4b', shape=[8, 2])
ptr = ~b[3, 1]
ptr = ptr + 4
ptr = ptr + (threadIdx.x * 444 + blockIdx.x * 888 + 555)
c = register_tensor('i4x8', shape=[1])
b[0, 1] = a
b[0, 1] = b[0, 2]
d = b[0, 1]
s = shared_tensor('uint4b', shape=[7, 8])
e = i8(b[3, 0])
s1 = shared_tensor('float32', shape=[64, 64])
s1[3, 4] = f16(b[4, 0])
a = constant(1, i2)

data = register_tensor('int4b', shape=[4, 4])
data = register_tensor('int2b', shape=[2, 2])

for i in range(4):
for j in range(4):
for i in range(2):
for j in range(2):
if i == 0 and j == 0:
data[i, j] = i4(-8)
data[i, j] = i2(-2)
elif j == 0:
data[i, j] = i4(f32(data[i - 1, 3]) + 1)
data[i, j] = i2(f32(data[i - 1, 1]) + 1)
else:
data[i, j] = i4(f32(data[i, j - 1]) + 1)
data[i, j] = i2(f32(data[i, j - 1]) + 1)

if threadIdx.x == 0 and blockIdx.x == 0:
for i in range(4):
for j in range(4):
for i in range(2):
for j in range(2):
d = data[i, j]
d = i4(f32(d) + 1)
out[i, j] = f32(d)

func = script_module.build()
import torch
data = torch.empty((4, 4), dtype=torch.float32, device="cuda")

data = torch.empty((2, 2), dtype=torch.float32, device="cuda")
data = hidet.from_torch(data)
func(data)
print(data.cpu().numpy())



if __name__ == "__main__":
hidet.option.cache_dir("./demo_int_subbyte")
hidet.option.save_lower_ir(True)

test_int_4bit()
pytest.main(__file__)

0 comments on commit 23c53d4

Please sign in to comment.