diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 52a710a53..0472aefbe 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=redefined-builtin -from .matmul import batch_matmul, matmul, matmul_x86, matmul_cublas +from .matmul import batch_matmul, matmul, matmul_cublas, batch_matmul_x86 from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 diff --git a/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py b/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py index 2b2a7f74e..ea744164d 100644 --- a/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py +++ b/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py @@ -464,7 +464,7 @@ def visit_Call(self, e: Call): func_name = e.func_var.name if func_name in self.func_records: args = self.process_call(func_name, list(e.args)) - return Call(e.func_var, args) + return Call(e.func_var, tuple(args)) return super().visit_Call(e) def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt): diff --git a/python/hidet/graph/ops/matmul/__init__.py b/python/hidet/graph/ops/matmul/__init__.py index ae3f4c217..d5ea6210b 100644 --- a/python/hidet/graph/ops/matmul/__init__.py +++ b/python/hidet/graph/ops/matmul/__init__.py @@ -16,4 +16,4 @@ from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86 -from .matmul_f32_x86 import matmul_x86 +from .matmul_f32_x86 import batch_matmul_x86 diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index eeb467b30..81e352d24 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -11,57 +11,34 @@ # limitations under the License. from typing import List, Union from hidet.ir.dtypes import float32, int32 -from hidet.ir.expr import cast +from hidet.ir.expr import cast, is_constant from hidet.ir.module import IRModule from hidet.ir.compute import TensorNode from hidet.ir.stmt import DeclareScope from hidet.ir.task import Task from hidet.ir.compute import compute, reduce -from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast +from hidet.graph.ops.utils import input_like from hidet.ir.library import tune from hidet.graph.operator import Operator, Tensor -from hidet.graph.ops.utils import broadcast_indices from hidet.lang import attrs +from hidet.ir.expr import if_then_else class MatmulF32Taskx86(Task): def __init__(self, a: TensorNode, b: TensorNode): - a_shape = a.const_shape - b_shape = b.const_shape - if not a.type.dtype == float32 or not b.type.dtype == float32: - raise ValueError('Both inputs must be float32 tensors') - - if len(a_shape) < 2 or len(b_shape) < 2: - raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) - - self._assert( - a_shape[-1] == b_shape[-2], - msg=( - 'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' - ', got {} and {}'.format(a_shape, b_shape) - ), - ) - - self._assert( - can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), - msg=( - 'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' - 'got {} and {}'.format(a_shape, b_shape) - ), - ) - - k_size = a_shape[-1] - c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] + batch_size, m_size, k_size = a.shape + batch_size, k_size, n_size = b.shape + self.batch_size = batch_size + self.m_size = m_size + self.n_size = n_size + self.k_size = k_size c = compute( name='c', - shape=c_shape, - fcompute=lambda *indices: reduce( - shape=[k_size], - fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] - * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], - reduce_type='sum', + shape=[batch_size, m_size, n_size], + fcompute=lambda r, i, j: reduce( + shape=[k_size], fcompute=lambda k: a[r, i, k] * b[r, k, j], reduce_type='sum' ), ) @@ -69,7 +46,7 @@ def __init__(self, a: TensorNode, b: TensorNode): name='matmul_f32_x86', inputs=[a, b], outputs=[c], - attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, + attributes={'batch_size': batch_size, 'm_size': m_size, 'n_size': n_size, 'k_size': k_size}, ) def allow_epilogue(self) -> bool: @@ -81,7 +58,14 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(2, 4, 2, 2)]) + @tune.space( + 2, + MC=[144, 288, 432, 576, 720], + NC=[800], + KC=[256, 560, 768, 384], + ways=[(1, 4, 2, 1), (2, 4, 4, 1), (1, 4, 4, 1), (1, 2, 4, 2), (1, 4, 4, 2), (2, 4, 2, 2)], + ) def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -94,10 +78,12 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor - node_a, node_b = self.inputs[0], self.inputs[1] - a_shape = node_a.const_shape - b_shape = node_b.const_shape - m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] + task = self + + batch_size = task.batch_size + m_size = task.m_size + n_size = task.n_size + k_size = task.k_size MR, NR = 6, 16 @@ -335,7 +321,10 @@ def micro_kernel( packed_b_total_size = packed_b_total_width * packed_b_height packed_b_individual_size = packed_b_width * packed_b_height - packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) + # packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) # FIXME: what? Error on this line? + temp_packed_a_ind = (m_size + MR - 1) // MR * MR + packed_a_individual_height = if_then_else(temp_packed_a_ind > MR, MR, temp_packed_a_ind) + packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed packed_a_width = min(KC, (k_size + 8 - 1) // 8 * 8) @@ -827,36 +816,56 @@ def gemm_5th_loop( ################### Start of the main kernel ################### @hidet.script - def matmul_kernel_x86_v3( - a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size] + def matmul_kernel_x86( + a: float32[batch_size, m_size, k_size], + b: float32[batch_size, k_size, n_size], + c: float32[batch_size, m_size, n_size], ): attrs.func_kind = 'cpu_kernel' - - init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) - init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) + a_ptr = cast(a, ~float32) + b_ptr = cast(b, ~float32) + c_ptr = cast(c, ~float32) parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads - for tidx in grid(nthreads, attrs=parallel_attr): - tid_5th_loop = tidx - work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) - comm_id_5th_loop = tid_5th_loop - - gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) - - assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) - # matmul_kernel_x86_v3.kind = "cpu_kernel" + for batch in range(batch_size): + init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) + init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) + # Iterate through the batch dimension, and for each batch, + # locate the corresponding a, b, and c matrices, and then call the single matmul kernel + a_matrix_size = m_size * k_size + b_matrix_size = k_size * n_size + c_matrix_size = m_size * n_size + a_matrix = as_tensor_pointer(a_ptr + (batch * a_matrix_size), dtype=float32, shape=[m_size, k_size]) + b_matrix = as_tensor_pointer(b_ptr + (batch * b_matrix_size), dtype=float32, shape=[k_size, n_size]) + c_matrix = as_tensor_pointer(c_ptr + (batch * c_matrix_size), dtype=float32, shape=[m_size, n_size]) + for tidx in grid(nthreads, attrs=parallel_attr): + tid_5th_loop = tidx + work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) + comm_id_5th_loop = tid_5th_loop + + gemm_5th_loop(a_matrix, b_matrix, c_matrix, work_id_5th_loop, comm_id_5th_loop) + + assert isinstance(matmul_kernel_x86, hidet.ir.Function) + # matmul_kernel_x86.kind = "cpu_kernel" ir_module = module.ir_module() return ir_module class Matmulx86Op(Operator): def __init__(self, a: Tensor, b: Tensor): - if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): - raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) + if not ( + len(a.shape) == len(b.shape) == 3 + and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0]) + and (not is_constant(a.shape[2], b.shape[1]) or a.shape[2] == b.shape[1]) + ): + raise ValueError( + "Matrix multiplication expects tensor A and B with shape [B, M, K] and [B, K, N]" + + ", got {} and {}".format(a.shape, b.shape) + ) task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b')) super().__init__(inputs=[a, b], attributes={}, task=task) -def matmul_x86(a: Tensor, b: Tensor) -> Tensor: +def batch_matmul_x86(a: Tensor, b: Tensor) -> Tensor: return Matmulx86Op(a, b).outputs[0] diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 8d6adbdbf..4424b1362 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -22,6 +22,7 @@ from .matmul import MatmulOp from .batch_matmul import batch_matmul +from .matmul_f32_x86 import batch_matmul_x86 from .matmul_f16 import matmul_f16 from ..transform import broadcast, flatten from ..utils import broadcast_shapes @@ -90,7 +91,7 @@ class MatmulResolveRule(ResolveRule): The generic matrix multiplication operator has the same semantics as numpy.matmul that accepts variable dimensions of inputs. - On ther other hand, the batched matrix multiplication operator accepts inputs with shape: + On the other hand, the batched matrix multiplication operator accepts inputs with shape: [batch_size, m_size, k_size] x [batch_size, k_size, n_size] This resolve rule also parallelize k dimension when possible, and determine the mma instruction. @@ -125,6 +126,9 @@ def run_batch_matmul(self, a: Tensor, b: Tensor) -> Tensor: c = batch_matmul(aa, bb, mma=mma).reshape([batch_size, nparts, m_size, n_size]).sum(1) return c + def run_batch_matmul_cpu(self, a: Tensor, b: Tensor) -> Tensor: + return batch_matmul_x86(a, b) + def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: assert isinstance(op, MatmulOp) a: Tensor = op.inputs[0] @@ -133,30 +137,35 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: if a.dtype.nbytes > 4 or b.dtype.nbytes > 4: return None + if op.device.is_cpu(): + run_func = self.run_batch_matmul_cpu + else: + run_func = self.run_batch_matmul + if len(a.shape) == 1: # shape: [a] a = a.unsqueeze([0, 1]) # [1, 1, a] if len(b.shape) == 2: # shape: [a, b] # [a] x [a, b] -> [b] b = b.unsqueeze([0]) # [1, a, b] - c = self.run_batch_matmul(a, b) # [1, 1, b] + c = run_func(a, b) # [1, 1, b] c = c.squeeze([0, 1]) # [b] else: assert len(b.shape) >= 3 # shape example: [b, c, a, d] # [a] x [b, c, a, d] -> [b, c, d] b = flatten(b, start_dim=0, end_dim=-3) # [b * c, a, d] - c = self.run_batch_matmul(a, b) # [b * c, 1, d] + c = run_func(a, b) # [b * c, 1, d] c = c.reshape(c_shape) # [b, c, d] elif len(b.shape) == 1: # shape: [b] b = b.unsqueeze([0, 2]) # [1, b, 1] if len(a.shape) == 2: # shape: [a, b] a = a.unsqueeze([0]) # [1, a, b] - c = self.run_batch_matmul(a, b) # [1, a, 1] + c = run_func(a, b) # [1, a, 1] c = c.squeeze([0, 2]) # [a] else: assert len(a.shape) >= 3 # shape example: [a, c, d, b] # [a, c, d, b] x [b] -> [a, c, d] a = flatten(a, start_dim=0, end_dim=-3) # [a * c, d, b] - c = self.run_batch_matmul(a, b) # [a * c, d, 1] + c = run_func(a, b) # [a * c, d, 1] c = c.reshape(c_shape) # [a, c, d] else: # example: [a, b, c] x [c, d] -> [a, b, d] @@ -168,7 +177,7 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: b_broadcast_shape = c_head + list(b.shape[-2:]) a = flatten(broadcast(a, a_broadcast_shape), start_dim=0, end_dim=-3) b = flatten(broadcast(b, b_broadcast_shape), start_dim=0, end_dim=-3) - c = self.run_batch_matmul(a, b) + c = run_func(a, b) c = c.reshape(c_shape) return [c] @@ -240,8 +249,9 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: return [c] def resolve(self, op: Operator) -> Optional[List[Tensor]]: + print("Here resolve is called.......") if op.device.is_cpu(): - return None + return self.resolve_generic(op) resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_f16, self.resolve_generic] for resolve_func in resolve_funcs: outs = resolve_func(op) diff --git a/python/hidet/graph/ops/softmax.py b/python/hidet/graph/ops/softmax.py index 4421eae61..d17544adb 100644 --- a/python/hidet/graph/ops/softmax.py +++ b/python/hidet/graph/ops/softmax.py @@ -157,11 +157,6 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]): return ir_module - def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: - if self.inputs[0].type.dtype != float32: - return NotImplemented # use auto-scheduler - return tune.extract_ir_modules(self.schedule_softmax_cpu) - class CPUSoftmaxTask(SoftmaxTask): def allow_epilogue(self) -> bool: @@ -170,6 +165,11 @@ def allow_epilogue(self) -> bool: def allow_prologue(self) -> bool: return False + def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: + if self.inputs[-1].type.dtype != float32: + return NotImplemented # use auto-scheduler + return tune.extract_ir_modules(self.schedule_softmax_cpu) + @tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96]) @tune.space(1, nthreads=['', 8, 16]) def schedule_softmax_cpu(self, nthreads='') -> IRModule: diff --git a/python/hidet/ir/schedulers/cpu/scheduler.py b/python/hidet/ir/schedulers/cpu/scheduler.py index 9089c288c..9e8a5301a 100644 --- a/python/hidet/ir/schedulers/cpu/scheduler.py +++ b/python/hidet/ir/schedulers/cpu/scheduler.py @@ -26,7 +26,15 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode, params, param_map, call_args = self.grid_compute_params_and_args(node, tensor_map) if self.task is not None: - name = f'{self.task.name}_compute_{node.name}' + # name = f'{self.task.name}_compute_{node.name}' + from hidet.graph.ops.fusion.fused_operator import FusedTask + + if isinstance(self.task, FusedTask): + fused_name = self.task.attrs['fused_ops'].replace(' ', '_') + name = f'fused_{fused_name}_{node.name}' + else: + name = f'{self.task.name}_{node.name}' + else: name = f'compute_{node.name}' diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index c5c67aa50..de38cde78 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -18,15 +18,13 @@ from hidet.testing import check_binary, check_binary_dynamic, check_torch_binary -# @pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.") -@pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]]) +@pytest.mark.parametrize("a_shape, b_shape", [[[1, 333, 444], [1, 444, 555]], [[1, 133, 1], [1, 1, 177]]]) def test_matmul_x86(a_shape, b_shape): - # TODO: Doesn't support broadcasting yet; need to add it later? check_binary( a_shape, b_shape, lambda x, y: np.matmul(x, y), - lambda x, y: ops.matmul_x86(x, y) - ops.matmul_x86(x, y) + ops.matmul_x86(x, y), + lambda x, y: ops.batch_matmul_x86(x, y), dtype="float32", atol=1e-4, rtol=1e-4,