Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Support integer subbyte #403

Merged
merged 4 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,8 @@ def visit_DataType(self, t: DataType):
'float32x8': '__m256',
'int8x4': 'char4',
'uint8x4': 'uint4',
'int4bx8': 'uint32_t',
'uint4bx8': 'uint32_t',
}

self.require_complex = self.require_complex or t.name in ['complex64', 'complex128']
Expand Down
26 changes: 24 additions & 2 deletions python/hidet/ir/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from hidet.ir.type import DataType
from .integer import int8, int16, int32, int64, uint8, uint16, uint32, uint64
from .integer import i8, i16, i32, i64, u8, u16, u32, u64
from .integer_subbyte import int4b, int3b, int2b, int1b, uint4b, uint3b, uint2b, uint1b
from .integer_subbyte import i4, i3, i2, i1, u4, u3, u2, u1
from .floats import float16, float32, float64, bfloat16, tfloat32
from .floats import f16, f32, f64, bf16, tf32
from .boolean import boolean
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, vectorize
from .vector import f16x2, f32x4, f32x8
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize
from .vector import f16x2, f32x4, f32x8, i4x8, u4x8
from .complex import complex64, complex128
from .promotion import promote_type
from .utils import dtype_to_numpy, finfo, iinfo
Expand All @@ -43,6 +45,16 @@
'float16x2': float16x2,
'int8x4': int8x4,
'uint8x4': uint8x4,
'int4b': int4b,
'int3b': int3b,
'int2b': int2b,
'int1b': int1b,
'uint4b': uint4b,
'uint3b': uint3b,
'uint2b': uint2b,
'uint1b': uint1b,
'int4bx8': int4bx8,
'uint4bx8': uint4bx8,
}

sname2dtype = {
Expand All @@ -66,6 +78,16 @@
'f32x8': f32x8,
'f16x2': f16x2,
'i8x4': int8x4,
'i4': int4b,
'i3': int3b,
'i2': int2b,
'i1': int1b,
'u4': uint4b,
'u3': uint3b,
'u2': uint2b,
'u1': uint1b,
'i4x8': int4bx8,
'u4x8': uint4bx8,
}


Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class Boolean(DataType):
def __init__(self):
super().__init__('bool', 'bool', 1)

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, name, short_name, base_dtype: DataType):
super().__init__(name, short_name, 2 * base_dtype.nbytes)
self.base_dtype: DataType = base_dtype

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/dtypes/floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value, eps, smallest
self._eps: float = eps
self._smallest_normal: float = smallest_normal

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return True

Expand Down
6 changes: 6 additions & 0 deletions python/hidet/ir/dtypes/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, name, short_name, nbytes, min_value, max_value):
self._min_value: int = min_value
self._max_value: int = max_value

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand All @@ -54,6 +57,9 @@ def constant(self, value: Any):
raise ValueError('Value {} is out of range for {}.'.format(value, self.name))
return constant(value, self)

def signedness(self):
return self._min_value < 0

@property
def one(self):
return self.constant(1)
Expand Down
70 changes: 70 additions & 0 deletions python/hidet/ir/dtypes/integer_subbyte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.type import DataType
from .integer import IntegerType, IntInfo, uint8, uint32


class IntegerSubbyteType(IntegerType):
def __init__(self, name, short_name, storage, nbits, signed, min_value, max_value):
nbytes = storage.nbytes
super().__init__(name, short_name, nbytes, min_value, max_value)
self._storage: DataType = storage
self._nbits: int = nbits
self._bits_mask: int = (1 << self._nbits) - 1
self._sign_mask: int = 1 << (self._nbits - 1) if self.signedness() else 0

@property
def storage(self):
return self._storage

@property
def nbytes(self):
raise TypeError(f"Cannot access nbytes property for the type({self}")

@property
def nbits(self):
return self._nbits

@property
def bits_mask(self):
return self._bits_mask

@property
def sign_mask(self):
return self._sign_mask

def is_integer_subbyte(self):
return True

def iinfo(self) -> IntInfo:
return IntInfo(self._nbits, self._max_value, self._min_value, self)


