Skip to content

Commit

Permalink
[Transforms] fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaocenxiaocen committed Jan 20, 2024
1 parent 23c53d4 commit 803b1c2
Showing 1 changed file with 52 additions and 28 deletions.
80 changes: 52 additions & 28 deletions python/hidet/transforms/lower_integer_subbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Dict, List, Union, Tuple

from hidet.ir.tools import infer_type, simplify
from hidet.ir.type import DataType, TensorType, PointerType
from hidet.ir.type import BaseType, DataType, TensorType, TensorPointerType, PointerType
from hidet.ir.dtypes import i32
from hidet.ir.expr import Var, Expr, Add, TensorElement, Address, Constant, Cast, var, cast, bitwise_not
from hidet.ir.stmt import (
Expand All @@ -35,6 +35,24 @@
from hidet.transforms import Pass


def is_pointer_type(base_ty: BaseType):
return isinstance(base_ty, (PointerType, TensorPointerType, TensorType))


def get_pointer_base_type(base_ty: BaseType):
if isinstance(base_ty, PointerType):
return base_ty.base_type
elif isinstance(base_ty, TensorType):
return base_ty.dtype
else:
ttype = base_ty.tensor_type
return ttype.dtype


def is_integer_subbyte(dtype: BaseType):
return dtype.is_data_type() and dtype.is_integer_subbyte()


class LowerIntegerSubbyteRewriter(IRRewriter):
# convert subbyte integers to their storage type
# e.g.,
Expand Down Expand Up @@ -149,7 +167,7 @@ def visit_DataType(self, t: DataType):
def visit_TensorType(self, t: TensorType):
from hidet.ir.layout import row_major

if t.dtype.is_integer_subbyte():
if is_integer_subbyte(t.dtype):
shape = list(self.visit(t.shape))
assert len(shape) == 1
dtype = t.dtype
Expand All @@ -168,23 +186,22 @@ def visit_Var(self, v: Var):
return super().visit_Var(v)

def visit_Constant(self, e: Constant):
if isinstance(e.type, DataType):
if e.type.is_integer_subbyte():
ty = self.visit(e.type)
value = self.visit(e.value)
dtype = e.type
storage_ty = dtype.storage
mask = storage_ty.constant(dtype.bits_mask)
value = value & mask
return Constant(value, ty)
if is_integer_subbyte(e.type):
ty = self.visit(e.type)
value = self.visit(e.value)
dtype = e.type
storage_ty = dtype.storage
mask = storage_ty.constant(dtype.bits_mask)
value = value & mask
return Constant(value, ty)
return super().visit_Constant(e)

def visit_TensorElement(self, e: TensorElement):
if isinstance(e.base, Var):
base_ty = infer_type(e.base)
if isinstance(base_ty, TensorType):
dtype = base_ty.dtype
if dtype.is_integer_subbyte():
if is_pointer_type(base_ty):
dtype = get_pointer_base_type(base_ty)
if is_integer_subbyte(dtype):
base = self.visit(e.base)
assert len(e.indices) == 1
offset = self.visit(e.indices[0])
Expand All @@ -196,9 +213,9 @@ def visit_Address(self, e: Address):
base = e.expr.base
if isinstance(base, Var):
base_ty = infer_type(base)
if isinstance(base_ty, TensorType):
dtype = base_ty.dtype
if dtype.is_integer_subbyte():
if is_pointer_type(base_ty):
dtype = get_pointer_base_type(base_ty)
if is_integer_subbyte(dtype):
storage_ty = dtype.storage
storage_bits = storage_ty.nbits
dtype_bits = dtype.nbits
Expand All @@ -220,18 +237,20 @@ def _cast_int(self, dtype: DataType, expr: Expr):

def visit_Cast(self, e: Cast):
expr_ty = infer_type(e.expr)
if isinstance(expr_ty, DataType) and expr_ty.is_integer_subbyte():
if e.target_type.is_integer_subbyte():
if is_integer_subbyte(expr_ty):
if is_integer_subbyte(e.target_type):
raise NotImplementedError(f"casting from {expr_ty} to {e.target_type} is not supported yet")
expr = self.visit(e.expr)
return cast(self._cast_int(expr_ty, expr), e.target_type)
elif e.target_type.is_integer_subbyte():
elif is_integer_subbyte(e.target_type):
from hidet.ir.expr import if_then_else

expr = self.visit(e.expr)
dtype = e.target_type
expr = if_then_else(expr < dtype.min_value, expr_ty(dtype.min_value), expr)
expr = if_then_else(expr >= dtype.max_value, expr_ty(dtype.max_value), expr)
min_val = expr_ty(dtype.min_value.value)
max_val = expr_ty(dtype.max_value.value)
expr = if_then_else(expr < min_val, min_val, expr)
expr = if_then_else(expr >= max_val, max_val, expr)
if not dtype.signedness():
return cast(expr, dtype.storage)
storage_ty = dtype.storage
Expand All @@ -240,28 +259,33 @@ def visit_Cast(self, e: Cast):
shift = int_type.nbits - dtype.nbits
int_data = (int_data << shift) >> shift
return cast(int_data, storage_ty)
return super().visit_Cast(e)
target_type = self.visit(e.target_type)
expr = self.visit(e.expr)
if target_type is e.target_type and expr is e.expr:
return e
else:
return cast(expr, target_type)

def _subbyte_pointer_add(self, dtype: DataType, ptr: Union[Expr, Tuple[Expr]], offset: Expr):
self._get_divisor(dtype)
divisor = self._get_divisor(dtype)
if isinstance(ptr, tuple):
ptr, offset_ = ptr
offset = offset + offset_
if self.recursive_depth == 0:
return ptr + offset // 2
return ptr + offset // divisor
else:
return ptr, offset

def visit_Add(self, e: Add):
a_ty = infer_type(e.a)
b_ty = infer_type(e.b)
if isinstance(a_ty, PointerType) and a_ty.base_type.is_integer_subbyte():
if isinstance(a_ty, PointerType) and is_integer_subbyte(a_ty.base_type):
self.recursive_depth += 1
a = self.visit(e.a)
b = self.visit(e.b)
self.recursive_depth -= 1
return self._subbyte_pointer_add(a_ty.base_type, a, b)
elif isinstance(b_ty, PointerType) and b_ty.base_type.is_integer_subbyte():
elif isinstance(b_ty, PointerType) and is_integer_subbyte(b_ty.base_type):
self.recursive_depth += 1
a = self.visit(e.a)
b = self.visit(e.b)
Expand Down Expand Up @@ -295,7 +319,7 @@ def visit_BufferStoreStmt(self, stmt: BufferStoreStmt):
buf_ty = infer_type(stmt.buf)
if isinstance(buf_ty, TensorType):
dtype = buf_ty.dtype
if dtype.is_integer_subbyte():
if is_integer_subbyte(dtype):
buf = self.visit(stmt.buf)
indices = self.visit(stmt.indices)
value = self.visit(stmt.value)
Expand Down

0 comments on commit 803b1c2

Please sign in to comment.