int4b = IntegerSubbyteType('int4b', 'i4', uint8, 4, True, -8, 7)
int3b = IntegerSubbyteType('int3b', 'i3', uint32, 3, True, -4, 3)
int2b = IntegerSubbyteType('int2b', 'i2', uint8, 2, True, -2, 1)
int1b = IntegerSubbyteType('int1b', 'i1', uint8, 1, True, -1, 0)

uint4b = IntegerSubbyteType('uint4b', 'u4', uint8, 4, False, 0, 16)
uint3b = IntegerSubbyteType('uint3b', 'u3', uint32, 3, False, 0, 8)
uint2b = IntegerSubbyteType('uint2b', 'u2', uint8, 2, False, 0, 4)
uint1b = IntegerSubbyteType('uint1b', 'u1', uint8, 1, False, 0, 1)

i4 = int4b
i3 = int3b
i2 = int2b
i1 = int1b

u4 = uint4b
u3 = uint3b
u2 = uint2b
u1 = uint1b
20 changes: 19 additions & 1 deletion python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,26 @@
from hidet.ir.type import DataType
from .floats import float32, float16
from .integer import int8, uint8
from .integer_subbyte import int4b, uint4b


class VectorType(DataType):
def __init__(self, lane_type: DataType, num_lanes: int):
name = '{}x{}'.format(lane_type.name, num_lanes)
short_name = '{}x{}'.format(lane_type.short_name, num_lanes)
nbytes = lane_type.nbytes * num_lanes
nbytes = (
lane_type.nbytes * num_lanes if not lane_type.is_integer_subbyte() else lane_type.nbits * num_lanes // 8
)
super().__init__(name, short_name, nbytes)
self._num_lanes: int = num_lanes
self._lane_type: DataType = lane_type

if lane_type.is_vector():
raise ValueError('Cannot create a vector type of vectors')

def is_integer_subbyte(self) -> bool:
return False

def is_float(self) -> bool:
return False

Expand Down Expand Up @@ -90,6 +96,18 @@ def max_value(self):
float16x2 = VectorType(float16, 2)
f16x2 = float16x2

int4bx2 = VectorType(int4b, 2)
i4x2 = int4bx2

uint4bx2 = VectorType(uint4b, 2)
u4x2 = uint4bx2

int4bx8 = VectorType(int4b, 8)
i4x8 = int4bx8

uint4bx8 = VectorType(uint4b, 8)
u4x8 = uint4bx8


def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
table = {
Expand Down
12 changes: 12 additions & 0 deletions python/hidet/ir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def from_str(name):
else:
return DeclareScope.Default

def is_global(self):
return self == DeclareScope.Global

def is_shared(self):
return self == DeclareScope.Shared

def is_register(self):
return self == DeclareScope.Register

def is_memory(self):
return not self.is_register()


class ForStmtAttr:
def __init__(self, unroll=False, unroll_factor=None, unroll_explicit=False, parallel=False, parallel_threads=None):
Expand Down
36 changes: 35 additions & 1 deletion python/hidet/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def as_data_type(self) -> Optional[DataType]:
class DataType(BaseType):
"""
The data type that defines how to interpret the data in memory.

"""

def __init__(self, name: str, short_name: str, nbytes: int):
Expand Down Expand Up @@ -131,6 +132,36 @@ def short_name(self) -> str:
def nbytes(self) -> int:
return self._nbytes

@property
def nbits(self) -> int:
"""
Get the bit length of the data type

Note:
1. The bit length of the data type itself other than the bit length of its storage.
2. For regular data types, the nbits can be computed from its nbytes property.
3. For subbyte data types, the nbits is defined when constructing the data type,
and this method will also be overridden for subbyte data types.
4. In addition, we cannot access the nbytes for a subbyte data type, otherwise
a type error will be raised.
"""
return self._nbytes * 8

@property
def storage(self) -> DataType:
"""
Get the actual storage type of the data type

Note:
1. The storage of a regular data type is the data type itself, while the storage
of a subbyte type is the type of its actual storage. e.g., the storage of int4b is uint8
2. The property will be overridden in the subclass of subbyte types.
"""
return self

def is_integer_subbyte(self) -> bool:
raise NotImplementedError()

def is_float(self) -> bool:
raise NotImplementedError()

Expand Down Expand Up @@ -187,7 +218,10 @@ def __invert__(self):
return TensorPointerType.from_tensor_type(self)

def storage_bytes(self) -> Expr:
return self.layout.size * self.dtype.nbytes
if self.dtype.is_integer_subbyte():
return self.layout.size * self.dtype.nbits // 8
else:
return self.layout.size * self.dtype.nbytes

def const_shape(self) -> List[int]:
return [int(v) for v in self.shape]
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .check_launch_configuration import check_launch_configuration_pass
from .lower_special_cast import lower_special_cast_pass
from .annotate_header_and_libs import annotate_header_and_libs_pass
from .lower_integer_subbyte import lower_integer_subbyte_pass


def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule:
Expand All @@ -63,6 +64,7 @@ def lower(ir_module: IRModule) -> IRModule:
declare_to_let_pass(),
rule_based_simplify_pass(), # make ir more readable
flatten_tensor_index_pass(),
lower_integer_subbyte_pass(),
lower_special_cast_pass(),
inline_function_pass(),
resolve_primitive_func_pass(),
Expand Down
42 changes: 38 additions & 4 deletions python/hidet/transforms/flatten_tensor_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.type import TensorType, tensor_type, tensor_pointer_type, PointerType, TensorPointerType, ArrayType
from typing import Dict

from hidet.ir.type import (
TensorType,
tensor_type,
tensor_pointer_type,
PointerType,
TensorPointerType,
ArrayType,
FuncType,
func_type,
)
from hidet.ir.expr import Var, TensorElement, TensorSlice, tensor_element
from hidet.ir.stmt import BufferStoreStmt, DeclareStmt
from hidet.ir.layout import row_major
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.functors import IRRewriter
from hidet.ir.tools import simplify, TypeInfer
from hidet.transforms import Pass
Expand All @@ -28,6 +40,15 @@ class FlattenTensorAccessRewriter(IRRewriter):
def __init__(self):
super().__init__()
self.type_infer = TypeInfer()
self.func2func_type: Dict[str, FuncType] = {}

def visit_Var(self, v: Var):
if isinstance(v.type, FuncType):
if v.name in self.func2func_type:
func_ty = self.func2func_type[v.name]
if func_ty is not v.type:
return Var(v.hint, func_ty, v.name)
return super().visit_Var(v)

def visit_Function(self, func: Function):
for var in func.params:
Expand All @@ -39,7 +60,13 @@ def visit_Function(self, func: Function):
self.memo[var] = Var(var.hint, tensor_pointer_type(var.type.tensor_type.dtype, [size]))
body = self(func.body)
params = [self(p) for p in func.params]
return Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs)
if body is func.body and all(p is p1 for p, p1 in zip(params, func.params)):
return func
else:
new_func = Function(func.name, params, body, func.ret_type, kind=func.kind, attrs=func.attrs)
param_types = [p.type for p in params]
self.func2func_type[func.name] = func_type(param_types, func.ret_type)
return new_func

def get_layout(self, e) -> DataLayout:
e_type = self.type_infer(e)
Expand Down Expand Up @@ -103,9 +130,16 @@ def visit_TensorSlice(self, e: TensorSlice):


class FlattenTensorIndexPass(Pass):
def process_func(self, func: Function) -> Function:
def process_module(self, ir_module: IRModule) -> IRModule:
flatten_index = FlattenTensorAccessRewriter()
return flatten_index(func)

new_funcs = {}
for name, func in ir_module.functions.items():
new_funcs[name] = flatten_index(func)
if all(new_funcs[name] is ir_module.functions[name] for name in new_funcs):
return ir_module
else:
return ir_module.copy().reset_funcs(new_funcs, ir_module.global_vars)


def flatten_tensor_index_pass():
Expand Down
Loading
Loading