From efe3e14c94d7931557705635c3b1e79226a015d7 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Sat, 27 May 2023 22:01:42 -0400 Subject: [PATCH 001/148] . --- tests/operators/test_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index 4c3402f42..9fb3dbfa1 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -17,7 +17,7 @@ from hidet.testing import check_binary -@pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]]) +@pytest.mark.parametrize("a_shape, b_shape", [[[367, 369], [369, 470]], [[133, 1], [1, 177]]]) def test_matmul_x86(a_shape, b_shape): # TODO: Doesn't support broadcasting yet; need to add it later? check_binary( From a7bce75ca0a0fb5b5c2fe306eb754af66af542d9 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Mon, 17 Jul 2023 13:11:45 -0400 Subject: [PATCH 002/148] added basic openMP primitives --- python/hidet/ir/primitives/__init__.py | 1 + python/hidet/ir/primitives/cpu/__init__.py | 2 ++ python/hidet/lang/cpu.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/python/hidet/ir/primitives/__init__.py b/python/hidet/ir/primitives/__init__.py index 21c1ae524..29ac28c6e 100644 --- a/python/hidet/ir/primitives/__init__.py +++ b/python/hidet/ir/primitives/__init__.py @@ -26,6 +26,7 @@ from . import cpu from .cpu import avx_f32x4_store, avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_setzero from .cpu import avx_free, avx_malloc +from .cpu import openmp_get_num_threads, openmp_get_thread_num # cuda primitive functions and variables from . import cuda diff --git a/python/hidet/ir/primitives/cpu/__init__.py b/python/hidet/ir/primitives/cpu/__init__.py index a31a708ef..60fdd534a 100644 --- a/python/hidet/ir/primitives/cpu/__init__.py +++ b/python/hidet/ir/primitives/cpu/__init__.py @@ -14,3 +14,5 @@ from .avx import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero from .avx import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero from .avx import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc + +from .openmp import openmp_get_thread_num, openmp_get_num_threads diff --git a/python/hidet/lang/cpu.py b/python/hidet/lang/cpu.py index 9e58fb445..a1073aa0e 100644 --- a/python/hidet/lang/cpu.py +++ b/python/hidet/lang/cpu.py @@ -26,3 +26,5 @@ avx_f32x8_setzero, ) from hidet.ir.primitives.cpu import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc + +from hidet.ir.primitives.cpu import openmp_get_thread_num, openmp_get_num_threads From d7f64698b7d0626cbfa04812650a35c88e7c3e96 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Tue, 8 Aug 2023 12:58:05 -0400 Subject: [PATCH 003/148] added those primitives back --- python/hidet/ir/primitives/cpu/avx.py | 65 +++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/hidet/ir/primitives/cpu/avx.py b/python/hidet/ir/primitives/cpu/avx.py index bc87a79e0..bc0392df1 100644 --- a/python/hidet/ir/primitives/cpu/avx.py +++ b/python/hidet/ir/primitives/cpu/avx.py @@ -24,12 +24,18 @@ def register_primitive_functions(): ('avx_x86_float32x4_broadcast', '_mm_broadcast_ss', FuncType([PointerType('float32')], 'float32x4')), ('avx_x86_float32x4_fmadd', '_mm_fmadd_ps', FuncType(['float32x4', 'float32x4', 'float32x4'], 'float32x4')), ('avx_x86_float32x4_load', '_mm_loadu_ps', FuncType([PointerType('float32')], 'float32x4')), + ('avx_x86_float32x4_load_aligned', '_mm_load_ps', FuncType([PointerType('float32')], 'float32x4')), ('avx_x86_float32x4_store', '_mm_storeu_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())), + ( + 'avx_x86_float32x4_store_aligned', '_mm_store_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())), ('avx_x86_float32x4_setzero', '_mm_setzero_ps', FuncType([], 'float32x4')), ('avx_x86_float32x8_broadcast', '_mm256_broadcast_ss', FuncType([PointerType('float32')], 'float32x8')), ('avx_x86_float32x8_fmadd', '_mm256_fmadd_ps', FuncType(['float32x8', 'float32x8', 'float32x8'], 'float32x8')), ('avx_x86_float32x8_load', '_mm256_loadu_ps', FuncType([PointerType('float32')], 'float32x8')), + ('avx_x86_float32x8_load_aligned', '_mm256_load_ps', FuncType([PointerType('float32')], 'float32x8')), ('avx_x86_float32x8_store', '_mm256_storeu_ps', FuncType([PointerType('float32'), 'float32x8'], VoidType())), + ('avx_x86_float32x8_store_aligned', '_mm256_store_ps', + FuncType([PointerType('float32'), 'float32x8'], VoidType())), ('avx_x86_float32x8_setzero', '_mm256_setzero_ps', FuncType([], 'float32x8')), ('avx_x86_malloc', '_mm_malloc', FuncType(['uint64', 'uint64'], PointerType(VoidType()))), ('avx_x86_free', '_mm_free', FuncType([PointerType(VoidType())], VoidType())), @@ -39,6 +45,20 @@ def register_primitive_functions(): 'memcpy', FuncType([PointerType(VoidType()), PointerType(VoidType()), 'uint64'], PointerType(VoidType())), ), + ('avx_x86_float32x8_unpacklo', '_mm256_unpacklo_ps', FuncType(['float32x8', 'float32x8'], 'float32x8')), + ('avx_x86_float32x8_unpackhi', '_mm256_unpackhi_ps', FuncType(['float32x8', 'float32x8'], 'float32x8')), + ('avx_x86_float32x8_shuffle', '_mm256_shuffle_ps', FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8')), + ('avx_x86_float32x8_cast_float32x4', '_mm256_castps256_ps128', FuncType(['float32x8'], 'float32x4')), + ( + 'avx_x86_float32x8_insert_float32x4', + '_mm256_insertf128_ps', + FuncType(['float32x8', 'float32x4', 'int32'], 'float32x8') + ), + ( + 'avx_x86_float32x8_permute2float32x4', + '_mm256_permute2f128_ps', + FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8') + ), ] for name, codegen_name, func_type in functions: register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name) @@ -92,13 +112,58 @@ def avx_f32x4_load(addr: Expr) -> Call: return call_primitive_func('avx_x86_float32x4_load', [addr]) +def avx_f32x4_load_aligned(addr: Expr) -> Call: + return call_primitive_func('avx_x86_float32x4_load_aligned', [addr]) + + def avx_f32x8_load(addr: Expr) -> Call: return call_primitive_func('avx_x86_float32x8_load', [addr]) +def avx_f32x8_load_aligned(addr: Expr) -> Call: + return call_primitive_func('avx_x86_float32x8_load_aligned', [addr]) + + def avx_f32x4_store(addr: Expr, src: Expr) -> Call: return call_primitive_func('avx_x86_float32x4_store', [addr, src]) +def avx_f32x4_store_aligned(addr: Expr, src: Expr) -> Call: + return call_primitive_func('avx_x86_float32x4_store_aligned', [addr, src]) + + def avx_f32x8_store(addr: Expr, src: Expr) -> Call: return call_primitive_func('avx_x86_float32x8_store', [addr, src]) + + +def avx_f32x8_store_aligned(addr: Expr, src: Expr) -> Call: + return call_primitive_func('avx_x86_float32x8_store_aligned', [addr, src]) + + +def avx_f32x8_unpacklo(a: Expr, b: Expr) -> Call: + return call_primitive_func('avx_x86_float32x8_unpacklo', [a, b]) + + +def avx_f32x8_unpackhi(a: Expr, b: Expr) -> Call: + return call_primitive_func('avx_x86_float32x8_unpackhi', [a, b]) + + +def avx_f32x8_shuffle(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call: + return call_primitive_func('avx_x86_float32x8_shuffle', [a, b, imm]) + + +def avx_f32x8_cast_f32x4(a: Expr) -> Call: + return call_primitive_func('avx_x86_float32x8_cast_float32x4', [a]) + + +def avx_f32x8_insert_f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call: + return call_primitive_func('avx_x86_float32x8_insert_float32x4', [a, b, imm]) + + +def avx_f32x8_permute2f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call: + return call_primitive_func('avx_x86_float32x8_permute2float32x4', [a, b, imm]) + + + + + From f211a48669cd431012581c97d5ba01934a6d1f4e Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Sat, 12 Aug 2023 21:32:42 -0400 Subject: [PATCH 004/148] let me pretend like it's all good for tonight --- python/hidet/ir/primitives/cpu/__init__.py | 13 ++++- python/hidet/ir/primitives/cpu/atomic.py | 59 ++++++++++++++++++++++ python/hidet/lang/cpu.py | 20 +++++++- 3 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 python/hidet/ir/primitives/cpu/atomic.py diff --git a/python/hidet/ir/primitives/cpu/__init__.py b/python/hidet/ir/primitives/cpu/__init__.py index 60fdd534a..f62df917e 100644 --- a/python/hidet/ir/primitives/cpu/__init__.py +++ b/python/hidet/ir/primitives/cpu/__init__.py @@ -14,5 +14,16 @@ from .avx import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero from .avx import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero from .avx import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc +from .avx import avx_f32x8_store_aligned, avx_f32x8_load_aligned +from .avx import avx_f32x4_store_aligned, avx_f32x4_load_aligned +from .avx import ( + avx_f32x8_unpackhi, + avx_f32x8_unpacklo, + avx_f32x8_shuffle, + avx_f32x8_cast_f32x4, + avx_f32x8_insert_f32x4, + avx_f32x8_permute2f32x4, +) + +from .atomic import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor -from .openmp import openmp_get_thread_num, openmp_get_num_threads diff --git a/python/hidet/ir/primitives/cpu/atomic.py b/python/hidet/ir/primitives/cpu/atomic.py new file mode 100644 index 000000000..b38a52247 --- /dev/null +++ b/python/hidet/ir/primitives/cpu/atomic.py @@ -0,0 +1,59 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from hidet.ir.expr import Expr, Call +from hidet.ir.type import FuncType, VoidType, PointerType +from hidet.ir.primitives.func import register_primitive_function +from hidet.utils import initialize +from hidet.ir.primitives.func import call_primitive_func + + +@initialize() +def register_primitive_functions(): + functions = [ + ('cpu_atomic_load_n', '__atomic_load_n', FuncType([PointerType(VoidType()), 'int32'], 'int32')), + ('cpu_atomic_add_fetch', '__atomic_add_fetch', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')), + ('cpu_atomic_fetch_xor', '__atomic_fetch_xor', FuncType([PointerType(VoidType()), 'int32', 'int32'], 'int32')), + ] + + for name, codegen_name, func_type in functions: + register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name) + + +def cpu_atomic_load_n(ptr: Expr, order: Union[Expr, int]) -> Expr: + return call_primitive_func('cpu_atomic_load_n', [ptr, order]) + + +def cpu_atomic_add_fetch(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr: + return call_primitive_func('cpu_atomic_add_fetch', [ptr, val, order]) + + +def cpu_atomic_fetch_xor(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr: + return call_primitive_func('cpu_atomic_fetch_xor', [ptr, val, order]) + + + + + + + + + + + + + + + + + diff --git a/python/hidet/lang/cpu.py b/python/hidet/lang/cpu.py index a1073aa0e..0a2da1da8 100644 --- a/python/hidet/lang/cpu.py +++ b/python/hidet/lang/cpu.py @@ -27,4 +27,22 @@ ) from hidet.ir.primitives.cpu import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc -from hidet.ir.primitives.cpu import openmp_get_thread_num, openmp_get_num_threads +# from hidet.ir.primitives.cpu import openmp_get_thread_num, openmp_get_num_threads + +from hidet.ir.primitives.cpu import ( + avx_f32x8_store_aligned, + avx_f32x8_load_aligned, + avx_f32x4_store_aligned, + avx_f32x4_load_aligned, +) + +from hidet.ir.primitives.cpu import ( + avx_f32x8_unpackhi, + avx_f32x8_unpacklo, + avx_f32x8_shuffle, + avx_f32x8_cast_f32x4, + avx_f32x8_insert_f32x4, + avx_f32x8_permute2f32x4, +) + +from hidet.ir.primitives.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor \ No newline at end of file From bbb5afc211d36cda4c05ccdab85dbfb84a456d23 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Sun, 13 Aug 2023 17:22:05 -0400 Subject: [PATCH 005/148] ... --- .../graph/ops/matmul/matmul_f32_x86_v3.py | 974 ++++++++++++++++++ 1 file changed, 974 insertions(+) create mode 100644 python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py new file mode 100644 index 000000000..0c3abff7f --- /dev/null +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py @@ -0,0 +1,974 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union +from hidet.ir.dtypes import float32, int32 +from hidet.ir.expr import cast +from hidet.ir.module import IRModule +from hidet.ir.compute import TensorNode +from hidet.ir.primitives import avx_malloc +from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned +from hidet.ir.stmt import DeclareScope +from hidet.ir.task import Task +from hidet.ir.compute import compute, reduce +from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast +from hidet.ir.library import tune +from hidet.graph.operator import Operator, Tensor +from hidet.graph.ops.utils import broadcast_indices + + +class MatmulF32Taskx86_v2(Task): + + def __init__(self, a: TensorNode, b: TensorNode): + a_shape = a.const_shape + b_shape = b.const_shape + + if not a.type.dtype == float32 or not b.type.dtype == float32: + raise ValueError('Both inputs must be float32 tensors') + + if len(a_shape) < 2 or len(b_shape) < 2: + raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) + + self._assert( + a_shape[-1] == b_shape[-2], + msg=( + 'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' + ', got {} and {}'.format(a_shape, b_shape) + ), + ) + + self._assert( + can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), + msg=( + 'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' + 'got {} and {}'.format(a_shape, b_shape) + ), + ) + + k_size = a_shape[-1] + c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] + + c = compute( + name='c', + shape=c_shape, + fcompute=lambda *indices: reduce( + shape=[k_size], + fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] + * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], + reduce_type='sum', + ), + ) + + super().__init__( + name='matmul_f32_x86_v2', + inputs=[a, b], + outputs=[c], + attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, + ) + + def allow_epilogue(self) -> bool: + return True + + def allow_prologue(self) -> bool: + return False + + def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: + return tune.extract_ir_modules(self.schedule_matmulf32_x86) + + # @tune.space( + # 2, + # block_m=[2016, 3024], + # block_n=[64, 144, 192, 256, 384, 512, 592, 672, 752, 896, 1024], + # block_k=[96, 128, 256, 384, 512, 560, 688, 784], + # nthreads=[4, 8, 16, 32], + # ) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], nthreads=[8, 16]) + def schedule_matmulf32_x86( + self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) + ) -> IRModule: + import hidet + from hidet.ir.type import tensor_type + from hidet.lang import tensor, grid, as_tensor_pointer + from hidet.lang.layout import row_major, column_major + from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast + from hidet.lang.cpu import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store + from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned + from hidet.lang.cpu import avx_f32x4_store_aligned, avx_f32x4_load_aligned + from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi + from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 + from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 + from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor + + node_a, node_b = self.inputs[0], self.inputs[1] + a_shape = node_a.const_shape + b_shape = node_b.const_shape + m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] + + MR, NR = 6, 16 + + tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') + + packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC)) + packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR)) + + # Get the number of threads... + loop5_nways, loop3_nways, macro_nways, loop1_nways = ways + loop4_nways = 1 + nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways + + # Get the number of threads remaining at each level + loop5_nthreads = nthreads + loop4_nthreads = loop5_nthreads // loop5_nways + loop3_nthreads = loop4_nthreads + macro_nthreads = loop3_nthreads // loop3_nways + loop1_nthreads = macro_nthreads // macro_nways + + packb_nthreads = loop3_nthreads + packa_nthreads = macro_nthreads + + # TODO: Since Hidet doesn't support the parallel region syntax as in OpenMP, + # TODO: We instead use a loop to simulate the parallel region, with the "thread id" being the loop index. + outermost_iters = nthreads + + loop5_thrcomm_barrier_sense = 0 + loop5_thrcomm_barrier_threads_arrived = 0 + + packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways], is_static=True) + # for idx in range(loop4_nways): + # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized + packb_thrcomm_barrier_threads_arrived = tensor('int32', shape=[loop4_nways], is_static=True) + + packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], is_static=True) + packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], is_static=True) + + # The buffer for storing the starting offset of the packed B buffers for thread, + # indexed by the work ID of Loop5 + packb_start_offsets = tensor('int32', shape=[loop5_nways], is_static=True) + # The buffer for storing the starting offset of the packed A buffers for thread, + # indexed by the work ID of Loop3 + packa_start_offsets = tensor('int32', shape=[loop3_nways], is_static=True) + + # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 + packb_sizes = tensor('int32', shape=[loop5_nways], is_static=True) + # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 + packa_sizes = tensor('int32', shape=[loop3_nways], is_static=True) + + with hidet.script_module() as module: + # Helpers + @hidet.script + def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): + if n_way == 1: + start[0] = 0 + end[0] = n + return + all_start = 0 + all_end = n + size = all_end - all_start + + n_bf_whole = size // bf + n_bf_left = size % bf + + n_bf_lo = n_bf_whole // n_way + n_bf_hi = n_bf_whole // n_way + + n_th_lo = n_bf_whole % n_way + # If some partitions must have more block_factors than others, assign the slightly larger partitions to lower index threads + if n_th_lo != 0: + n_bf_lo += 1 + # Compute the actual widths (in units of rows/columns) of individual threads in the low and high groups + size_lo = n_bf_lo * bf + size_hi = n_bf_hi * bf + + # Pre-compute the starting indices of the low and high groups + lo_start = all_start + hi_start = all_start + n_th_lo * size_lo + + # Compute the start and end of individual threads' ranges + if work_id < n_th_lo: + start[0] = lo_start + work_id * size_lo + end[0] = lo_start + (work_id + 1) * size_lo + else: + start[0] = hi_start + (work_id - n_th_lo) * size_hi + end[0] = hi_start + (work_id - n_th_lo + 1) * size_hi + + # Add the remainder to the last thread's end + if work_id == n_way - 1: + end[0] += n_bf_left + + @hidet.script + def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, + start: ~int32, end: ~int32, inc: ~int32): + start[0] = work_id + end[0] = n + inc[0] = n_way + + @hidet.script + def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: + dim_left_now = dim - i + if dim_left_now <= b_alg: + b_now = dim_left_now + else: + b_now = b_alg + return b_now + + @hidet.script + def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: + return i != n_iter - 1 or n_left == 0 + + # Thread barrier + def thrcomm_barrier(tid: int32, barrier_sense: ~int32, + barrier_threads_arrived: ~int32, nthreads: int32): + if nthreads == 1: + return + orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED + + # Register the current thread's arrival by incrementing the counter + my_threads_arrived = cpu_atomic_add_fetch( + barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL + + if my_threads_arrived == nthreads: + barrier_threads_arrived[0] = 0 + cpu_atomic_fetch_xor(barrier_sense, 1, 3) # _ATOMIC_RELEASE + else: + while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE + pass + + @hidet.script + def micro_kernel( + a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, + is_first: bool + ): + c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) + c0 = avx_f32x8_load(~c[0, 0]) + c08 = avx_f32x8_load(~c[0, 8]) + c1 = avx_f32x8_load(~c[1, 0]) + c18 = avx_f32x8_load(~c[1, 8]) + c2 = avx_f32x8_load(~c[2, 0]) + c28 = avx_f32x8_load(~c[2, 8]) + c3 = avx_f32x8_load(~c[3, 0]) + c38 = avx_f32x8_load(~c[3, 8]) + c4 = avx_f32x8_load(~c[4, 0]) + c48 = avx_f32x8_load(~c[4, 8]) + c5 = avx_f32x8_load(~c[5, 0]) + c58 = avx_f32x8_load(~c[5, 8]) + + if is_first: + c0 = avx_f32x8_setzero() + c08 = avx_f32x8_setzero() + c1 = avx_f32x8_setzero() + c18 = avx_f32x8_setzero() + c2 = avx_f32x8_setzero() + c28 = avx_f32x8_setzero() + c3 = avx_f32x8_setzero() + c38 = avx_f32x8_setzero() + c4 = avx_f32x8_setzero() + c48 = avx_f32x8_setzero() + c5 = avx_f32x8_setzero() + c58 = avx_f32x8_setzero() + a_ptr = cast(a, ~float32) + b_ptr = cast(b, ~float32) + + niters = msize // 4 + nleft = msize % 4 + # Outer iterations with step 4 + for _ in range(niters): + # First of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) + + aa = avx_f32x8_broadcast(a_ptr) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 1) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 2) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 3) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 4) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 5) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Second of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr + 16) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 24) + + aa = avx_f32x8_broadcast(a_ptr + 6) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 7) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 8) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 9) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 10) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 11) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Third of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr + 32) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 40) + + aa = avx_f32x8_broadcast(a_ptr + 12) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 13) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 14) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 15) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 16) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 17) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Fourth of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr + 48) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 56) + + aa = avx_f32x8_broadcast(a_ptr + 18) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 19) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 20) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 21) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 22) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 23) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Increment the a_ptr and b_ptr for the next iteration of the outermost loop + a_ptr += 24 + b_ptr += 64 + + # process the edge + for _ in range(nleft): + aa = avx_f32x8_broadcast(a_ptr) + bb0to7 = avx_f32x8_load_aligned(b_ptr) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) + + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 1) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 2) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 3) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 4) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 5) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + a_ptr += 6 + b_ptr += 16 + + # Store the results + avx_f32x8_store(c_ptr, c0) + avx_f32x8_store(c_ptr + 8, c08) + + avx_f32x8_store(c_ptr + nsize, c1) + avx_f32x8_store(c_ptr + (nsize + 8), c18) + + avx_f32x8_store(c_ptr + 2 * nsize, c2) + avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) + + avx_f32x8_store(c_ptr + 3 * nsize, c3) + avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) + + avx_f32x8_store(c_ptr + 4 * nsize, c4) + avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) + + avx_f32x8_store(c_ptr + 5 * nsize, c5) + avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) + + @hidet.script + def macro_kernel( + a: packed_a_type, b: packed_b_type, c_in_macro: float32[m_size, n_size], ib: int32, jb: int32, + pb: int32 + ): + return + + #### Some setup code #### + packed_b_total_width = 0 + for workid_loop5 in range(loop5_nways): + loop5_start = 0 + loop5_end = 0 + thread_range_sub(loop5_nways, workid_loop5, n_size, NR, ~loop5_start, ~loop5_end) + curr_width = loop5_end - loop5_start + # packed_b_total_width += curr_width + # packb_start_offsets[workid_loop5] = temp_prev + # temp_prev += curr_width + packb_start_offsets[workid_loop5] = packed_b_total_width + packed_b_total_width += curr_width + + packed_b_height = KC + if packed_b_height > k_size: + packed_b_height = (k_size + NR - 1) // NR * NR + packed_b_total_size = packed_b_total_width * packed_b_height + + a_height_mr_partitions = (m_size + MR - 1) // MR + a_height_mr_remainder = m_size % MR + packed_a_individual_height = MC + packed_a_total_height = packed_a_individual_height * loop3_nways + # if packed_a_total_height > m_size: + # packed_a_total_height = a_height_mr_partitions * MR + packed_a_width = KC + if packed_a_width > k_size: + packed_a_width = (k_size + MR - 1) // MR * MR + packed_a_total_size = packed_a_total_height * packed_a_width + packed_a_individual_size = packed_a_width * packed_a_individual_height + + packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) + packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) + + packb_buf = as_tensor_pointer(packb_buf_ptr, dtype=float32, shape=[packed_b_total_size]) + packa_buf = as_tensor_pointer(packa_buf_ptr, dtype=float32, shape=[packed_a_total_size]) + + packed_a_type = tensor_type( + dtype='float32', + layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) + ) + + ################### Start of the main kernel ################### + @hidet.script + def matmul_kernel_x86_v2(a: float32[m_size, k_size], b: float32[k_size, n_size], + c: float32[m_size, n_size]): + b_width_nr_partitions = (n_size + NR - 1) // NR + b_width_nr_remainder = n_size % NR + # TODO: Since we(they, BLIS) use a memory broker... Allocate a little more memory is OK I think??? + # packed_b_individual_width = NC + + parallel_attr = 'p' + str(nthreads) + # The outermost loop spawning threads + for tidx in grid(nthreads, attrs=parallel_attr): + tid_5th_loop = tidx + work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) + comm_id_5th_loop = tid_5th_loop + + # Before each loop, we compute the work id and comm id for the loop after it. + comm_id_4th_loop = comm_id_5th_loop % loop4_nways + work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) + + my_start = -1 + my_end = -1 + b_alg_loop5 = NC + thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~my_start, ~my_end) + loop5_iter = my_start + while loop5_iter < my_end: + b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, my_end, NC) + + loop5_partition_c_width = b_alg_loop5 + loop5_partition_c_start_col = loop5_iter + + loop5_partition_b_width = b_alg_loop5 + loop5_partition_b_start_col = loop5_iter + + comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads + work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) + + # After getting the communicator and work id for the 3rd loop, + # we can now get the ids for the packing of B. + comm_id_packb = comm_id_3rd_loop + work_id_packb = comm_id_3rd_loop + packb_nways = loop3_nthreads + + # Below: The start of loop4 + b_alg_loop4 = KC + i_loop4 = 0 + while i_loop4 < k_size: + b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) + loop4_partition_b_height = b_alg_loop4 + loop4_partition_b_width = loop5_partition_b_width + loop4_partition_b_start_row = i_loop4 + loop4_partition_b_start_col = loop5_partition_b_start_col + + loop4_partition_a_start_col = i_loop4 + + is_first = (i_loop4 == 0) + + # Get the thread's partition of buffer and matrix + packed_b_buf = packb_buf + ( + packb_start_offsets[work_id_5th_loop] * packed_b_height) # TODO: Check this + loop4_partition_b = b + (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) + npanels_full_b = loop4_partition_b_width // NR + npanels_b_remainder = loop4_partition_b_width % NR + + npanels_b = npanels_full_b + (npanels_b_remainder != 0) + packedb_panel_stride = packed_b_height * NR + + # TODO: If passed, see if this barrier is really needed + thrcomm_barrier( + comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_4th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + packb_nthreads # TODO: Check this last parameter + ) + + # Start of the packing of B + for i_panel in range(npanels_b): + if i_panel % packb_nways != work_id_packb % packb_nways: + continue + packed_b_buf_curr = packed_b_buf + (i_panel * packedb_panel_stride) + + curr_panel_start = i_panel * NR + curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) + if curr_panel_width == NR: + k_iters = loop4_partition_b_height // 8 + k_remainder = loop4_partition_b_height % 8 + row = 0 + for k_iter in range(k_iters): + row = k_iter * 8 + b_panel = loop4_partition_b + (row * n_size + curr_panel_start) + b00 = avx_f32x8_load(b_panel) + b08 = avx_f32x8_load(b_panel + 8) + + avx_f32x8_store_aligned(packed_b_buf_curr, b00) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b08) + packed_b_buf_curr += 16 + + b10 = avx_f32x8_load(b_panel + n_size) + b18 = avx_f32x8_load(b_panel + (n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b10) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b18) + packed_b_buf_curr += 16 + + b20 = avx_f32x8_load(b_panel + (2 * n_size)) + b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b20) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b28) + + packed_b_buf_curr += 16 + + b30 = avx_f32x8_load(b_panel + (3 * n_size)) + b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b30) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b38) + + packed_b_buf_curr += 16 + + b40 = avx_f32x8_load(b_panel + (4 * n_size)) + b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b40) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b48) + + packed_b_buf_curr += 16 + + b50 = avx_f32x8_load(b_panel + (5 * n_size)) + b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b50) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b58) + + packed_b_buf_curr += 16 + + b60 = avx_f32x8_load(b_panel + (6 * n_size)) + b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b60) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b68) + + packed_b_buf_curr += 16 + + b70 = avx_f32x8_load(b_panel + (7 * n_size)) + b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buf_curr, b70) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b78) + + packed_b_buf_curr += 16 + + row += 8 + for remaining_row in range(k_remainder): + b_panel = loop4_partition_b + (row * n_size + curr_panel_start) + b00 = avx_f32x8_load(b_panel) + b08 = avx_f32x8_load(b_panel + 8) + + avx_f32x8_store_aligned(packed_b_buf_curr, b00) + avx_f32x8_store_aligned(packed_b_buf_curr + 8, b08) + packed_b_buf_curr += 16 + row += 1 + else: + packed_b_remaining_buf = packed_b_buf + (npanels_full_b * packedb_panel_stride) + if npanels_b_remainder > 0: + remain_col_start = npanels_full_b * NR + for remain_row in range(loop4_partition_b_height): + packed_b_remaining_buf_curr = packed_b_remaining_buf + (remain_row * NR) + for remain_col in range(npanels_b_remainder): + packed_b_remaining_buf_curr[0] = loop4_partition_b[ + (remain_row * n_size) + (remain_col_start + remain_col)] + packed_b_remaining_buf_curr += 1 + zero_fill_col = npanels_b_remainder + while zero_fill_col < NR: + packed_b_remaining_buf_curr[0] = 0.0 + packed_b_remaining_buf_curr += 1 + zero_fill_col += 1 + + # The barrier at the end of packing + thrcomm_barrier(comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_4th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + packb_nthreads + ) + + # TODO: Loop 3 should start here! + # Loop 3 + comm_id_macro = work_id_3rd_loop % macro_nthreads + work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) + + comm_id_packa = comm_id_macro + work_id_packa = comm_id_macro + packa_nways = macro_nthreads + + m_start_loop3 = 0 + m_end_loop3 = 0 + # Partition the dimension M for loop 3 + thread_range_sub( + loop3_nways, + work_id_3rd_loop, + m_size, + MR, + ~m_start_loop3, + ~m_end_loop3 + ) + + b_alg_loop3 = -1 + ii = m_start_loop3 + while ii < m_end_loop3: + b_alg_loop3 = determine_blocksize_f_sub(ii, m_size, MC) + + # Acquire the partition at Loop 3 + loop3_partition_c_start_row = ii + loop3_partition_a_start_row = ii + + loop3_partition_a_start_col = loop4_partition_a_start_col + loop3_partition_b_start_col = loop4_partition_b_start_col + loop3_partition_c_start_col = loop4_partition_b_start_col + + loop3_partition_height = b_alg_loop3 + # TODO: Is this right? + loop3_partition_a_width = loop4_partition_b_height + loop3_partition_b_width = loop4_partition_b_width + loop3_partition_a_height = b_alg_loop3 + loop3_partition_c_height = b_alg_loop3 + + loop3_partition_a = a + ( + loop3_partition_a_start_row * k_size + loop3_partition_a_start_col) + npanels_full_a = loop3_partition_a_height // MR + panel_a_remainder = loop3_partition_a_height % MR + + npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) + packeda_panel_stride = MR * loop3_partition_a_width + + # Get our position within the A panel + packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + packed_a_tensor = as_tensor_pointer( + packed_a_buf, + float32, + layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) + ) + + thrcomm_barrier( + comm_id_packa, + ~packa_thrcomm_barrier_sense[work_id_3rd_loop], + ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + packa_nthreads + ) + + # Pack A + for ii_panel in range(npanels_a): + if ii_panel % packa_nways != work_id_packa % packa_nways: + continue + a_curr_panel_row_start = ii_panel * MR + a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) + + if a_curr_panel_height == MR: # we unroll the packing by 8 + k_iters = loop3_partition_a_width // 8 + k_remainder = loop3_partition_a_width % 8 + col = 0 + for k_iter in range(k_iters): + col = k_iter * 8 + a_curr_panel_col = loop3_partition_a + ( + a_curr_panel_row_start * k_size + col) + v0 = avx_f32x8_load(a_curr_panel_col) + v1 = avx_f32x8_load(a_curr_panel_col + k_size) + v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) + v3 = avx_f32x8_load(a_curr_panel_col + (3 * k_size)) + v4 = avx_f32x8_load(a_curr_panel_col + (4 * k_size)) + v5 = avx_f32x8_load(a_curr_panel_col + (5 * k_size)) + + unpack0 = avx_f32x8_unpacklo(v0, v1) + unpack1 = avx_f32x8_unpackhi(v0, v1) + unpack2 = avx_f32x8_unpacklo(v2, v3) + unpack3 = avx_f32x8_unpackhi(v2, v3) + unpack4 = avx_f32x8_unpacklo(v4, v5) + unpack5 = avx_f32x8_unpackhi(v4, v5) + + shf0 = avx_f32x8_shuffle(unpack0, unpack2, 0x44) + shf1 = avx_f32x8_shuffle(unpack4, unpack0, 0xE4) + shf2 = avx_f32x8_shuffle(unpack2, unpack4, 0xEE) + shf3 = avx_f32x8_shuffle(unpack5, unpack1, 0xE4) + shf4 = avx_f32x8_shuffle(unpack3, unpack5, 0xEE) + shf5 = avx_f32x8_shuffle(unpack1, unpack3, 0x44) + + low_shf1 = avx_f32x8_cast_f32x4(shf1) + res0 = avx_f32x8_insert_f32x4(shf0, low_shf1, 0x1) + res1 = avx_f32x8_permute2f32x4(shf0, shf1, 0x31) + + low_shf5 = avx_f32x8_cast_f32x4(shf5) + res2 = avx_f32x8_insert_f32x4(shf2, low_shf5, 0x1) + res3 = avx_f32x8_permute2f32x4(shf2, shf5, 0x31) + + low_shf4 = avx_f32x8_cast_f32x4(shf4) + res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) + res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) + + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start, col], res0) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 2, col + 1], res2) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 4, col + 2], res4) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start, col + 4], res1) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 2, col + 5], res3) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 4, col + 6], res5) + + remaining_start_col = k_iters * 8 + for remain_off in range(k_remainder): + curr_remain_col = remaining_start_col + remain_off + for micropanel_row in range(MR): + packed_a_tensor[a_curr_panel_row_start + micropanel_row, curr_remain_col] = \ + loop3_partition_a[(micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] + else: + remain_start_row = npanels_a * MR + for remain_col in range(loop3_partition_a_width): + for remain_row in range(panel_a_remainder): + packed_a_tensor[remain_start_row + remain_row, remain_col] = \ + loop3_partition_a[(remain_row + remain_start_row) * k_size + remain_col] + remain_row = panel_a_remainder + while remain_row < MR: + packed_a_tensor[remain_start_row + remain_row, remain_col] = 0 + remain_row += 1 + + # This marks the end of the packing of A, or so I wish + # Now let's go to the macrokernel + # But first, barrier... + thrcomm_barrier( + comm_id_packa, + ~packa_thrcomm_barrier_sense[work_id_3rd_loop], + ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + packa_nthreads + ) + + comm_id_1st_loop = comm_id_macro % loop1_nthreads + work_id_1st_loop = comm_id_macro // (loop1_nthreads // loop1_nways) + + jr_nt = macro_nways + jr_tid = work_id_macro + ir_nt = loop1_nways + ir_tid = work_id_1st_loop + + jr_start = -1 + jr_end = -1 + ir_start = -1 + ir_end = -1 + jr_inc = -1 + ir_inc = -1 + + macro_m = loop3_partition_a_height + macro_n = loop3_partition_b_width + macro_k = loop3_partition_a_width + + n_iter = macro_n // NR + n_remainder = macro_n % NR + m_iter = macro_m // MR + m_remainder = macro_m % MR + + if n_remainder > 0: + n_iter += 1 + if m_remainder > 0: + m_iter += 1 + + thread_range_jrir( + work_id_macro, + macro_nways, + n_iter, + 1, + ~jr_start, + ~jr_end, + ~jr_inc + ) + + thread_range_jrir( + work_id_1st_loop, + loop1_nways, + m_iter, + 1, + ~ir_start, + ~ir_end, + ~ir_inc + ) + + # Some variables as in the original code... + # TODO: There must be some useless ones, delete after passing tests + rs_packeda = 1 + cs_packeda = MR + panel_dim_packeda = MR + ps_packed_a = packeda_panel_stride + rs_packedb = MR + cs_packedb = 1 + ps_packed_b = packedb_panel_stride + + rstep_a = ps_packed_a + cstep_b = ps_packed_b + + rstep_c = rs_packeda * MR + + cstep_c = NR + rstep_c = n_size * MR + + macro_c_cast = as_tensor_pointer( + ~c[loop3_partition_a_start_row, loop3_partition_b_start_col], + dtype=float32, + shape=(m_size, n_size) + ) + + temp_c = tensor( + scope=DeclareScope.Default, + dtype='float32', + layout=row_major(MR, NR), + is_static=True + ) + + + j = jr_start + while j < jr_end: + b1 = packed_b_buf + j * cstep_b + c1 = macro_c_cast + j * cstep_c + + n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder + b2 = b1 + # Loop over the m dimension, MR rows at a time + i = ir_start + while i < ir_end: + a1 = packed_a_buf + i * rstep_a + c11 = c1 + i * rstep_c + + m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder + + if m_cur == MR and n_cur == NR: + micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) + else: + for i in range(MR): + for j in range(NR): + temp_c[i, j] = 0 + micro_kernel(a1, b1, temp_c, macro_k, macro_m, macro_n, is_first) + if not is_first: + for mm in range(m_cur): + for nn in range(n_cur): + c11[mm, nn] += temp_c[mm, nn] + else: + for mm in range(m_cur): + for nn in range(n_cur): + c11[mm, nn] = temp_c[mm, nn] + i += ir_inc + j += jr_inc + ii += b_alg_loop3 + # End of loop4 + # According to the original code, there seems to be a barrier here + # TODO: Looks weird, check later, especially about whether it's really the ids of packb + # arrays that are used here + thrcomm_barrier( + comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_4th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + packb_nthreads + ) + i_loop4 += b_alg_loop4 + # End of loop5 + loop5_iter += b_alg_loop5 + + return + + assert isinstance(matmul_kernel_x86_v2, hidet.ir.Function) + matmul_kernel_x86_v2.kind = "cpu_kernel" + ir_module = module.ir_module() + return ir_module + + +class Matmulx86Op_v2(Operator): + def __init__(self, a: Tensor, b: Tensor): + if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): + raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) + task = MatmulF32Taskx86_v2(input_like(a, 'a'), input_like(b, 'b')) + super().__init__(inputs=[a, b], attributes={}, task=task) + + +def matmul_x86_v2(a: Tensor, b: Tensor) -> Tensor: + return Matmulx86Op_v2(a, b).outputs[0] From 569fb49254b5e56c9cdd342a4a4806d972f866de Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Thu, 17 Aug 2023 03:06:49 -0400 Subject: [PATCH 006/148] working on refactoring --- .../ops/matmul/matmul_f32_x86_refactored.py | 633 ++++++++++++++++++ 1 file changed, 633 insertions(+) create mode 100644 python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py new file mode 100644 index 000000000..67619213b --- /dev/null +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -0,0 +1,633 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union +from hidet.ir.dtypes import float32, int32 +from hidet.ir.expr import cast +from hidet.ir.module import IRModule +from hidet.ir.compute import TensorNode +from hidet.ir.primitives import avx_malloc +from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned +from hidet.ir.stmt import DeclareScope +from hidet.ir.task import Task +from hidet.ir.compute import compute, reduce +from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast +from hidet.ir.library import tune +from hidet.graph.operator import Operator, Tensor +from hidet.graph.ops.utils import broadcast_indices + + +class MatmulF32Taskx86_v3(Task): + + def __init__(self, a: TensorNode, b: TensorNode): + a_shape = a.const_shape + b_shape = b.const_shape + + if not a.type.dtype == float32 or not b.type.dtype == float32: + raise ValueError('Both inputs must be float32 tensors') + + if len(a_shape) < 2 or len(b_shape) < 2: + raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) + + self._assert( + a_shape[-1] == b_shape[-2], + msg=( + 'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' + ', got {} and {}'.format(a_shape, b_shape) + ), + ) + + self._assert( + can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), + msg=( + 'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' + 'got {} and {}'.format(a_shape, b_shape) + ), + ) + + k_size = a_shape[-1] + c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] + + c = compute( + name='c', + shape=c_shape, + fcompute=lambda *indices: reduce( + shape=[k_size], + fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] + * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], + reduce_type='sum', + ), + ) + + super().__init__( + name='matmul_f32_x86_v2', + inputs=[a, b], + outputs=[c], + attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, + ) + + def allow_epilogue(self) -> bool: + return True + + def allow_prologue(self) -> bool: + return False + + def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: + return tune.extract_ir_modules(self.schedule_matmulf32_x86) + + # @tune.space( + # 2, + # block_m=[2016, 3024], + # block_n=[64, 144, 192, 256, 384, 512, 592, 672, 752, 896, 1024], + # block_k=[96, 128, 256, 384, 512, 560, 688, 784], + # nthreads=[4, 8, 16, 32], + # ) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], nthreads=[8, 16]) + def schedule_matmulf32_x86( + self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) + ) -> IRModule: + import hidet + from hidet.ir.type import tensor_type + from hidet.lang import tensor, grid, as_tensor_pointer + from hidet.lang.layout import row_major, column_major + from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast + from hidet.lang.cpu import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store + from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned + from hidet.lang.cpu import avx_f32x4_store_aligned, avx_f32x4_load_aligned + from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi + from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 + from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 + from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor + + node_a, node_b = self.inputs[0], self.inputs[1] + a_shape = node_a.const_shape + b_shape = node_b.const_shape + m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] + + MR, NR = 6, 16 + + tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') + + packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC)) + packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR)) + + # Get the number of threads... + loop5_nways, loop3_nways, macro_nways, loop1_nways = ways + loop4_nways = 1 + nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways + + # Get the number of threads remaining at each level + loop5_nthreads = nthreads + loop4_nthreads = loop5_nthreads // loop5_nways + loop3_nthreads = loop4_nthreads + macro_nthreads = loop3_nthreads // loop3_nways + loop1_nthreads = macro_nthreads // macro_nways + + packb_nthreads = loop3_nthreads + packa_nthreads = macro_nthreads + + # TODO: Since Hidet doesn't support the parallel region syntax as in OpenMP, + # TODO: We instead use a loop to simulate the parallel region, with the "thread id" being the loop index. + outermost_iters = nthreads + + loop5_thrcomm_barrier_sense = 0 + loop5_thrcomm_barrier_threads_arrived = 0 + + packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways], is_static=True) + # for idx in range(loop4_nways): + # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized + packb_thrcomm_barrier_threads_arrived = tensor('int32', shape=[loop4_nways], is_static=True) + + packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], is_static=True) + packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], is_static=True) + + # The buffer for storing the starting offset of the packed B buffers for thread, + # indexed by the work ID of Loop5 + packb_start_offsets = tensor('int32', shape=[loop5_nways], is_static=True) + # The buffer for storing the starting offset of the packed A buffers for thread, + # indexed by the work ID of Loop3 + packa_start_offsets = tensor('int32', shape=[loop3_nways], is_static=True) + + # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 + packb_sizes = tensor('int32', shape=[loop5_nways], is_static=True) + # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 + packa_sizes = tensor('int32', shape=[loop3_nways], is_static=True) + + with hidet.script_module() as module: + # Helpers + @hidet.script + def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): + if n_way == 1: + start[0] = 0 + end[0] = n + return + all_start = 0 + all_end = n + size = all_end - all_start + + n_bf_whole = size // bf + n_bf_left = size % bf + + n_bf_lo = n_bf_whole // n_way + n_bf_hi = n_bf_whole // n_way + + n_th_lo = n_bf_whole % n_way + # If some partitions must have more block_factors than others, assign the slightly larger partitions to lower index threads + if n_th_lo != 0: + n_bf_lo += 1 + # Compute the actual widths (in units of rows/columns) of individual threads in the low and high groups + size_lo = n_bf_lo * bf + size_hi = n_bf_hi * bf + + # Pre-compute the starting indices of the low and high groups + lo_start = all_start + hi_start = all_start + n_th_lo * size_lo + + # Compute the start and end of individual threads' ranges + if work_id < n_th_lo: + start[0] = lo_start + work_id * size_lo + end[0] = lo_start + (work_id + 1) * size_lo + else: + start[0] = hi_start + (work_id - n_th_lo) * size_hi + end[0] = hi_start + (work_id - n_th_lo + 1) * size_hi + + # Add the remainder to the last thread's end + if work_id == n_way - 1: + end[0] += n_bf_left + + @hidet.script + def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, + start: ~int32, end: ~int32, inc: ~int32): + start[0] = work_id + end[0] = n + inc[0] = n_way + + @hidet.script + def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: + dim_left_now = dim - i + if dim_left_now <= b_alg: + b_now = dim_left_now + else: + b_now = b_alg + return b_now + + @hidet.script + def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: + return i != n_iter - 1 or n_left == 0 + + # Thread barrier + def thrcomm_barrier(tid: int32, barrier_sense: ~int32, + barrier_threads_arrived: ~int32, nthreads: int32): + if nthreads == 1: + return + orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED + + # Register the current thread's arrival by incrementing the counter + my_threads_arrived = cpu_atomic_add_fetch( + barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL + + if my_threads_arrived == nthreads: + barrier_threads_arrived[0] = 0 + cpu_atomic_fetch_xor(barrier_sense, 1, 3) # _ATOMIC_RELEASE + else: + while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE + pass + + @hidet.script + def micro_kernel( + a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, + is_first: bool + ): + c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) + c0 = avx_f32x8_load(~c[0, 0]) + c08 = avx_f32x8_load(~c[0, 8]) + c1 = avx_f32x8_load(~c[1, 0]) + c18 = avx_f32x8_load(~c[1, 8]) + c2 = avx_f32x8_load(~c[2, 0]) + c28 = avx_f32x8_load(~c[2, 8]) + c3 = avx_f32x8_load(~c[3, 0]) + c38 = avx_f32x8_load(~c[3, 8]) + c4 = avx_f32x8_load(~c[4, 0]) + c48 = avx_f32x8_load(~c[4, 8]) + c5 = avx_f32x8_load(~c[5, 0]) + c58 = avx_f32x8_load(~c[5, 8]) + + if is_first: + c0 = avx_f32x8_setzero() + c08 = avx_f32x8_setzero() + c1 = avx_f32x8_setzero() + c18 = avx_f32x8_setzero() + c2 = avx_f32x8_setzero() + c28 = avx_f32x8_setzero() + c3 = avx_f32x8_setzero() + c38 = avx_f32x8_setzero() + c4 = avx_f32x8_setzero() + c48 = avx_f32x8_setzero() + c5 = avx_f32x8_setzero() + c58 = avx_f32x8_setzero() + a_ptr = cast(a, ~float32) + b_ptr = cast(b, ~float32) + + niters = msize // 4 + nleft = msize % 4 + # Outer iterations with step 4 + for _ in range(niters): + # First of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) + + aa = avx_f32x8_broadcast(a_ptr) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 1) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 2) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 3) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 4) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 5) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Second of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr + 16) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 24) + + aa = avx_f32x8_broadcast(a_ptr + 6) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 7) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 8) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 9) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 10) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 11) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Third of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr + 32) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 40) + + aa = avx_f32x8_broadcast(a_ptr + 12) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 13) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 14) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 15) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 16) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 17) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Fourth of the 4 unrolled iterations + bb0to7 = avx_f32x8_load_aligned(b_ptr + 48) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 56) + + aa = avx_f32x8_broadcast(a_ptr + 18) + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 19) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 20) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 21) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 22) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 23) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + # Increment the a_ptr and b_ptr for the next iteration of the outermost loop + a_ptr += 24 + b_ptr += 64 + + # process the edge + for _ in range(nleft): + aa = avx_f32x8_broadcast(a_ptr) + bb0to7 = avx_f32x8_load_aligned(b_ptr) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) + + c0 = avx_f32x8_fmadd(aa, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + + aa = avx_f32x8_broadcast(a_ptr + 1) + c1 = avx_f32x8_fmadd(aa, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + + aa = avx_f32x8_broadcast(a_ptr + 2) + c2 = avx_f32x8_fmadd(aa, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + + aa = avx_f32x8_broadcast(a_ptr + 3) + c3 = avx_f32x8_fmadd(aa, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + + aa = avx_f32x8_broadcast(a_ptr + 4) + c4 = avx_f32x8_fmadd(aa, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + + aa = avx_f32x8_broadcast(a_ptr + 5) + c5 = avx_f32x8_fmadd(aa, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + + a_ptr += 6 + b_ptr += 16 + + # Store the results + avx_f32x8_store(c_ptr, c0) + avx_f32x8_store(c_ptr + 8, c08) + + avx_f32x8_store(c_ptr + nsize, c1) + avx_f32x8_store(c_ptr + (nsize + 8), c18) + + avx_f32x8_store(c_ptr + 2 * nsize, c2) + avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) + + avx_f32x8_store(c_ptr + 3 * nsize, c3) + avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) + + avx_f32x8_store(c_ptr + 4 * nsize, c4) + avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) + + avx_f32x8_store(c_ptr + 5 * nsize, c5) + avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) + + @hidet.script + def macro_kernel( + a: packed_a_type, b: packed_b_type, c_in_macro: float32[m_size, n_size], ib: int32, jb: int32, + pb: int32 + ): + return + + #### Some setup code #### + packed_b_total_width = 0 + for workid_loop5 in range(loop5_nways): + loop5_start = 0 + loop5_end = 0 + thread_range_sub(loop5_nways, workid_loop5, n_size, NR, ~loop5_start, ~loop5_end) + curr_width = loop5_end - loop5_start + # packed_b_total_width += curr_width + # packb_start_offsets[workid_loop5] = temp_prev + # temp_prev += curr_width + packb_start_offsets[workid_loop5] = packed_b_total_width + packed_b_total_width += curr_width + + packed_b_height = KC + if packed_b_height > k_size: + packed_b_height = (k_size + NR - 1) // NR * NR + packed_b_total_size = packed_b_total_width * packed_b_height + + a_height_mr_partitions = (m_size + MR - 1) // MR + a_height_mr_remainder = m_size % MR + packed_a_individual_height = MC + packed_a_total_height = packed_a_individual_height * loop3_nways + # if packed_a_total_height > m_size: + # packed_a_total_height = a_height_mr_partitions * MR + packed_a_width = KC + if packed_a_width > k_size: + packed_a_width = (k_size + MR - 1) // MR * MR + packed_a_total_size = packed_a_total_height * packed_a_width + packed_a_individual_size = packed_a_width * packed_a_individual_height + + packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) + packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) + + packb_buf = as_tensor_pointer(packb_buf_ptr, dtype=float32, shape=[packed_b_total_size]) + packa_buf = as_tensor_pointer(packa_buf_ptr, dtype=float32, shape=[packed_a_total_size]) + + packed_a_type = tensor_type( + dtype='float32', + layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) + ) + + + ##### Start of the loops around micro kernel ##### + @hidet.script + def gemm_pack_b( + loop4_partition_b: ~float32, + loop4_partition_b_width: int32, + loop4_partition_b_height: int32, + packed_b_buf: ~float32, + comm_id_packb: int32, work_id_packb: int32 + ): + + + + @hidet.script + def gemm_4th_loop(a: float32[m_size, k_size], + b: float32[k_size, n_size], + c: float32[k_size, n_size], + loop5_partition_b_width: int32, + loop5_partition_b_start_col: int32, + comm_id_4th_loop: int32, + work_id_4th_loop: int32, + work_id_5th_loop: int32): + b_alg_loop4 = KC + i_loop4 = 0 + + comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads + work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) + comm_id_packb = comm_id_3rd_loop + work_id_packb = comm_id_3rd_loop + # packb_nways = loop3_nthreads + + while i_loop4 < k_size: + b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) + loop4_partition_b_height = b_alg_loop4 + loop4_partition_b_width = loop5_partition_b_width + loop4_partition_b_start_row = i_loop4 + loop4_partition_b_start_col = loop5_partition_b_start_col + + loop4_partition_a_start_col = i_loop4 + is_first = (i_loop4 == 0) + # Get the thread's partition of the buffer and the matrix + packed_b_buf = packb_buf + ( + packb_start_offsets[work_id_5th_loop] * packed_b_height + ) + + loop4_partition_b = b + \ + (loop4_partition_b_start_row * n_size + + loop4_partition_b_start_col) + + # TODO: If passed, see if this barrier is really needed + thrcomm_barrier( + comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_4th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + packb_nthreads + ) + + # Start the packing of B + gemm_pack_b(loop4_partition_b, loop4_partition_b_width, + loop4_partition_b_height, packed_b_buf, + comm_id_packb, work_id_packb) + + + + + thrcomm_barrier( + comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_4th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + packb_nthreads + ) + i_loop4 += b_alg_loop4 + + + @hidet.script + def gemm_5th_loop(a: float32[m_size, k_size], + b: float32[k_size, n_size], + c: float32[m_size, n_size], + work_id_5th_loop: int32, + comm_id_5th_loop: int32): + comm_id_4th_loop = comm_id_5th_loop % loop4_nways + work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) + + loop5_my_start = -1 + loop5_my_end = -1 + thread_range_sub(loop5_nways, work_id_5th_loop, n_size, + NR, ~loop5_my_start, ~loop5_my_end) + loop5_iter = loop5_my_start + while loop5_iter < loop5_my_end: + b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, + loop5_my_end, NC) + loop5_partition_c_width = b_alg_loop5 + loop5_partition_c_start_col = loop5_iter + loop5_partition_b_width = b_alg_loop5, + loop5_partition_b_start_col = loop5_iter + gemm_4th_loop(a, b, c, + loop5_partition_b_width, + loop5_partition_b_start_col, + comm_id_4th_loop, + work_id_4th_loop, + work_id_5th_loop) + loop5_iter += b_alg_loop5 + + + ################### Start of the main kernel ################### + @hidet.script + def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], + c: float32[m_size, n_size]): + b_width_nr_partitions = (n_size + NR - 1) // NR + b_width_nr_remainder = n_size % NR + # TODO: Since we(they, BLIS) use a memory broker... Allocate a little more memory is OK I think??? + # packed_b_individual_width = NC + + parallel_attr = 'p' + str(nthreads) + # The outermost loop spawning threads + for tidx in grid(nthreads, attrs=parallel_attr): + tid_5th_loop = tidx + work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) + comm_id_5th_loop = tid_5th_loop + gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) + + + + return ir_module + + +class Matmulx86Op_v2(Operator): + def __init__(self, a: Tensor, b: Tensor): + if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): + raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) + task = MatmulF32Taskx86_v2(input_like(a, 'a'), input_like(b, 'b')) + super().__init__(inputs=[a, b], attributes={}, task=task) + + +def matmul_x86_v2(a: Tensor, b: Tensor) -> Tensor: + return Matmulx86Op_v2(a, b).outputs[0] From b32ea73deae0cf7b2921c76bb193577bfc0cd463 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 20 Aug 2023 06:18:26 -0400 Subject: [PATCH 007/148] ready to be tested on the eco server --- python/hidet/graph/ops/__init__.py | 1 + python/hidet/graph/ops/matmul/__init__.py | 4 + .../ops/matmul/matmul_f32_x86_refactored.py | 466 +++++++++++++++++- .../graph/ops/matmul/matmul_f32_x86_v3.py | 2 +- python/hidet/mat_new.py | 95 ++++ 5 files changed, 550 insertions(+), 18 deletions(-) create mode 100644 python/hidet/mat_new.py diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index a10466a6a..81865205a 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -11,6 +11,7 @@ # limitations under the License. # pylint: disable=redefined-builtin from .matmul import batch_matmul, matmul, matmul_x86 +from .matmul import matmul_x86_refactored from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 diff --git a/python/hidet/graph/ops/matmul/__init__.py b/python/hidet/graph/ops/matmul/__init__.py index 18c4da549..c33a08573 100644 --- a/python/hidet/graph/ops/matmul/__init__.py +++ b/python/hidet/graph/ops/matmul/__init__.py @@ -16,3 +16,7 @@ from .matmul_f32_x86 import matmul_x86 from .matmul_f32_x86 import MatmulF32Taskx86, Matmulx86Op + +from .matmul_f32_x86_refactored import Matmulx86Op_refactored, MatmulF32Taskx86_refactored +from .matmul_f32_x86_refactored import matmul_x86_refactored + diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 67619213b..f0d75f6fa 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -25,7 +25,7 @@ from hidet.graph.ops.utils import broadcast_indices -class MatmulF32Taskx86_v3(Task): +class MatmulF32Taskx86_refactored(Task): def __init__(self, a: TensorNode, b: TensorNode): a_shape = a.const_shape @@ -447,12 +447,7 @@ def micro_kernel( avx_f32x8_store(c_ptr + 5 * nsize, c5) avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) - @hidet.script - def macro_kernel( - a: packed_a_type, b: packed_b_type, c_in_macro: float32[m_size, n_size], ib: int32, jb: int32, - pb: int32 - ): - return + #### Some setup code #### packed_b_total_width = 0 @@ -497,15 +492,434 @@ def macro_kernel( ##### Start of the loops around micro kernel ##### + # gemm_macro(packed_a_buf, + # packed_b, + # c, + # loop3_partition_a_height, + # loop3_partition_b_width, + # loop3_partition_a_width, + # comm_id_macro, + # work_id_macro + # ) + @hidet.script + def gemm_macro( + packed_a: ~float32, + packed_b: ~float32, + c: float32[m_size, n_size], + c_row_off: int32, + c_col_off: int32, + macro_m: int32, + macro_n: int32, + macro_k: int32, + ps_packed_a, + ps_packed_b, + comm_id_macro: int32, + work_id_macro: int32, + is_first: bool + ): + comm_id_1st_loop = comm_id_macro % loop1_nthreads + work_id_1st_loop = comm_id_macro // (loop1_nthreads // loop1_nways) + + n_iter = macro_n // NR + n_remainder = macro_n % NR + m_iter = macro_m // MR + m_remainder = macro_m % MR + + if n_remainder > 0: + n_iter += 1 + if m_remainder > 0: + m_iter += 1 + + jr_start = -1 + jr_end = -1 + ir_start = -1 + ir_end = -1 + jr_inc = -1 + ir_inc = -1 + + thread_range_jrir( + work_id_macro, + macro_nways, + n_iter, + 1, + ~jr_start, + ~jr_end, + ~jr_inc + ) + + thread_range_jrir( + work_id_1st_loop, + m_iter, + 1, + ~ir_start, + ~ir_end, + ~ir_inc + ) + + rs_packeda = 1 + rstep_a = ps_packed_a + cstep_b = ps_packed_b + + cstep_c = NR + rstep_c = n_size * MR + + macro_c_cast = as_tensor_pointer( + ~c[c_row_off, c_col_off], + dtype=float32, + shape=(m_size, n_size) + ) + temp_c = tensor(scope=DeclareScope.Default, + dtype=float32, + layout=row_major(MR, NR), + is_static=True) + j = jr_start + while j < jr_end: + b1 = packed_b + j * cstep_b + c1 = macro_c_cast + j * cstep_c + + n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder + # Loop over the m dimension, MR rows at a time + i = ir_start + while i < ir_end: + a1 = packed_a + i * rstep_a + c11 = c1 + i * rstep_a + m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder + + if m_cur == MR and n_cur == NR: + micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) + else: + for i, j in grid(MR, NR): + temp_c[i, j] = 0.0 + micro_kernel(a1, b1, temp_c, macro_k, macro_m, macro_n, is_first) + if not is_first: + for mm, nn in grid(m_cur, n_cur): + c11[mm, nn] += temp_c[mm, nn] + else: + for mm, nn in grid(m_cur, n_cur): + c11[mm, nn] = temp_c[mm, nn] + + i += ir_inc + j += jr_inc + + + @hidet.script + def gemm_3rd_loop( + a: float32[m_size, k_size], + packed_b: ~float32, + c: float32[m_size, n_size], + loop3_partition_a_start_col: int32, + loop3_partition_b_start_col: int32, + loop3_partition_a_width: int32, + loop3_partition_b_width: int32, + comm_id_3rd_loop: int32, + work_id_3rd_loop: int32, + is_first: bool + ): + comm_id_macro = work_id_3rd_loop % macro_nthreads + work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) + comm_id_packa = comm_id_macro + work_id_packa = comm_id_macro + packa_nways = macro_nthreads + + m_start_loop3 = 0 + m_end_loop3 = 0 + thread_range_sub( + loop3_nways, + work_id_3rd_loop, + m_size, + MR, + ~m_start_loop3, + ~m_end_loop3 + ) + ii = m_start_loop3 + while ii < m_end_loop3: + b_alg_loop3 = determine_blocksize_f_sub( + ii, m_size, MC + ) + # Acquire the partition at loop 3 + loop3_partition_a_start_row = ii + loop3_partition_a_height = b_alg_loop3 + + loop3_partition_a = a + ( + loop3_partition_a_start_row * k_size + + loop3_partition_a_start_col + ) + + # Get our position within the packed A global buffer + packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + + # TODO: If passed, see if this barrier is necessary + thrcomm_barrier( + comm_id_packa, + ~packa_thrcomm_barrier_sense[work_id_3rd_loop], + ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + packa_nthreads + ) + + gemm_pack_a( + loop3_partition_a, + loop3_partition_a_width, + loop3_partition_a_height, + packed_a_buf, + comm_id_packa, + work_id_packa, + packa_nways + ) + + # This marks the end of the packing of A, + # so a barrier is needed + thrcomm_barrier( + comm_id_packa, + ~packa_thrcomm_barrier_sense[work_id_3rd_loop], + ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + packa_nthreads + ) + + gemm_macro(packed_a_buf, + packed_b, + c, + loop3_partition_a_start_row, + loop3_partition_b_start_col, + loop3_partition_a_height, + loop3_partition_b_width, + loop3_partition_a_width, + MR * loop3_partition_a_width, + packed_b_height * NR, + comm_id_macro, + work_id_macro, + is_first + ) + + @hidet.script + def gemm_pack_a( + loop3_partition_a: ~float32, + loop3_partition_a_width: int32, + loop3_partition_a_height: int32, + packed_a_buf: ~float32, + comm_id_packa: int32, + work_id_packa: int32, + packa_nways: int32 + ): + packed_a_tensor = as_tensor_pointer( + packed_a_buf, + float32, + layout=row_major(packed_a_individual_height // MR, 1) * + column_major(MR, packed_a_width) + ) + + + npanels_full_a = loop3_partition_a_height // MR + panel_a_remainder = loop3_partition_a_height % MR + + npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) + for ii_panel in range(npanels_a): + if ii_panel % packa_nways != work_id_packa % packa_nways: + continue + a_curr_panel_row_start = ii_panel * MR + a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) + + if a_curr_panel_height == MR: # unroll the packing by 8 + k_iters = loop3_partition_a_width // 8 + k_remainder = loop3_partition_a_width % 8 + col = 0 + for k_iter in range(k_iters): + col = k_iter * 8 + a_curr_panel_col = loop3_partition_a + ( + a_curr_panel_row_start * k_size + col + ) + v0 = avx_f32x8_load(a_curr_panel_col) + v1 = avx_f32x8_load(a_curr_panel_col * k_size) + v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) + v3 = avx_f32x8_load(a_curr_panel_col + (3 * k_size)) + v4 = avx_f32x8_load(a_curr_panel_col + (4 * k_size)) + v5 = avx_f32x8_load(a_curr_panel_col + (5 * k_size)) + + unpack0 = avx_f32x8_unpacklo(v0, v1) + unpack1 = avx_f32x8_unpackhi(v0, v1) + unpack2 = avx_f32x8_unpacklo(v2, v3) + unpack3 = avx_f32x8_unpackhi(v2, v3) + unpack4 = avx_f32x8_unpacklo(v4, v5) + unpack5 = avx_f32x8_unpackhi(v4, v5) + + shf0 = avx_f32x8_shuffle(unpack0, unpack2, 0x44) + shf1 = avx_f32x8_shuffle(unpack4, unpack0, 0xE4) + shf2 = avx_f32x8_shuffle(unpack2, unpack4, 0xEE) + shf3 = avx_f32x8_shuffle(unpack5, unpack1, 0xE4) + shf4 = avx_f32x8_shuffle(unpack3, unpack5, 0xEE) + shf5 = avx_f32x8_shuffle(unpack1, unpack3, 0x44) + + low_shf1 = avx_f32x8_cast_f32x4(shf1) + res0 = avx_f32x8_insert_f32x4(shf0, low_shf1, 0x1) + res1 = avx_f32x8_permute2f32x4(shf0, shf1, 0x31) + + low_shf5 = avx_f32x8_cast_f32x4(shf5) + res2 = avx_f32x8_insert_f32x4(shf2, low_shf5, 0x1) + res3 = avx_f32x8_permute2f32x4(shf2, shf5, 0x31) + + low_shf4 = avx_f32x8_cast_f32x4(shf4) + res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) + res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) + + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start, col], + res0 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 2, + col + 1], + res2 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 4, + col + 2], + res4) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start, + col + 4], + res1 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 2, + col + 5], + res3 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 4, + col + 6], + res5 + ) + remaining_start_col = k_iters * 8 + for remain_off in range(k_remainder): + curr_remain_col = remaining_start_col + remain_off + for micropanel_row in range(MR): + packed_a_tensor[a_curr_panel_row_start + micropanel_row, + curr_remain_col] = \ + loop3_partition_a[(micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] + else: + remain_start_row = npanels_a * MR + for remain_col in range(loop3_partition_a_width): + for remain_row in range(panel_a_remainder): + packed_a_tensor[remain_start_row + remain_row, remain_col] = \ + loop3_partition_a[(remain_row + remain_start_row) * k_size + remain_col] + remain_row = panel_a_remainder + while remain_row < MR: + packed_a_tensor[remain_start_row + remain_row, remain_col] = 0 + remain_row += 1 + + @hidet.script def gemm_pack_b( loop4_partition_b: ~float32, loop4_partition_b_width: int32, loop4_partition_b_height: int32, packed_b_buf: ~float32, - comm_id_packb: int32, work_id_packb: int32 + comm_id_packb: int32, work_id_packb: int32, + packb_nways: int32 ): - + npanels_full_b = loop4_partition_b_width // NR + npanels_b_remainder = loop4_partition_b_width % NR + + npanels_b = npanels_full_b + (npanels_b_remainder != 0) + packedb_panel_stride = packed_b_height * NR + + # Loop for the packing of B + for i_panel in range(npanels_b): + if i_panel % packb_nways != work_id_packb % packb_nways: + continue + packed_b_buff_curr = packed_b_buf + (i_panel * packedb_panel_stride) + curr_panel_start = i_panel * NR + curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) + if curr_panel_width == NR: + k_iters = loop4_partition_b_height // 8 + k_remainder = loop4_partition_b_height % 8 + row = 0 + for k_iter in range(k_iters): + row = k_iter * 8 + b_panel = loop4_partition_b + (row * n_size + curr_panel_start) + b00 = avx_f32x8_load(b_panel) + b08 = avx_f32x8_load(b_panel + 8) + + avx_f32x8_store_aligned(packed_b_buff_curr, b00) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) + packed_b_buff_curr += 16 + + b10 = avx_f32x8_load(b_panel + n_size) + b18 = avx_f32x8_load(b_panel + (n_size * 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b10) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b18) + packed_b_buff_curr += 16 + + b20 = avx_f32x8_load(b_panel + (2 * n_size)) + b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b20) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b28) + packed_b_buff_curr += 16 + + b30 = avx_f32x8_load(b_panel + (3 * n_size)) + b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b30) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b38) + packed_b_buff_curr += 16 + + b40 = avx_f32x8_load(b_panel + (4 * n_size)) + b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b40) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b48) + packed_b_buff_curr += 16 + + b50 = avx_f32x8_load(b_panel + (5 * n_size)) + b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b50) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b58) + packed_b_buff_curr += 16 + + b60 = avx_f32x8_load(b_panel + (6 * n_size)) + b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b60) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b68) + packed_b_buff_curr += 16 + + b70 = avx_f32x8_load(b_panel + (7 * n_size)) + b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b70) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b78) + + packed_b_buff_curr += 16 + + row = k_iters + 8 + for _ in range(k_remainder): + b_panel = loop4_partition_b + (row * n_size + curr_panel_start) + b00 = avx_f32x8_load(b_panel) + b08 = avx_f32x8_load(b_panel + 8) + avx_f32x8_store_aligned(packed_b_buff_curr, b00) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) + packed_b_buff_curr += 16 + row += 1 + + else: + packed_b_remaining_buf = packed_b_buf + (npanels_full_b * packedb_panel_stride) + if npanels_b_remainder > 0: + # TODO: I think this if should always be true if this is executed? + remain_col_start = npanels_full_b * NR + for remain_row in range(loop4_partition_b_height): + packed_b_remaining_buf_curr = packed_b_remaining_buf + (remain_row * NR) + for remain_col in range(npanels_b_remainder): + packed_b_remaining_buf_curr[0] = loop4_partition_b[ + (remain_row * n_size) + (remain_col_start + remain_col) + ] + packed_b_remaining_buf_curr += 1 + zero_fill_col = npanels_b_remainder + while zero_fill_col < NR: + packed_b_remaining_buf_curr[0] = 0.0 + packed_b_remaining_buf_curr += 1 + zero_fill_col += 1 @hidet.script @@ -553,19 +967,35 @@ def gemm_4th_loop(a: float32[m_size, k_size], ) # Start the packing of B + # TODO: Check this assertion: + # TODO: loop3_nthreads == packb_nthreads gemm_pack_b(loop4_partition_b, loop4_partition_b_width, loop4_partition_b_height, packed_b_buf, - comm_id_packb, work_id_packb) - + comm_id_packb, work_id_packb, loop3_nthreads) + # The barrier at the end of the packing of B thrcomm_barrier( comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_4th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], packb_nthreads ) + + # TODO: The loop3 and beyond should start here? + gemm_3rd_loop( + a, packed_b_buf, c, + loop4_partition_a_start_col, + loop4_partition_b_start_col, + loop4_partition_b_height, + loop4_partition_b_width, + comm_id_3rd_loop, + work_id_3rd_loop, + is_first + ) + + i_loop4 += b_alg_loop4 @@ -616,18 +1046,20 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], comm_id_5th_loop = tid_5th_loop gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) + assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) + matmul_kernel_x86_v3.kind = "cpu_kernel" + return module.ir_module() - - return ir_module + # return ir_module -class Matmulx86Op_v2(Operator): +class Matmulx86Op_refactored(Operator): def __init__(self, a: Tensor, b: Tensor): if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) - task = MatmulF32Taskx86_v2(input_like(a, 'a'), input_like(b, 'b')) + task = MatmulF32Taskx86_refactored(input_like(a, 'a'), input_like(b, 'b')) super().__init__(inputs=[a, b], attributes={}, task=task) -def matmul_x86_v2(a: Tensor, b: Tensor) -> Tensor: - return Matmulx86Op_v2(a, b).outputs[0] +def matmul_x86_refactored(a: Tensor, b: Tensor) -> Tensor: + return Matmulx86Op_refactored(a, b).outputs[0] diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py index 0c3abff7f..b65690138 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py @@ -729,7 +729,7 @@ def matmul_kernel_x86_v2(a: float32[m_size, k_size], b: float32[k_size, n_size], npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) packeda_panel_stride = MR * loop3_partition_a_width - # Get our position within the A panel + # Get our position within the packed A global buffer packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) packed_a_tensor = as_tensor_pointer( packed_a_buf, diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py new file mode 100644 index 000000000..587f18a4b --- /dev/null +++ b/python/hidet/mat_new.py @@ -0,0 +1,95 @@ +import numpy as np +import pytest + +import hidet +from hidet.graph.ops import matmul_x86_refactored +from hidet.testing import check_binary +from hidet.option import debug_cache_tuning + +import tvm +from tvm import te, auto_scheduler + +@auto_scheduler.register_workload +def matmul_ansor(M, K, N, dtype): + A = te.placeholder((M, K), name='A', dtype=dtype) + B = te.placeholder((K, N), name='B', dtype=dtype) + + k = te.reduce_axis((0, K), name='k') + rst = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), + name='matmul_ansor', + attrs={"layout_free_placeholders": [B], + # Enable automatic layout transform for B} + ) + + return [A, B, rst] + +target = tvm.target.Target("llvm -mcpu=core-avx2") +debug_cache_tuning(True) +hidet.option.search_space(0) +for m, n, k in [(384, 256, 256), (512, 512, 512), (1024, 1024, 1024)]: + a = hidet.randn([m, k], device='cpu') + b = hidet.randn([k, n], device='cpu') + x1 = hidet.symbol_like(a) + x2 = hidet.symbol_like(b) + y = matmul_x86_refactored(x1, x2) + graph: hidet.FlowGraph = hidet.trace_from( + y, inputs=[x1, x2] + ) + opt_graph = hidet.graph.optimize(graph) + compiled_func = opt_graph.nodes[0].compiled_task + c = compiled_func(a, b) + np.testing.assert_allclose( + actual=c.numpy(), + desired=a.numpy() @ b.numpy(), + rtol=1e-3, + atol=1e-3 + ) + + print("passed for m={}, n={}, k={}".format(m, n, k)) + + hidet_latency = hidet.utils.benchmark_func( + lambda: compiled_func(a, b), repeat=50 + ) + np_latency = hidet.utils.benchmark_func( + lambda: a.numpy() @ b.numpy(), repeat=50 + ) + + ansor_task = tvm.auto_scheduler.SearchTask( + func=matmul_ansor, args=(m, k, n, "float32"), target=target + ) + log_file = f"matmul_{m}x{k}x{n}.json" + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=1000, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + ) + + ansor_task.tune(tune_option) + sch, args = ansor_task.apply_best(log_file) + with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: + f.write(str(tvm.lower(sch, args, simple_mode=True))) + ansor_func = tvm.build(sch, args, target) + dev = tvm.cpu() + a_tvm = tvm.nd.array(a.numpy(), device=dev) + b_tvm = tvm.nd.array(b.numpy(), device=dev) + c_tvm = tvm.nd.empty((m, n), device=dev) + + ansor_func(a_tvm, b_tvm, c_tvm) + + np.testing.assert_allclose( + actual=c_tvm.numpy(), + desired=a_tvm.numpy() @ b_tvm.numpy(), + rtol=1e-3, + atol=1e-3 + ) + + ansor_latency = hidet.utils.benchmark_func( + lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 + ) + + with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: + f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") + f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") + f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") From dbbb2b60b99590de1869092e2b218bde00dcefd8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 20 Aug 2023 08:34:25 -0400 Subject: [PATCH 008/148] fix stupid error --- python/hidet/mat_new.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py index 587f18a4b..fc36bc873 100644 --- a/python/hidet/mat_new.py +++ b/python/hidet/mat_new.py @@ -21,6 +21,7 @@ def matmul_ansor(M, K, N, dtype): name='matmul_ansor', attrs={"layout_free_placeholders": [B], # Enable automatic layout transform for B} + } ) return [A, B, rst] From 014f5c1c2beb21f8ec9d457dad0ded9e2bc35661 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 20 Aug 2023 08:36:26 -0400 Subject: [PATCH 009/148] .. --- python/hidet/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py index fc36bc873..0303a7b29 100644 --- a/python/hidet/mat_new.py +++ b/python/hidet/mat_new.py @@ -35,7 +35,7 @@ def matmul_ansor(M, K, N, dtype): x1 = hidet.symbol_like(a) x2 = hidet.symbol_like(b) y = matmul_x86_refactored(x1, x2) - graph: hidet.FlowGraph = hidet.trace_from( + graph = hidet.trace_from( y, inputs=[x1, x2] ) opt_graph = hidet.graph.optimize(graph) From 2d82325b21d22ae1f7a24bf7ea1e456381b39c0b Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 21 Aug 2023 08:38:18 -0400 Subject: [PATCH 010/148] fix more error --- python/hidet/ir/primitives/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/ir/primitives/__init__.py b/python/hidet/ir/primitives/__init__.py index 74ea7e2a7..49866a0af 100644 --- a/python/hidet/ir/primitives/__init__.py +++ b/python/hidet/ir/primitives/__init__.py @@ -26,7 +26,6 @@ from . import cpu from .cpu import avx_f32x4_store, avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_setzero from .cpu import avx_free, avx_malloc -from .cpu import openmp_get_num_threads, openmp_get_thread_num # cuda primitive functions and variables from . import cuda From 11c9e70d4e92ecae7d952f363e1ba6de9e1ed42d Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 21 Aug 2023 08:43:56 -0400 Subject: [PATCH 011/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index f0d75f6fa..0c9bcd862 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -213,10 +213,12 @@ def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, @hidet.script def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: dim_left_now = dim - i + b_now = -1 if dim_left_now <= b_alg: b_now = dim_left_now else: b_now = b_alg + assert b_now >= 0 return b_now @hidet.script From 4586e890bc097f95ff755ae79d44e5d615379ee8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 00:43:48 -0400 Subject: [PATCH 012/148] fixing hidet script error --- .../ops/matmul/matmul_f32_x86_refactored.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 0c9bcd862..60f67480a 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -456,7 +456,40 @@ def micro_kernel( for workid_loop5 in range(loop5_nways): loop5_start = 0 loop5_end = 0 - thread_range_sub(loop5_nways, workid_loop5, n_size, NR, ~loop5_start, ~loop5_end) + # thread_range_sub(loop5_nways, workid_loop5, n_size, NR, ~loop5_start, ~loop5_end) + # TODO: For now, substitute the above func call with code + if loop5_nways == 1: + loop5_start = 0 + loop5_end = n_size + else: + all_start = 0 + all_end = n_size + size = all_end - all_start + n_bf_whole = n_size // NR + n_bf_left = n_size % NR + n_bf_lo = n_bf_whole // loop5_nways + n_bf_hi = n_bf_whole // loop5_nways + + n_th_lo = n_bf_whole % loop5_nways + if n_th_lo != 0: + n_bf_lo += 1 + size_lo = n_bf_lo * NR + size_hi = n_bf_hi * NR + + lo_start = all_start + hi_start = all_start + n_th_lo * size_lo + + if workid_loop5 < n_th_lo: + loop5_start = lo_start + workid_loop5 * size_lo + loop5_end = lo_start + (workid_loop5 + 1) * size_lo + else: + loop5_start = hi_start + (workid_loop5 - n_th_lo) * size_hi + loop5_end = hi_start + (workid_loop5 - n_th_lo + 1) * size_hi + + if workid_loop5 == loop5_nways - 1: + loop5_end += n_bf_left + + curr_width = loop5_end - loop5_start # packed_b_total_width += curr_width # packb_start_offsets[workid_loop5] = temp_prev From 65c3b9d900236d49a4bd3d30628236534ec920b2 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 00:57:40 -0400 Subject: [PATCH 013/148] ...: --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 60f67480a..b86a59540 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -151,7 +151,7 @@ def schedule_matmulf32_x86( # The buffer for storing the starting offset of the packed B buffers for thread, # indexed by the work ID of Loop5 - packb_start_offsets = tensor('int32', shape=[loop5_nways], is_static=True) + packb_start_offsets = tensor('int32', shape=[loop5_nways, 1], is_static=True) # The buffer for storing the starting offset of the packed A buffers for thread, # indexed by the work ID of Loop3 packa_start_offsets = tensor('int32', shape=[loop3_nways], is_static=True) @@ -494,7 +494,7 @@ def micro_kernel( # packed_b_total_width += curr_width # packb_start_offsets[workid_loop5] = temp_prev # temp_prev += curr_width - packb_start_offsets[workid_loop5] = packed_b_total_width + packb_start_offsets[workid_loop5, 0] = packed_b_total_width packed_b_total_width += curr_width packed_b_height = KC @@ -986,7 +986,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], is_first = (i_loop4 == 0) # Get the thread's partition of the buffer and the matrix packed_b_buf = packb_buf + ( - packb_start_offsets[work_id_5th_loop] * packed_b_height + packb_start_offsets[work_id_5th_loop, 0] * packed_b_height ) loop4_partition_b = b + \ From 286c1074ca0f82ff3d37474011f5cfea1de4f812 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 01:11:17 -0400 Subject: [PATCH 014/148] .... --- .../ops/matmul/matmul_f32_x86_refactored.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index b86a59540..dcb937279 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -141,25 +141,25 @@ def schedule_matmulf32_x86( loop5_thrcomm_barrier_sense = 0 loop5_thrcomm_barrier_threads_arrived = 0 - packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways], is_static=True) + packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways]) # for idx in range(loop4_nways): # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized - packb_thrcomm_barrier_threads_arrived = tensor('int32', shape=[loop4_nways], is_static=True) + packb_thrcomm_barrier_threads_arrived = tensor('int32', shape=[loop4_nways]) - packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], is_static=True) - packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], is_static=True) + packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways]) + packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways]) # The buffer for storing the starting offset of the packed B buffers for thread, # indexed by the work ID of Loop5 - packb_start_offsets = tensor('int32', shape=[loop5_nways, 1], is_static=True) + packb_start_offsets = tensor('int32', shape=[loop5_nways, 1]) # The buffer for storing the starting offset of the packed A buffers for thread, # indexed by the work ID of Loop3 - packa_start_offsets = tensor('int32', shape=[loop3_nways], is_static=True) + packa_start_offsets = tensor('int32', shape=[loop3_nways]) # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 - packb_sizes = tensor('int32', shape=[loop5_nways], is_static=True) + packb_sizes = tensor('int32', shape=[loop5_nways]) # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 - packa_sizes = tensor('int32', shape=[loop3_nways], is_static=True) + packa_sizes = tensor('int32', shape=[loop3_nways]) with hidet.script_module() as module: # Helpers From bfacaf81a6723d686b0644725a8cad38f27dd56c Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 09:42:38 -0400 Subject: [PATCH 015/148] ... --- .../ops/matmul/matmul_f32_x86_refactored.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index dcb937279..7386b6516 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -226,6 +226,7 @@ def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: return i != n_iter - 1 or n_left == 0 # Thread barrier + @hidet.script def thrcomm_barrier(tid: int32, barrier_sense: ~int32, barrier_threads_arrived: ~int32, nthreads: int32): if nthreads == 1: @@ -490,17 +491,27 @@ def micro_kernel( loop5_end += n_bf_left - curr_width = loop5_end - loop5_start + # curr_width = loop5_end - loop5_start + # # packed_b_total_width += curr_width + # # packb_start_offsets[workid_loop5] = temp_prev + # # temp_prev += curr_width + # packb_start_offsets[workid_loop5] = packed_b_total_width # packed_b_total_width += curr_width - # packb_start_offsets[workid_loop5] = temp_prev - # temp_prev += curr_width - packb_start_offsets[workid_loop5, 0] = packed_b_total_width - packed_b_total_width += curr_width + + # packed_b_individual_width = min(NC, n_size) packed_b_height = KC if packed_b_height > k_size: - packed_b_height = (k_size + NR - 1) // NR * NR + packed_b_height = k_size + # packed_b_height = (k_size + NR - 1) // NR * NR + # packed_b_total_size = packed_b_total_width * packed_b_height + packed_b_width = NC + if packed_b_width > n_size: + packed_b_widht = (n_size + NR - 1) // NR * NR + + packed_b_total_width = packed_b_width * loop5_nways packed_b_total_size = packed_b_total_width * packed_b_height + packed_b_individual_size = packed_b_width * packed_b_height a_height_mr_partitions = (m_size + MR - 1) // MR a_height_mr_remainder = m_size % MR @@ -985,8 +996,11 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) # Get the thread's partition of the buffer and the matrix + # packed_b_buf = packb_buf + ( + # packb_start_offsets[work_id_5th_loop, 0] * packed_b_height + # ) packed_b_buf = packb_buf + ( - packb_start_offsets[work_id_5th_loop, 0] * packed_b_height + packed_b_individual_size * work_id_5th_loop ) loop4_partition_b = b + \ From 8246466936398cdc773e07ee95b9ef488950a739 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 09:46:34 -0400 Subject: [PATCH 016/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 7386b6516..e77b30bc2 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -557,8 +557,8 @@ def gemm_macro( macro_m: int32, macro_n: int32, macro_k: int32, - ps_packed_a, - ps_packed_b, + ps_packed_a: int32, + ps_packed_b: int32, comm_id_macro: int32, work_id_macro: int32, is_first: bool From 75180421bce41e8565356e6bc336b27e16307e88 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 09:58:39 -0400 Subject: [PATCH 017/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index e77b30bc2..271f0be54 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1097,7 +1097,8 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) matmul_kernel_x86_v3.kind = "cpu_kernel" - return module.ir_module() + ir_module = module.ir_module() + return ir_module # return ir_module From f8a97b2a6746aa98ca8ce5e4b17fa5de0e1c10ba Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 10:15:52 -0400 Subject: [PATCH 018/148] fixing strange error --- .../ops/matmul/matmul_f32_x86_refactored.py | 438 +++++++++--------- 1 file changed, 225 insertions(+), 213 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 271f0be54..378a8ece8 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -163,6 +163,7 @@ def schedule_matmulf32_x86( with hidet.script_module() as module: # Helpers + @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): if n_way == 1: @@ -547,194 +548,6 @@ def micro_kernel( # comm_id_macro, # work_id_macro # ) - @hidet.script - def gemm_macro( - packed_a: ~float32, - packed_b: ~float32, - c: float32[m_size, n_size], - c_row_off: int32, - c_col_off: int32, - macro_m: int32, - macro_n: int32, - macro_k: int32, - ps_packed_a: int32, - ps_packed_b: int32, - comm_id_macro: int32, - work_id_macro: int32, - is_first: bool - ): - comm_id_1st_loop = comm_id_macro % loop1_nthreads - work_id_1st_loop = comm_id_macro // (loop1_nthreads // loop1_nways) - - n_iter = macro_n // NR - n_remainder = macro_n % NR - m_iter = macro_m // MR - m_remainder = macro_m % MR - - if n_remainder > 0: - n_iter += 1 - if m_remainder > 0: - m_iter += 1 - - jr_start = -1 - jr_end = -1 - ir_start = -1 - ir_end = -1 - jr_inc = -1 - ir_inc = -1 - - thread_range_jrir( - work_id_macro, - macro_nways, - n_iter, - 1, - ~jr_start, - ~jr_end, - ~jr_inc - ) - - thread_range_jrir( - work_id_1st_loop, - m_iter, - 1, - ~ir_start, - ~ir_end, - ~ir_inc - ) - - rs_packeda = 1 - rstep_a = ps_packed_a - cstep_b = ps_packed_b - - cstep_c = NR - rstep_c = n_size * MR - - macro_c_cast = as_tensor_pointer( - ~c[c_row_off, c_col_off], - dtype=float32, - shape=(m_size, n_size) - ) - temp_c = tensor(scope=DeclareScope.Default, - dtype=float32, - layout=row_major(MR, NR), - is_static=True) - j = jr_start - while j < jr_end: - b1 = packed_b + j * cstep_b - c1 = macro_c_cast + j * cstep_c - - n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder - # Loop over the m dimension, MR rows at a time - i = ir_start - while i < ir_end: - a1 = packed_a + i * rstep_a - c11 = c1 + i * rstep_a - m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - - if m_cur == MR and n_cur == NR: - micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) - else: - for i, j in grid(MR, NR): - temp_c[i, j] = 0.0 - micro_kernel(a1, b1, temp_c, macro_k, macro_m, macro_n, is_first) - if not is_first: - for mm, nn in grid(m_cur, n_cur): - c11[mm, nn] += temp_c[mm, nn] - else: - for mm, nn in grid(m_cur, n_cur): - c11[mm, nn] = temp_c[mm, nn] - - i += ir_inc - j += jr_inc - - - @hidet.script - def gemm_3rd_loop( - a: float32[m_size, k_size], - packed_b: ~float32, - c: float32[m_size, n_size], - loop3_partition_a_start_col: int32, - loop3_partition_b_start_col: int32, - loop3_partition_a_width: int32, - loop3_partition_b_width: int32, - comm_id_3rd_loop: int32, - work_id_3rd_loop: int32, - is_first: bool - ): - comm_id_macro = work_id_3rd_loop % macro_nthreads - work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) - comm_id_packa = comm_id_macro - work_id_packa = comm_id_macro - packa_nways = macro_nthreads - - m_start_loop3 = 0 - m_end_loop3 = 0 - thread_range_sub( - loop3_nways, - work_id_3rd_loop, - m_size, - MR, - ~m_start_loop3, - ~m_end_loop3 - ) - ii = m_start_loop3 - while ii < m_end_loop3: - b_alg_loop3 = determine_blocksize_f_sub( - ii, m_size, MC - ) - # Acquire the partition at loop 3 - loop3_partition_a_start_row = ii - loop3_partition_a_height = b_alg_loop3 - - loop3_partition_a = a + ( - loop3_partition_a_start_row * k_size + - loop3_partition_a_start_col - ) - - # Get our position within the packed A global buffer - packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) - - # TODO: If passed, see if this barrier is necessary - thrcomm_barrier( - comm_id_packa, - ~packa_thrcomm_barrier_sense[work_id_3rd_loop], - ~packa_thrcomm_threads_arrived[work_id_3rd_loop], - packa_nthreads - ) - - gemm_pack_a( - loop3_partition_a, - loop3_partition_a_width, - loop3_partition_a_height, - packed_a_buf, - comm_id_packa, - work_id_packa, - packa_nways - ) - - # This marks the end of the packing of A, - # so a barrier is needed - thrcomm_barrier( - comm_id_packa, - ~packa_thrcomm_barrier_sense[work_id_3rd_loop], - ~packa_thrcomm_threads_arrived[work_id_3rd_loop], - packa_nthreads - ) - - gemm_macro(packed_a_buf, - packed_b, - c, - loop3_partition_a_start_row, - loop3_partition_b_start_col, - loop3_partition_a_height, - loop3_partition_b_width, - loop3_partition_a_width, - MR * loop3_partition_a_width, - packed_b_height * NR, - comm_id_macro, - work_id_macro, - is_first - ) @hidet.script def gemm_pack_a( @@ -750,10 +563,9 @@ def gemm_pack_a( packed_a_buf, float32, layout=row_major(packed_a_individual_height // MR, 1) * - column_major(MR, packed_a_width) + column_major(MR, packed_a_width) ) - npanels_full_a = loop3_partition_a_height // MR panel_a_remainder = loop3_partition_a_height % MR @@ -762,7 +574,8 @@ def gemm_pack_a( if ii_panel % packa_nways != work_id_packa % packa_nways: continue a_curr_panel_row_start = ii_panel * MR - a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) + a_curr_panel_height = min(MR, + loop3_partition_a_height - a_curr_panel_row_start) if a_curr_panel_height == MR: # unroll the packing by 8 k_iters = loop3_partition_a_width // 8 @@ -771,7 +584,7 @@ def gemm_pack_a( for k_iter in range(k_iters): col = k_iter * 8 a_curr_panel_col = loop3_partition_a + ( - a_curr_panel_row_start * k_size + col + a_curr_panel_row_start * k_size + col ) v0 = avx_f32x8_load(a_curr_panel_col) v1 = avx_f32x8_load(a_curr_panel_col * k_size) @@ -812,47 +625,51 @@ def gemm_pack_a( ) avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 2, - col + 1], + col + 1], res2 ) avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 4, - col + 2], + col + 2], res4) avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start, - col + 4], + col + 4], res1 ) avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 2, - col + 5], + col + 5], res3 ) avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 4, - col + 6], + col + 6], res5 ) remaining_start_col = k_iters * 8 for remain_off in range(k_remainder): curr_remain_col = remaining_start_col + remain_off for micropanel_row in range(MR): - packed_a_tensor[a_curr_panel_row_start + micropanel_row, - curr_remain_col] = \ - loop3_partition_a[(micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] + packed_a_tensor[ + a_curr_panel_row_start + micropanel_row, + curr_remain_col] = \ + loop3_partition_a[( + micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] else: remain_start_row = npanels_a * MR for remain_col in range(loop3_partition_a_width): for remain_row in range(panel_a_remainder): - packed_a_tensor[remain_start_row + remain_row, remain_col] = \ - loop3_partition_a[(remain_row + remain_start_row) * k_size + remain_col] + packed_a_tensor[ + remain_start_row + remain_row, remain_col] = \ + loop3_partition_a[( + remain_row + remain_start_row) * k_size + remain_col] remain_row = panel_a_remainder while remain_row < MR: - packed_a_tensor[remain_start_row + remain_row, remain_col] = 0 + packed_a_tensor[ + remain_start_row + remain_row, remain_col] = 0 remain_row += 1 - @hidet.script def gemm_pack_b( loop4_partition_b: ~float32, @@ -872,16 +689,19 @@ def gemm_pack_b( for i_panel in range(npanels_b): if i_panel % packb_nways != work_id_packb % packb_nways: continue - packed_b_buff_curr = packed_b_buf + (i_panel * packedb_panel_stride) + packed_b_buff_curr = packed_b_buf + ( + i_panel * packedb_panel_stride) curr_panel_start = i_panel * NR - curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) + curr_panel_width = min(NR, + loop4_partition_b_width - curr_panel_start) if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 k_remainder = loop4_partition_b_height % 8 row = 0 for k_iter in range(k_iters): row = k_iter * 8 - b_panel = loop4_partition_b + (row * n_size + curr_panel_start) + b_panel = loop4_partition_b + ( + row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) @@ -941,7 +761,8 @@ def gemm_pack_b( row = k_iters + 8 for _ in range(k_remainder): - b_panel = loop4_partition_b + (row * n_size + curr_panel_start) + b_panel = loop4_partition_b + ( + row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) avx_f32x8_store_aligned(packed_b_buff_curr, b00) @@ -950,16 +771,20 @@ def gemm_pack_b( row += 1 else: - packed_b_remaining_buf = packed_b_buf + (npanels_full_b * packedb_panel_stride) + packed_b_remaining_buf = packed_b_buf + ( + npanels_full_b * packedb_panel_stride) if npanels_b_remainder > 0: # TODO: I think this if should always be true if this is executed? remain_col_start = npanels_full_b * NR for remain_row in range(loop4_partition_b_height): - packed_b_remaining_buf_curr = packed_b_remaining_buf + (remain_row * NR) + packed_b_remaining_buf_curr = packed_b_remaining_buf + ( + remain_row * NR) for remain_col in range(npanels_b_remainder): - packed_b_remaining_buf_curr[0] = loop4_partition_b[ - (remain_row * n_size) + (remain_col_start + remain_col) - ] + packed_b_remaining_buf_curr[0] = \ + loop4_partition_b[ + (remain_row * n_size) + ( + remain_col_start + remain_col) + ] packed_b_remaining_buf_curr += 1 zero_fill_col = npanels_b_remainder while zero_fill_col < NR: @@ -967,6 +792,193 @@ def gemm_pack_b( packed_b_remaining_buf_curr += 1 zero_fill_col += 1 + @hidet.script + def gemm_macro( + packed_a: ~float32, + packed_b: ~float32, + c: float32[m_size, n_size], + c_row_off: int32, + c_col_off: int32, + macro_m: int32, + macro_n: int32, + macro_k: int32, + ps_packed_a: int32, + ps_packed_b: int32, + comm_id_macro: int32, + work_id_macro: int32, + is_first: bool + ): + comm_id_1st_loop = comm_id_macro % loop1_nthreads + work_id_1st_loop = comm_id_macro // (loop1_nthreads // loop1_nways) + + n_iter = macro_n // NR + n_remainder = macro_n % NR + m_iter = macro_m // MR + m_remainder = macro_m % MR + + if n_remainder > 0: + n_iter += 1 + if m_remainder > 0: + m_iter += 1 + + jr_start = -1 + jr_end = -1 + ir_start = -1 + ir_end = -1 + jr_inc = -1 + ir_inc = -1 + + thread_range_jrir( + work_id_macro, + macro_nways, + n_iter, + 1, + ~jr_start, + ~jr_end, + ~jr_inc + ) + + thread_range_jrir( + work_id_1st_loop, + m_iter, + 1, + ~ir_start, + ~ir_end, + ~ir_inc + ) + + rs_packeda = 1 + rstep_a = ps_packed_a + cstep_b = ps_packed_b + + cstep_c = NR + rstep_c = n_size * MR + + macro_c_cast = as_tensor_pointer( + ~c[c_row_off, c_col_off], + dtype=float32, + shape=(m_size, n_size) + ) + temp_c = tensor(scope=DeclareScope.Default, + dtype=float32, + layout=row_major(MR, NR), + is_static=True) + j = jr_start + while j < jr_end: + b1 = packed_b + j * cstep_b + c1 = macro_c_cast + j * cstep_c + + n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder + # Loop over the m dimension, MR rows at a time + i = ir_start + while i < ir_end: + a1 = packed_a + i * rstep_a + c11 = c1 + i * rstep_a + m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder + + if m_cur == MR and n_cur == NR: + micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) + else: + for i, j in grid(MR, NR): + temp_c[i, j] = 0.0 + micro_kernel(a1, b1, temp_c, macro_k, macro_m, macro_n, is_first) + if not is_first: + for mm, nn in grid(m_cur, n_cur): + c11[mm, nn] += temp_c[mm, nn] + else: + for mm, nn in grid(m_cur, n_cur): + c11[mm, nn] = temp_c[mm, nn] + + i += ir_inc + j += jr_inc + + + @hidet.script + def gemm_3rd_loop( + a: float32[m_size, k_size], + packed_b: ~float32, + c: float32[m_size, n_size], + loop3_partition_a_start_col: int32, + loop3_partition_b_start_col: int32, + loop3_partition_a_width: int32, + loop3_partition_b_width: int32, + comm_id_3rd_loop: int32, + work_id_3rd_loop: int32, + is_first: bool): + comm_id_macro = work_id_3rd_loop % macro_nthreads + work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) + comm_id_packa = comm_id_macro + work_id_packa = comm_id_macro + packa_nways = macro_nthreads + + m_start_loop3 = 0 + m_end_loop3 = 0 + thread_range_sub( + loop3_nways, + work_id_3rd_loop, + m_size, + MR, + ~m_start_loop3, + ~m_end_loop3 + ) + ii = m_start_loop3 + while ii < m_end_loop3: + b_alg_loop3 = determine_blocksize_f_sub( + ii, m_size, MC + ) + # Acquire the partition at loop 3 + loop3_partition_a_start_row = ii + loop3_partition_a_height = b_alg_loop3 + + loop3_partition_a = a + ( + loop3_partition_a_start_row * k_size + + loop3_partition_a_start_col + ) + + # Get our position within the packed A global buffer + packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + + # TODO: If passed, see if this barrier is necessary + thrcomm_barrier( + comm_id_packa, + ~packa_thrcomm_barrier_sense[work_id_3rd_loop], + ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + packa_nthreads + ) + + gemm_pack_a( + loop3_partition_a, + loop3_partition_a_width, + loop3_partition_a_height, + packed_a_buf, + comm_id_packa, + work_id_packa, + packa_nways + ) + + # This marks the end of the packing of A, + # so a barrier is needed + thrcomm_barrier( + comm_id_packa, + ~packa_thrcomm_barrier_sense[work_id_3rd_loop], + ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + packa_nthreads + ) + + gemm_macro(packed_a_buf, + packed_b, + c, + loop3_partition_a_start_row, + loop3_partition_b_start_col, + loop3_partition_a_height, + loop3_partition_b_width, + loop3_partition_a_width, + MR * loop3_partition_a_width, + packed_b_height * NR, + comm_id_macro, + work_id_macro, + is_first + ) @hidet.script def gemm_4th_loop(a: float32[m_size, k_size], From 1a87c2736dca7fcc6b67a69c563872f7e0244328 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 10:16:58 -0400 Subject: [PATCH 019/148] more errors --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 378a8ece8..5f10b6406 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -587,7 +587,7 @@ def gemm_pack_a( a_curr_panel_row_start * k_size + col ) v0 = avx_f32x8_load(a_curr_panel_col) - v1 = avx_f32x8_load(a_curr_panel_col * k_size) + v1 = avx_f32x8_load(a_curr_panel_col + k_size) v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) v3 = avx_f32x8_load(a_curr_panel_col + (3 * k_size)) v4 = avx_f32x8_load(a_curr_panel_col + (4 * k_size)) From 31044733f3a41ee9f75c93757f4b3c61cb3fc522 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 10:17:49 -0400 Subject: [PATCH 020/148] more err --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 5f10b6406..3c8ef9bd4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -682,7 +682,7 @@ def gemm_pack_b( npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR - npanels_b = npanels_full_b + (npanels_b_remainder != 0) + npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR # Loop for the packing of B From 68bc03dc0df1ae03638375bb36b61facf0f6b5ff Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 11:31:10 -0400 Subject: [PATCH 021/148] ... --- .../ops/matmul/matmul_f32_x86_refactored.py | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 3c8ef9bd4..022bcfa4b 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -116,53 +116,59 @@ def schedule_matmulf32_x86( tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') - packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC)) - packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR)) - - # Get the number of threads... - loop5_nways, loop3_nways, macro_nways, loop1_nways = ways - loop4_nways = 1 - nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways - - # Get the number of threads remaining at each level - loop5_nthreads = nthreads - loop4_nthreads = loop5_nthreads // loop5_nways - loop3_nthreads = loop4_nthreads - macro_nthreads = loop3_nthreads // loop3_nways - loop1_nthreads = macro_nthreads // macro_nways - - packb_nthreads = loop3_nthreads - packa_nthreads = macro_nthreads - - # TODO: Since Hidet doesn't support the parallel region syntax as in OpenMP, - # TODO: We instead use a loop to simulate the parallel region, with the "thread id" being the loop index. - outermost_iters = nthreads - - loop5_thrcomm_barrier_sense = 0 - loop5_thrcomm_barrier_threads_arrived = 0 - - packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways]) - # for idx in range(loop4_nways): - # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized - packb_thrcomm_barrier_threads_arrived = tensor('int32', shape=[loop4_nways]) - - packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways]) - packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways]) - - # The buffer for storing the starting offset of the packed B buffers for thread, - # indexed by the work ID of Loop5 - packb_start_offsets = tensor('int32', shape=[loop5_nways, 1]) - # The buffer for storing the starting offset of the packed A buffers for thread, - # indexed by the work ID of Loop3 - packa_start_offsets = tensor('int32', shape=[loop3_nways]) - - # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 - packb_sizes = tensor('int32', shape=[loop5_nways]) - # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 - packa_sizes = tensor('int32', shape=[loop3_nways]) + with hidet.script_module() as module: # Helpers + packed_a_type = tensor_type('float32', layout=row_major(MC // MR, + 1) * column_major( + MR, KC)) + packed_b_type = tensor_type('float32', layout=row_major(1, + NC // NR) * row_major( + KC, NR)) + + # Get the number of threads... + loop5_nways, loop3_nways, macro_nways, loop1_nways = ways + loop4_nways = 1 + nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways + + # Get the number of threads remaining at each level + loop5_nthreads = nthreads + loop4_nthreads = loop5_nthreads // loop5_nways + loop3_nthreads = loop4_nthreads + macro_nthreads = loop3_nthreads // loop3_nways + loop1_nthreads = macro_nthreads // macro_nways + + packb_nthreads = loop3_nthreads + packa_nthreads = macro_nthreads + + # TODO: Since Hidet doesn't support the parallel region syntax as in OpenMP, + # TODO: We instead use a loop to simulate the parallel region, with the "thread id" being the loop index. + outermost_iters = nthreads + + loop5_thrcomm_barrier_sense = 0 + loop5_thrcomm_barrier_threads_arrived = 0 + + packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways]) + # for idx in range(loop4_nways): + # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized + packb_thrcomm_barrier_threads_arrived = tensor('int32', + shape=[loop4_nways]) + + packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways]) + packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways]) + + # The buffer for storing the starting offset of the packed B buffers for thread, + # indexed by the work ID of Loop5 + packb_start_offsets = tensor('int32', shape=[loop5_nways, 1]) + # The buffer for storing the starting offset of the packed A buffers for thread, + # indexed by the work ID of Loop3 + packa_start_offsets = tensor('int32', shape=[loop3_nways]) + + # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 + packb_sizes = tensor('int32', shape=[loop5_nways]) + # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 + packa_sizes = tensor('int32', shape=[loop3_nways]) @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): From 9059ca3ef70b3fec390b4ef1bcbb846a8f224e55 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 11:46:57 -0400 Subject: [PATCH 022/148] ... --- .../hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 022bcfa4b..dbc266ae7 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -160,15 +160,15 @@ def schedule_matmulf32_x86( # The buffer for storing the starting offset of the packed B buffers for thread, # indexed by the work ID of Loop5 - packb_start_offsets = tensor('int32', shape=[loop5_nways, 1]) + packb_start_offsets = tensor('int32', shape=[loop5_nways,]) # The buffer for storing the starting offset of the packed A buffers for thread, # indexed by the work ID of Loop3 - packa_start_offsets = tensor('int32', shape=[loop3_nways]) + packa_start_offsets = tensor('int32', shape=[loop3_nways,]) # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 - packb_sizes = tensor('int32', shape=[loop5_nways]) + packb_sizes = tensor('int32', shape=[loop5_nways,]) # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 - packa_sizes = tensor('int32', shape=[loop3_nways]) + packa_sizes = tensor('int32', shape=[loop3_nways,]) @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): From df5a177fb95f71147f95d158a0ca73429f0bf9d5 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 23 Aug 2023 11:49:20 -0400 Subject: [PATCH 023/148] global --- .../ops/matmul/matmul_f32_x86_refactored.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index dbc266ae7..873def7dc 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -149,14 +149,23 @@ def schedule_matmulf32_x86( loop5_thrcomm_barrier_sense = 0 loop5_thrcomm_barrier_threads_arrived = 0 - packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways]) + packb_thrcomm_barrier_sense = tensor(dtype='int32', + shape=[loop4_nways], + scope=DeclareScope.Default, + is_static=True) # for idx in range(loop4_nways): # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized - packb_thrcomm_barrier_threads_arrived = tensor('int32', - shape=[loop4_nways]) - - packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways]) - packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways]) + packb_thrcomm_barrier_threads_arrived = tensor(dtype='int32', + shape=[loop4_nways], + scope=DeclareScope.Default, + is_static=True) + + packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], + scope=DeclareScope.Default, + is_static=True) + packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], + scope=DeclareScope.Default, + is_static=True) # The buffer for storing the starting offset of the packed B buffers for thread, # indexed by the work ID of Loop5 From 27da1ba78ce40d68ee131ef7b927eaeb3f660caa Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 00:41:16 -0400 Subject: [PATCH 024/148] global var --- .../ops/matmul/matmul_f32_x86_refactored.py | 63 ++++++++++++++----- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 873def7dc..51a1d99b5 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -119,6 +119,45 @@ def schedule_matmulf32_x86( with hidet.script_module() as module: + # Get the number of threads... + loop5_nways, loop3_nways, macro_nways, loop1_nways = ways + loop4_nways = 1 + nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways + + # Use the define_global_var functionality. + # packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], + # scope=DeclareScope.Default, + # is_static=True) + # packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], + # scope=DeclareScope.Default, + # is_static=True) + + packa_thrcomm_barrier_sense = module.define_global_var( + name="pack_a_barrier_sense", + var_type=int32[loop3_nways] + ) + packa_thrcomm_threads_arrived = module.define_global_var( + name="pack_a_threads_arrived", + var_type=int32[loop3_nways] + ) + + packb_thrcomm_barrier_sense = module.define_global_var( + name='pack_b_barrier_sense', + var_type=int32[loop5_nways] + ) + packb_thrcomm_barrier_threads_arrived = module.define_global_var( + name="pack_b_threads_arrived", + var_type=int32[loop5_nways] + ) + for i in range(loop3_nways): + packa_thrcomm_barrier_sense[i] = 0 + packa_thrcomm_threads_arrived[i] = 0 + for i in range(loop5_nways): + packb_thrcomm_barrier_sense[i] = 0 + packb_thrcomm_barrier_threads_arrived = [0] + + + # Helpers packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major( @@ -127,10 +166,6 @@ def schedule_matmulf32_x86( NC // NR) * row_major( KC, NR)) - # Get the number of threads... - loop5_nways, loop3_nways, macro_nways, loop1_nways = ways - loop4_nways = 1 - nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways # Get the number of threads remaining at each level loop5_nthreads = nthreads @@ -149,16 +184,16 @@ def schedule_matmulf32_x86( loop5_thrcomm_barrier_sense = 0 loop5_thrcomm_barrier_threads_arrived = 0 - packb_thrcomm_barrier_sense = tensor(dtype='int32', - shape=[loop4_nways], - scope=DeclareScope.Default, - is_static=True) - # for idx in range(loop4_nways): - # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized - packb_thrcomm_barrier_threads_arrived = tensor(dtype='int32', - shape=[loop4_nways], - scope=DeclareScope.Default, - is_static=True) + # packb_thrcomm_barrier_sense = tensor(dtype='int32', + # shape=[loop4_nways], + # scope=DeclareScope.Default, + # is_static=True) + # # for idx in range(loop4_nways): + # # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized + # packb_thrcomm_barrier_threads_arrived = tensor(dtype='int32', + # shape=[loop4_nways], + # scope=DeclareScope.Default, + # is_static=True) packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], scope=DeclareScope.Default, From fca3694c5ab26c282b62e66724906d5b3728b08b Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 00:49:46 -0400 Subject: [PATCH 025/148] . --- .../ops/matmul/matmul_f32_x86_refactored.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 51a1d99b5..f92ba60e8 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -149,12 +149,28 @@ def schedule_matmulf32_x86( name="pack_b_threads_arrived", var_type=int32[loop5_nways] ) - for i in range(loop3_nways): - packa_thrcomm_barrier_sense[i] = 0 - packa_thrcomm_threads_arrived[i] = 0 - for i in range(loop5_nways): - packb_thrcomm_barrier_sense[i] = 0 - packb_thrcomm_barrier_threads_arrived = [0] + + @hidet.script + def init_thr(sense: ~int32, arrived: ~int32, size: int32): + for i in range(size): + sense[i] = 0 + arrived[i] = 0 + + init_thr(packa_thrcomm_barrier_sense, + packa_thrcomm_threads_arrived, + loop3_nways) + init_thr(packb_thrcomm_barrier_sense, + packb_thrcomm_barrier_threads_arrived, + loop3_nways) + + + + # for i in range(loop3_nways): + # packa_thrcomm_barrier_sense[i] = 0 + # packa_thrcomm_threads_arrived[i] = 0 + # for i in range(loop5_nways): + # packb_thrcomm_barrier_sense[i] = 0 + # packb_thrcomm_barrier_threads_arrived = [0] From 14973b4968b83ca28f940b0ef244202b90b29bb9 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 00:52:47 -0400 Subject: [PATCH 026/148] . --- .../ops/matmul/matmul_f32_x86_refactored.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index f92ba60e8..e2bd2ffba 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -156,15 +156,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): sense[i] = 0 arrived[i] = 0 - init_thr(packa_thrcomm_barrier_sense, - packa_thrcomm_threads_arrived, - loop3_nways) - init_thr(packb_thrcomm_barrier_sense, - packb_thrcomm_barrier_threads_arrived, - loop3_nways) - - - # for i in range(loop3_nways): # packa_thrcomm_barrier_sense[i] = 0 # packa_thrcomm_threads_arrived[i] = 0 @@ -1165,6 +1156,13 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], # TODO: Since we(they, BLIS) use a memory broker... Allocate a little more memory is OK I think??? # packed_b_individual_width = NC + init_thr(packa_thrcomm_barrier_sense, + packa_thrcomm_threads_arrived, + loop3_nways) + init_thr(packb_thrcomm_barrier_sense, + packb_thrcomm_barrier_threads_arrived, + loop3_nways) + parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads for tidx in grid(nthreads, attrs=parallel_attr): From 45ad16a771b8a75355e54d4914d91711fa1064c8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 00:54:33 -0400 Subject: [PATCH 027/148] ... --- .../ops/matmul/matmul_f32_x86_refactored.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index e2bd2ffba..626e2339e 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -191,31 +191,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): loop5_thrcomm_barrier_sense = 0 loop5_thrcomm_barrier_threads_arrived = 0 - # packb_thrcomm_barrier_sense = tensor(dtype='int32', - # shape=[loop4_nways], - # scope=DeclareScope.Default, - # is_static=True) - # # for idx in range(loop4_nways): - # # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized - # packb_thrcomm_barrier_threads_arrived = tensor(dtype='int32', - # shape=[loop4_nways], - # scope=DeclareScope.Default, - # is_static=True) - - packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], - scope=DeclareScope.Default, - is_static=True) - packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], - scope=DeclareScope.Default, - is_static=True) - - # The buffer for storing the starting offset of the packed B buffers for thread, - # indexed by the work ID of Loop5 - packb_start_offsets = tensor('int32', shape=[loop5_nways,]) - # The buffer for storing the starting offset of the packed A buffers for thread, - # indexed by the work ID of Loop3 - packa_start_offsets = tensor('int32', shape=[loop3_nways,]) - # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 packb_sizes = tensor('int32', shape=[loop5_nways,]) # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 From 36b3c52ea213fdaaf03e6af47f2de3bd4ffaf7b3 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 08:33:50 -0400 Subject: [PATCH 028/148] ..: --- .../ops/matmul/matmul_f32_x86_refactored.py | 76 +------------------ 1 file changed, 2 insertions(+), 74 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 626e2339e..b17897886 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -83,13 +83,7 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - # @tune.space( - # 2, - # block_m=[2016, 3024], - # block_n=[64, 144, 192, 256, 384, 512, 592, 672, 752, 896, 1024], - # block_k=[96, 128, 256, 384, 512, 560, 688, 784], - # nthreads=[4, 8, 16, 32], - # ) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], nthreads=[8, 16]) def schedule_matmulf32_x86( self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) @@ -99,9 +93,7 @@ def schedule_matmulf32_x86( from hidet.lang import tensor, grid, as_tensor_pointer from hidet.lang.layout import row_major, column_major from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast - from hidet.lang.cpu import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned - from hidet.lang.cpu import avx_f32x4_store_aligned, avx_f32x4_load_aligned from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 @@ -124,13 +116,6 @@ def schedule_matmulf32_x86( loop4_nways = 1 nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways - # Use the define_global_var functionality. - # packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], - # scope=DeclareScope.Default, - # is_static=True) - # packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], - # scope=DeclareScope.Default, - # is_static=True) packa_thrcomm_barrier_sense = module.define_global_var( name="pack_a_barrier_sense", @@ -184,17 +169,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): packb_nthreads = loop3_nthreads packa_nthreads = macro_nthreads - # TODO: Since Hidet doesn't support the parallel region syntax as in OpenMP, - # TODO: We instead use a loop to simulate the parallel region, with the "thread id" being the loop index. - outermost_iters = nthreads - - loop5_thrcomm_barrier_sense = 0 - loop5_thrcomm_barrier_threads_arrived = 0 - - # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 - packb_sizes = tensor('int32', shape=[loop5_nways,]) - # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 - packa_sizes = tensor('int32', shape=[loop3_nways,]) @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): @@ -487,57 +461,10 @@ def micro_kernel( #### Some setup code #### packed_b_total_width = 0 - for workid_loop5 in range(loop5_nways): - loop5_start = 0 - loop5_end = 0 - # thread_range_sub(loop5_nways, workid_loop5, n_size, NR, ~loop5_start, ~loop5_end) - # TODO: For now, substitute the above func call with code - if loop5_nways == 1: - loop5_start = 0 - loop5_end = n_size - else: - all_start = 0 - all_end = n_size - size = all_end - all_start - n_bf_whole = n_size // NR - n_bf_left = n_size % NR - n_bf_lo = n_bf_whole // loop5_nways - n_bf_hi = n_bf_whole // loop5_nways - - n_th_lo = n_bf_whole % loop5_nways - if n_th_lo != 0: - n_bf_lo += 1 - size_lo = n_bf_lo * NR - size_hi = n_bf_hi * NR - - lo_start = all_start - hi_start = all_start + n_th_lo * size_lo - - if workid_loop5 < n_th_lo: - loop5_start = lo_start + workid_loop5 * size_lo - loop5_end = lo_start + (workid_loop5 + 1) * size_lo - else: - loop5_start = hi_start + (workid_loop5 - n_th_lo) * size_hi - loop5_end = hi_start + (workid_loop5 - n_th_lo + 1) * size_hi - - if workid_loop5 == loop5_nways - 1: - loop5_end += n_bf_left - - - # curr_width = loop5_end - loop5_start - # # packed_b_total_width += curr_width - # # packb_start_offsets[workid_loop5] = temp_prev - # # temp_prev += curr_width - # packb_start_offsets[workid_loop5] = packed_b_total_width - # packed_b_total_width += curr_width - - # packed_b_individual_width = min(NC, n_size) packed_b_height = KC if packed_b_height > k_size: packed_b_height = k_size - # packed_b_height = (k_size + NR - 1) // NR * NR - # packed_b_total_size = packed_b_total_width * packed_b_height packed_b_width = NC if packed_b_width > n_size: packed_b_widht = (n_size + NR - 1) // NR * NR @@ -872,6 +799,7 @@ def gemm_macro( thread_range_jrir( work_id_1st_loop, + loop1_nways, m_iter, 1, ~ir_start, From c79fcca699fae5480c483059181cfe8f36afc4e5 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 09:12:38 -0400 Subject: [PATCH 029/148] cast --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index b17897886..f266d2f36 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -890,7 +890,7 @@ def gemm_3rd_loop( loop3_partition_a_start_row = ii loop3_partition_a_height = b_alg_loop3 - loop3_partition_a = a + ( + loop3_partition_a = cast(a, ~float32) + ( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) @@ -975,7 +975,8 @@ def gemm_4th_loop(a: float32[m_size, k_size], packed_b_individual_size * work_id_5th_loop ) - loop4_partition_b = b + \ + + loop4_partition_b = cast(b, ~float32) + \ (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) From 87cdd76c8c05986dfe2b67d6f612f49ce5478393 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 09:16:01 -0400 Subject: [PATCH 030/148] cast --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index f266d2f36..8eef45c54 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -841,7 +841,7 @@ def gemm_macro( else: for i, j in grid(MR, NR): temp_c[i, j] = 0.0 - micro_kernel(a1, b1, temp_c, macro_k, macro_m, macro_n, is_first) + micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, macro_m, macro_n, is_first) if not is_first: for mm, nn in grid(m_cur, n_cur): c11[mm, nn] += temp_c[mm, nn] From 8648cedd3ca0d7ff77b1ac186adfb2817ae1d539 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 09:29:06 -0400 Subject: [PATCH 031/148] ... --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 8eef45c54..83a71e953 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -751,6 +751,10 @@ def gemm_pack_b( packed_b_remaining_buf_curr += 1 zero_fill_col += 1 + gemm_pack_b.kind = "cpu_internal" + gemm_pack_a.kind = "cpu_internal" + micro_kernel.kind = "cpu_internal" + @hidet.script def gemm_macro( packed_a: ~float32, @@ -852,6 +856,8 @@ def gemm_macro( i += ir_inc j += jr_inc + gemm_macro.kind = "cpu_internal" + @hidet.script def gemm_3rd_loop( @@ -940,6 +946,8 @@ def gemm_3rd_loop( is_first ) + gemm_3rd_loop.kind = "cpu_internal" + @hidet.script def gemm_4th_loop(a: float32[m_size, k_size], b: float32[k_size, n_size], @@ -1019,6 +1027,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], i_loop4 += b_alg_loop4 + gemm_4th_loop.kind = "cpu_internal" @hidet.script @@ -1049,6 +1058,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], work_id_4th_loop, work_id_5th_loop) loop5_iter += b_alg_loop5 + gemm_5th_loop.kind = 'cpu_internal' ################### Start of the main kernel ################### From 075cc64ce03c73c94028deba5c8d8ef41405b212 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 09:30:54 -0400 Subject: [PATCH 032/148] . --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 83a71e953..9a36fc5f8 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -141,12 +141,7 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): sense[i] = 0 arrived[i] = 0 - # for i in range(loop3_nways): - # packa_thrcomm_barrier_sense[i] = 0 - # packa_thrcomm_threads_arrived[i] = 0 - # for i in range(loop5_nways): - # packb_thrcomm_barrier_sense[i] = 0 - # packb_thrcomm_barrier_threads_arrived = [0] + init_thr.kind = "cpu_internal" @@ -209,6 +204,7 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~ # Add the remainder to the last thread's end if work_id == n_way - 1: end[0] += n_bf_left + thread_range_sub.kind = "cpu_internal" @hidet.script def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, @@ -217,6 +213,8 @@ def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, end[0] = n inc[0] = n_way + thread_range_jrir.kind = "cpu_internal" + @hidet.script def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: dim_left_now = dim - i @@ -228,9 +226,12 @@ def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: assert b_now >= 0 return b_now + determine_blocksize_f_sub.kind = "cpu_internal" + @hidet.script def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: return i != n_iter - 1 or n_left == 0 + not_edge.kind = 'cpu_internal' # Thread barrier @hidet.script @@ -251,6 +252,8 @@ def thrcomm_barrier(tid: int32, barrier_sense: ~int32, while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE pass + thrcomm_barrier.kind = 'cpu_internal' + @hidet.script def micro_kernel( a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, From 7814d6dbd996d6626b4e9100dcd7efe2b45627a8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Fri, 25 Aug 2023 10:04:06 -0400 Subject: [PATCH 033/148] now segfault not internal errors --- python/hidet/mat_new.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py index 0303a7b29..30a554caa 100644 --- a/python/hidet/mat_new.py +++ b/python/hidet/mat_new.py @@ -25,6 +25,7 @@ def matmul_ansor(M, K, N, dtype): ) return [A, B, rst] +hidet.option.cache_dir("./wtf") target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) From ff058bfca43fd68841c1aad5e00283cbf7bb1c88 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 03:57:55 -0400 Subject: [PATCH 034/148] stupid error --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9a36fc5f8..a329dd9f5 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1078,7 +1078,7 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], loop3_nways) init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, - loop3_nways) + loop5_nways) parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads From f9f3b816f7ccf7f0e2e14229ab8e508afe537e4f Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 04:05:28 -0400 Subject: [PATCH 035/148] err --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index a329dd9f5..2d5536aab 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -672,7 +672,7 @@ def gemm_pack_b( packed_b_buff_curr += 16 b10 = avx_f32x8_load(b_panel + n_size) - b18 = avx_f32x8_load(b_panel + (n_size * 8)) + b18 = avx_f32x8_load(b_panel + (n_size + 8)) avx_f32x8_store_aligned(packed_b_buff_curr, b10) avx_f32x8_store_aligned(packed_b_buff_curr + 8, b18) From 0a7b2fe491e477055375baf08b572a7ec39e1857 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 05:58:06 -0400 Subject: [PATCH 036/148] ... --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 2d5536aab..e1eaadb2c 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -14,7 +14,7 @@ from hidet.ir.expr import cast from hidet.ir.module import IRModule from hidet.ir.compute import TensorNode -from hidet.ir.primitives import avx_malloc +from hidet.ir.primitives import avx_malloc, printf from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned from hidet.ir.stmt import DeclareScope from hidet.ir.task import Task @@ -259,6 +259,7 @@ def micro_kernel( a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, is_first: bool ): + printf("The start of the micro_kernel.....") c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) c08 = avx_f32x8_load(~c[0, 8]) @@ -459,6 +460,7 @@ def micro_kernel( avx_f32x8_store(c_ptr + 5 * nsize, c5) avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) + printf("The end of micro kernel....") @@ -491,8 +493,10 @@ def micro_kernel( packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) - packb_buf = as_tensor_pointer(packb_buf_ptr, dtype=float32, shape=[packed_b_total_size]) - packa_buf = as_tensor_pointer(packa_buf_ptr, dtype=float32, shape=[packed_a_total_size]) + # packb_buf = as_tensor_pointer(packb_buf_ptr, dtype=float32, shape=[packed_b_total_size]) + # packa_buf = as_tensor_pointer(packa_buf_ptr, dtype=float32, shape=[packed_a_total_size]) + packb_buf = cast(packb_buf_ptr, ~float32) + packa_buf = cast(packa_buf_ptr, ~float32) packed_a_type = tensor_type( dtype='float32', From 99954e17a57f742e64d7ca65aada19ea240c4b81 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 06:03:01 -0400 Subject: [PATCH 037/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 7 +++++-- python/hidet/mat_new.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index e1eaadb2c..b2cfa952d 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -259,7 +259,7 @@ def micro_kernel( a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, is_first: bool ): - printf("The start of the micro_kernel.....") + # printf("The start of the micro_kernel.....") c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) c08 = avx_f32x8_load(~c[0, 8]) @@ -460,7 +460,7 @@ def micro_kernel( avx_f32x8_store(c_ptr + 5 * nsize, c5) avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) - printf("The end of micro kernel....") + # printf("The end of micro kernel....") @@ -525,6 +525,7 @@ def gemm_pack_a( work_id_packa: int32, packa_nways: int32 ): + printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, @@ -635,6 +636,8 @@ def gemm_pack_a( packed_a_tensor[ remain_start_row + remain_row, remain_col] = 0 remain_row += 1 + printf("The end of the pack a, comm id: %d, work id: %d\n", + comm_id_packa, work_id_packa) @hidet.script def gemm_pack_b( diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py index 30a554caa..d5b37518d 100644 --- a/python/hidet/mat_new.py +++ b/python/hidet/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(384, 256, 256), (512, 512, 512), (1024, 1024, 1024)]: +for m, n, k in [(100, 100, 100), (384, 256, 256), (512, 512, 512), (1024, 1024, 1024)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From b884a95af80f4be9f348f9b48af75da039b63734 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 23:01:31 -0400 Subject: [PATCH 038/148] .. --- .../hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index b2cfa952d..71b435388 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -525,7 +525,7 @@ def gemm_pack_a( work_id_packa: int32, packa_nways: int32 ): - printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) + # printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, @@ -636,8 +636,8 @@ def gemm_pack_a( packed_a_tensor[ remain_start_row + remain_row, remain_col] = 0 remain_row += 1 - printf("The end of the pack a, comm id: %d, work id: %d\n", - comm_id_packa, work_id_packa) + # printf("The end of the pack a, comm id: %d, work id: %d\n", + # comm_id_packa, work_id_packa) @hidet.script def gemm_pack_b( @@ -1093,6 +1093,9 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], tid_5th_loop = tidx work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop + + printf("tidx: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, work_id_5th_loop, comm_id_5th_loop) + gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) From 12a139ad87358c59fc6ff40f1c0ff7f825b59733 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 23:06:43 -0400 Subject: [PATCH 039/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 71b435388..d9fb9ba71 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1048,6 +1048,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], comm_id_5th_loop: int32): comm_id_4th_loop = comm_id_5th_loop % loop4_nways work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) + printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) loop5_my_start = -1 loop5_my_end = -1 @@ -1068,6 +1069,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], work_id_4th_loop, work_id_5th_loop) loop5_iter += b_alg_loop5 + printf("End of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop.kind = 'cpu_internal' @@ -1094,7 +1096,7 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop - printf("tidx: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, work_id_5th_loop, comm_id_5th_loop) + # printf("tidx: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) From 8cf009da9b0b7f41259bdc30d7c421fdc3df7455 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 23:14:16 -0400 Subject: [PATCH 040/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index d9fb9ba71..28e07dcff 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -975,6 +975,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], comm_id_packb = comm_id_3rd_loop work_id_packb = comm_id_3rd_loop # packb_nways = loop3_nthreads + printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n") while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) @@ -1037,6 +1038,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], i_loop4 += b_alg_loop4 + printf("The end of the 4th loop. work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) gemm_4th_loop.kind = "cpu_internal" @@ -1048,7 +1050,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], comm_id_5th_loop: int32): comm_id_4th_loop = comm_id_5th_loop % loop4_nways work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) - printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) + # printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) loop5_my_start = -1 loop5_my_end = -1 @@ -1069,7 +1071,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], work_id_4th_loop, work_id_5th_loop) loop5_iter += b_alg_loop5 - printf("End of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) + # printf("End of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop.kind = 'cpu_internal' From 717069f36f00816f328da922e235e48981dbdd5c Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sat, 26 Aug 2023 23:38:35 -0400 Subject: [PATCH 041/148] ... --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 28e07dcff..8124095f4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -975,7 +975,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], comm_id_packb = comm_id_3rd_loop work_id_packb = comm_id_3rd_loop # packb_nways = loop3_nthreads - printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n") + printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) @@ -1050,7 +1050,9 @@ def gemm_5th_loop(a: float32[m_size, k_size], comm_id_5th_loop: int32): comm_id_4th_loop = comm_id_5th_loop % loop4_nways work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) - # printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) + printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d, comm_id_4th_loop: %d, work_id_4th_loop: %d\n", + work_id_5th_loop, comm_id_5th_loop, + comm_id_4th_loop, work_id_4th_loop) loop5_my_start = -1 loop5_my_end = -1 From 7b53554e619f91107ebfdfe642487f5f24518477 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 01:09:01 -0400 Subject: [PATCH 042/148] . --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 8124095f4..264f48655 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -648,6 +648,7 @@ def gemm_pack_b( comm_id_packb: int32, work_id_packb: int32, packb_nways: int32 ): + printf("The start of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR @@ -760,6 +761,8 @@ def gemm_pack_b( packed_b_remaining_buf_curr[0] = 0.0 packed_b_remaining_buf_curr += 1 zero_fill_col += 1 + printf("The end of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) + gemm_pack_b.kind = "cpu_internal" gemm_pack_a.kind = "cpu_internal" @@ -975,7 +978,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], comm_id_packb = comm_id_3rd_loop work_id_packb = comm_id_3rd_loop # packb_nways = loop3_nthreads - printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) + # printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) @@ -1038,7 +1041,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], i_loop4 += b_alg_loop4 - printf("The end of the 4th loop. work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) + # printf("The end of the 4th loop. work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) gemm_4th_loop.kind = "cpu_internal" @@ -1050,9 +1053,9 @@ def gemm_5th_loop(a: float32[m_size, k_size], comm_id_5th_loop: int32): comm_id_4th_loop = comm_id_5th_loop % loop4_nways work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) - printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d, comm_id_4th_loop: %d, work_id_4th_loop: %d\n", - work_id_5th_loop, comm_id_5th_loop, - comm_id_4th_loop, work_id_4th_loop) + # printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d, comm_id_4th_loop: %d, work_id_4th_loop: %d\n", + # work_id_5th_loop, comm_id_5th_loop, + # comm_id_4th_loop, work_id_4th_loop) loop5_my_start = -1 loop5_my_end = -1 From f933711f2841f4ba9814703ee6c6620f5ef9c572 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 02:26:26 -0400 Subject: [PATCH 043/148] small fix --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 264f48655..c6ce881d1 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1051,7 +1051,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], c: float32[m_size, n_size], work_id_5th_loop: int32, comm_id_5th_loop: int32): - comm_id_4th_loop = comm_id_5th_loop % loop4_nways + comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) # printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d, comm_id_4th_loop: %d, work_id_4th_loop: %d\n", # work_id_5th_loop, comm_id_5th_loop, From 42054a46ef2f147867606c1ef6ca7b816f754cea Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 02:46:34 -0400 Subject: [PATCH 044/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index c6ce881d1..a670bf180 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -648,7 +648,7 @@ def gemm_pack_b( comm_id_packb: int32, work_id_packb: int32, packb_nways: int32 ): - printf("The start of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) + # printf("The start of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR @@ -761,7 +761,7 @@ def gemm_pack_b( packed_b_remaining_buf_curr[0] = 0.0 packed_b_remaining_buf_curr += 1 zero_fill_col += 1 - printf("The end of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) + # printf("The end of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) gemm_pack_b.kind = "cpu_internal" @@ -884,6 +884,9 @@ def gemm_3rd_loop( comm_id_3rd_loop: int32, work_id_3rd_loop: int32, is_first: bool): + + printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) + comm_id_macro = work_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) comm_id_packa = comm_id_macro @@ -958,6 +961,7 @@ def gemm_3rd_loop( work_id_macro, is_first ) + printf("The end of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) gemm_3rd_loop.kind = "cpu_internal" @@ -1103,8 +1107,6 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop - # printf("tidx: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, work_id_5th_loop, comm_id_5th_loop) - gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) From 60599c2bb594b726cfe4df4e9cf27712d69a77b8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 02:51:04 -0400 Subject: [PATCH 045/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index a670bf180..a567eb9f4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -525,7 +525,7 @@ def gemm_pack_a( work_id_packa: int32, packa_nways: int32 ): - # printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) + printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, @@ -533,6 +533,7 @@ def gemm_pack_a( column_major(MR, packed_a_width) ) + npanels_full_a = loop3_partition_a_height // MR panel_a_remainder = loop3_partition_a_height % MR @@ -636,8 +637,8 @@ def gemm_pack_a( packed_a_tensor[ remain_start_row + remain_row, remain_col] = 0 remain_row += 1 - # printf("The end of the pack a, comm id: %d, work id: %d\n", - # comm_id_packa, work_id_packa) + printf("The end of the pack a, comm id: %d, work id: %d\n", + comm_id_packa, work_id_packa) @hidet.script def gemm_pack_b( From 2d650058e99e7a54d01474077e75c12f02c97963 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 02:52:15 -0400 Subject: [PATCH 046/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index a567eb9f4..d12c53ef5 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -886,7 +886,7 @@ def gemm_3rd_loop( work_id_3rd_loop: int32, is_first: bool): - printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) + # printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) comm_id_macro = work_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) @@ -962,7 +962,7 @@ def gemm_3rd_loop( work_id_macro, is_first ) - printf("The end of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) + # printf("The end of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) gemm_3rd_loop.kind = "cpu_internal" From 747508bf6dd2fd02ae0cf0ed1c533132d6ce8fc6 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 02:58:02 -0400 Subject: [PATCH 047/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index d12c53ef5..de0c8cf35 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -888,7 +888,7 @@ def gemm_3rd_loop( # printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) - comm_id_macro = work_id_3rd_loop % macro_nthreads + comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) comm_id_packa = comm_id_macro work_id_packa = comm_id_macro From 4e5c7daa4019e689dd25a923b5bfb8632174ccb1 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 03:17:53 -0400 Subject: [PATCH 048/148] . --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index de0c8cf35..e4c972363 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -587,30 +587,31 @@ def gemm_pack_a( res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) - avx_f32x8_store_aligned( + # TODO: Now I changed to unaligned to debug... + avx_f32x8_store( ~packed_a_tensor[a_curr_panel_row_start, col], res0 ) - avx_f32x8_store_aligned( + avx_f32x8_store( ~packed_a_tensor[a_curr_panel_row_start + 2, col + 1], res2 ) - avx_f32x8_store_aligned( + avx_f32x8_store( ~packed_a_tensor[a_curr_panel_row_start + 4, col + 2], res4) - avx_f32x8_store_aligned( + avx_f32x8_store( ~packed_a_tensor[a_curr_panel_row_start, col + 4], res1 ) - avx_f32x8_store_aligned( + avx_f32x8_store( ~packed_a_tensor[a_curr_panel_row_start + 2, col + 5], res3 ) - avx_f32x8_store_aligned( + avx_f32x8_store( ~packed_a_tensor[a_curr_panel_row_start + 4, col + 6], res5 From 23f2768b82e4841280e016ef24fe4e4956278754 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 03:25:08 -0400 Subject: [PATCH 049/148] try single thread first --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index e4c972363..547fa707a 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -84,9 +84,9 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], nthreads=[8, 16]) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) + self, MC=2016, NC=896, KC=512, ways=(1, 1, 1, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type From 0ab48882401bed4253fc48a3f308c1687328eb2b Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 06:14:04 -0400 Subject: [PATCH 050/148] .. --- python/hidet/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py index d5b37518d..5ad168ee5 100644 --- a/python/hidet/mat_new.py +++ b/python/hidet/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(100, 100, 100), (384, 256, 256), (512, 512, 512), (1024, 1024, 1024)]: +for m, n, k in [(384, 256, 256)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 1631d77ee7acdc62d2b0c291650c8c40e59f2722 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 06:26:02 -0400 Subject: [PATCH 051/148] dumb mistake again --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 547fa707a..0f5f4c0d3 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -963,6 +963,7 @@ def gemm_3rd_loop( work_id_macro, is_first ) + ii += b_alg_loop3 # printf("The end of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) gemm_3rd_loop.kind = "cpu_internal" From 62c075c27dc7cea412597832039f933fb7e5e884 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 06:30:47 -0400 Subject: [PATCH 052/148] .. --- .../hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 0f5f4c0d3..d9aba64c2 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1068,10 +1068,14 @@ def gemm_5th_loop(a: float32[m_size, k_size], loop5_my_end = -1 thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) + + printf("loop5_my_start: %d, loop5_my_end: %d\n", loop5_my_start, loop5_my_end) + loop5_iter = loop5_my_start while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) + printf("b_alg_loop5: %d\n", b_alg_loop5) loop5_partition_c_width = b_alg_loop5 loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, @@ -1091,10 +1095,6 @@ def gemm_5th_loop(a: float32[m_size, k_size], @hidet.script def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size]): - b_width_nr_partitions = (n_size + NR - 1) // NR - b_width_nr_remainder = n_size % NR - # TODO: Since we(they, BLIS) use a memory broker... Allocate a little more memory is OK I think??? - # packed_b_individual_width = NC init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, From 5d4a3149b52751a123394835b99cf1d679efe1d1 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Sun, 27 Aug 2023 10:17:28 -0400 Subject: [PATCH 053/148] .. --- .../ops/matmul/matmul_f32_x86_refactored.py | 158 ++---------------- 1 file changed, 11 insertions(+), 147 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index d9aba64c2..8c0fdcc8f 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -236,8 +236,8 @@ def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: # Thread barrier @hidet.script def thrcomm_barrier(tid: int32, barrier_sense: ~int32, - barrier_threads_arrived: ~int32, nthreads: int32): - if nthreads == 1: + barrier_threads_arrived: ~int32, num_threads: int32): + if num_threads == 1: return orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED @@ -245,7 +245,7 @@ def thrcomm_barrier(tid: int32, barrier_sense: ~int32, my_threads_arrived = cpu_atomic_add_fetch( barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL - if my_threads_arrived == nthreads: + if my_threads_arrived == num_threads: barrier_threads_arrived[0] = 0 cpu_atomic_fetch_xor(barrier_sense, 1, 3) # _ATOMIC_RELEASE else: @@ -290,13 +290,11 @@ def micro_kernel( a_ptr = cast(a, ~float32) b_ptr = cast(b, ~float32) - niters = msize // 4 - nleft = msize % 4 - # Outer iterations with step 4 - for _ in range(niters): - # First of the 4 unrolled iterations + # TODO: For now, let's forget about unrolling for now. + for _ in range(pb): bb0to7 = avx_f32x8_load_aligned(b_ptr) bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) + b_ptr = b_ptr + 16 aa = avx_f32x8_broadcast(a_ptr) c0 = avx_f32x8_fmadd(aa, bb0to7, c0) @@ -322,125 +320,7 @@ def micro_kernel( c5 = avx_f32x8_fmadd(aa, bb0to7, c5) c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - # Second of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr + 16) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 24) - - aa = avx_f32x8_broadcast(a_ptr + 6) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 7) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 8) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 9) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 10) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 11) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Third of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr + 32) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 40) - - aa = avx_f32x8_broadcast(a_ptr + 12) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 13) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 14) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 15) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 16) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 17) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Fourth of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr + 48) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 56) - - aa = avx_f32x8_broadcast(a_ptr + 18) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 19) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 20) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 21) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 22) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 23) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Increment the a_ptr and b_ptr for the next iteration of the outermost loop - a_ptr += 24 - b_ptr += 64 - - # process the edge - for _ in range(nleft): - aa = avx_f32x8_broadcast(a_ptr) - bb0to7 = avx_f32x8_load_aligned(b_ptr) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 1) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 2) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 3) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 4) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 5) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - a_ptr += 6 - b_ptr += 16 + a_ptr = a_ptr + 6 # Store the results avx_f32x8_store(c_ptr, c0) @@ -462,8 +342,6 @@ def micro_kernel( avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) # printf("The end of micro kernel....") - - #### Some setup code #### packed_b_total_width = 0 @@ -472,18 +350,15 @@ def micro_kernel( packed_b_height = k_size packed_b_width = NC if packed_b_width > n_size: - packed_b_widht = (n_size + NR - 1) // NR * NR + packed_b_width = (n_size + NR - 1) // NR * NR packed_b_total_width = packed_b_width * loop5_nways packed_b_total_size = packed_b_total_width * packed_b_height packed_b_individual_size = packed_b_width * packed_b_height - a_height_mr_partitions = (m_size + MR - 1) // MR - a_height_mr_remainder = m_size % MR packed_a_individual_height = MC packed_a_total_height = packed_a_individual_height * loop3_nways - # if packed_a_total_height > m_size: - # packed_a_total_height = a_height_mr_partitions * MR + packed_a_width = KC if packed_a_width > k_size: packed_a_width = (k_size + MR - 1) // MR * MR @@ -493,8 +368,6 @@ def micro_kernel( packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) - # packb_buf = as_tensor_pointer(packb_buf_ptr, dtype=float32, shape=[packed_b_total_size]) - # packa_buf = as_tensor_pointer(packa_buf_ptr, dtype=float32, shape=[packed_a_total_size]) packb_buf = cast(packb_buf_ptr, ~float32) packa_buf = cast(packa_buf_ptr, ~float32) @@ -505,15 +378,6 @@ def micro_kernel( ##### Start of the loops around micro kernel ##### - # gemm_macro(packed_a_buf, - # packed_b, - # c, - # loop3_partition_a_height, - # loop3_partition_b_width, - # loop3_partition_a_width, - # comm_id_macro, - # work_id_macro - # ) @hidet.script def gemm_pack_a( @@ -636,7 +500,7 @@ def gemm_pack_a( remain_row = panel_a_remainder while remain_row < MR: packed_a_tensor[ - remain_start_row + remain_row, remain_col] = 0 + remain_start_row + remain_row, remain_col] = 0.0 remain_row += 1 printf("The end of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) @@ -731,7 +595,7 @@ def gemm_pack_b( packed_b_buff_curr += 16 - row = k_iters + 8 + row = k_iters * 8 for _ in range(k_remainder): b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) From e30ab31daa5864fcee4916ba26091cc93f325be3 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 03:48:02 -0400 Subject: [PATCH 054/148] keep debugging --- python/hidet/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py index 5ad168ee5..19e67a438 100644 --- a/python/hidet/mat_new.py +++ b/python/hidet/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(384, 256, 256)]: +for m, n, k in [(128, 128, 128)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 134a1d50fcf5378b12f531738e5bef2de2bbc3ab Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 03:50:25 -0400 Subject: [PATCH 055/148] .. --- python/mat_new.py | 97 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 python/mat_new.py diff --git a/python/mat_new.py b/python/mat_new.py new file mode 100644 index 000000000..19e67a438 --- /dev/null +++ b/python/mat_new.py @@ -0,0 +1,97 @@ +import numpy as np +import pytest + +import hidet +from hidet.graph.ops import matmul_x86_refactored +from hidet.testing import check_binary +from hidet.option import debug_cache_tuning + +import tvm +from tvm import te, auto_scheduler + +@auto_scheduler.register_workload +def matmul_ansor(M, K, N, dtype): + A = te.placeholder((M, K), name='A', dtype=dtype) + B = te.placeholder((K, N), name='B', dtype=dtype) + + k = te.reduce_axis((0, K), name='k') + rst = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), + name='matmul_ansor', + attrs={"layout_free_placeholders": [B], + # Enable automatic layout transform for B} + } + ) + + return [A, B, rst] +hidet.option.cache_dir("./wtf") + +target = tvm.target.Target("llvm -mcpu=core-avx2") +debug_cache_tuning(True) +hidet.option.search_space(0) +for m, n, k in [(128, 128, 128)]: + a = hidet.randn([m, k], device='cpu') + b = hidet.randn([k, n], device='cpu') + x1 = hidet.symbol_like(a) + x2 = hidet.symbol_like(b) + y = matmul_x86_refactored(x1, x2) + graph = hidet.trace_from( + y, inputs=[x1, x2] + ) + opt_graph = hidet.graph.optimize(graph) + compiled_func = opt_graph.nodes[0].compiled_task + c = compiled_func(a, b) + np.testing.assert_allclose( + actual=c.numpy(), + desired=a.numpy() @ b.numpy(), + rtol=1e-3, + atol=1e-3 + ) + + print("passed for m={}, n={}, k={}".format(m, n, k)) + + hidet_latency = hidet.utils.benchmark_func( + lambda: compiled_func(a, b), repeat=50 + ) + np_latency = hidet.utils.benchmark_func( + lambda: a.numpy() @ b.numpy(), repeat=50 + ) + + ansor_task = tvm.auto_scheduler.SearchTask( + func=matmul_ansor, args=(m, k, n, "float32"), target=target + ) + log_file = f"matmul_{m}x{k}x{n}.json" + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=1000, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + ) + + ansor_task.tune(tune_option) + sch, args = ansor_task.apply_best(log_file) + with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: + f.write(str(tvm.lower(sch, args, simple_mode=True))) + ansor_func = tvm.build(sch, args, target) + dev = tvm.cpu() + a_tvm = tvm.nd.array(a.numpy(), device=dev) + b_tvm = tvm.nd.array(b.numpy(), device=dev) + c_tvm = tvm.nd.empty((m, n), device=dev) + + ansor_func(a_tvm, b_tvm, c_tvm) + + np.testing.assert_allclose( + actual=c_tvm.numpy(), + desired=a_tvm.numpy() @ b_tvm.numpy(), + rtol=1e-3, + atol=1e-3 + ) + + ansor_latency = hidet.utils.benchmark_func( + lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 + ) + + with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: + f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") + f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") + f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") From e1e2d290ec8bc08b36a31a7e06eb2d2dd05a50e0 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 04:02:20 -0400 Subject: [PATCH 056/148] .. --- .../ops/matmul/matmul_f32_x86_refactored.py | 7 ++ python/hidet/mat_new.py | 97 ------------------- 2 files changed, 7 insertions(+), 97 deletions(-) delete mode 100644 python/hidet/mat_new.py diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 8c0fdcc8f..1e84137d4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -853,10 +853,17 @@ def gemm_4th_loop(a: float32[m_size, k_size], while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) + + printf("i_loop4: %d\n", i_loop4) + loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width loop4_partition_b_start_row = i_loop4 loop4_partition_b_start_col = loop5_partition_b_start_col + printf("loop4_partition_b_height: %d\n", loop4_partition_b_height) + printf("loop4_partition_b_width: %d\n", loop4_partition_b_width) + printf("loop4_partition_b_start_row: %d\n", loop4_partition_b_start_row) + printf("loop4_partition_b_start_col: %d\n", loop4_partition_b_start_col) loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) diff --git a/python/hidet/mat_new.py b/python/hidet/mat_new.py deleted file mode 100644 index 19e67a438..000000000 --- a/python/hidet/mat_new.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np -import pytest - -import hidet -from hidet.graph.ops import matmul_x86_refactored -from hidet.testing import check_binary -from hidet.option import debug_cache_tuning - -import tvm -from tvm import te, auto_scheduler - -@auto_scheduler.register_workload -def matmul_ansor(M, K, N, dtype): - A = te.placeholder((M, K), name='A', dtype=dtype) - B = te.placeholder((K, N), name='B', dtype=dtype) - - k = te.reduce_axis((0, K), name='k') - rst = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), - name='matmul_ansor', - attrs={"layout_free_placeholders": [B], - # Enable automatic layout transform for B} - } - ) - - return [A, B, rst] -hidet.option.cache_dir("./wtf") - -target = tvm.target.Target("llvm -mcpu=core-avx2") -debug_cache_tuning(True) -hidet.option.search_space(0) -for m, n, k in [(128, 128, 128)]: - a = hidet.randn([m, k], device='cpu') - b = hidet.randn([k, n], device='cpu') - x1 = hidet.symbol_like(a) - x2 = hidet.symbol_like(b) - y = matmul_x86_refactored(x1, x2) - graph = hidet.trace_from( - y, inputs=[x1, x2] - ) - opt_graph = hidet.graph.optimize(graph) - compiled_func = opt_graph.nodes[0].compiled_task - c = compiled_func(a, b) - np.testing.assert_allclose( - actual=c.numpy(), - desired=a.numpy() @ b.numpy(), - rtol=1e-3, - atol=1e-3 - ) - - print("passed for m={}, n={}, k={}".format(m, n, k)) - - hidet_latency = hidet.utils.benchmark_func( - lambda: compiled_func(a, b), repeat=50 - ) - np_latency = hidet.utils.benchmark_func( - lambda: a.numpy() @ b.numpy(), repeat=50 - ) - - ansor_task = tvm.auto_scheduler.SearchTask( - func=matmul_ansor, args=(m, k, n, "float32"), target=target - ) - log_file = f"matmul_{m}x{k}x{n}.json" - tune_option = auto_scheduler.TuningOptions( - num_measure_trials=1000, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - verbose=2, - ) - - ansor_task.tune(tune_option) - sch, args = ansor_task.apply_best(log_file) - with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: - f.write(str(tvm.lower(sch, args, simple_mode=True))) - ansor_func = tvm.build(sch, args, target) - dev = tvm.cpu() - a_tvm = tvm.nd.array(a.numpy(), device=dev) - b_tvm = tvm.nd.array(b.numpy(), device=dev) - c_tvm = tvm.nd.empty((m, n), device=dev) - - ansor_func(a_tvm, b_tvm, c_tvm) - - np.testing.assert_allclose( - actual=c_tvm.numpy(), - desired=a_tvm.numpy() @ b_tvm.numpy(), - rtol=1e-3, - atol=1e-3 - ) - - ansor_latency = hidet.utils.benchmark_func( - lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 - ) - - with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: - f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") - f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") - f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") From 7a7ff5e7ebe69e9b67859a20bfe240128b61f936 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 04:42:35 -0400 Subject: [PATCH 057/148] . --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 1e84137d4..79d95accc 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -520,6 +520,9 @@ def gemm_pack_b( npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR + printf("Start of the packing of B...") + printf("packed_b_height: %d", packed_b_height) + printf("packedb_panel_stride: %d\n", packedb_panel_stride) # Loop for the packing of B for i_panel in range(npanels_b): @@ -530,12 +533,21 @@ def gemm_pack_b( curr_panel_start = i_panel * NR curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) + + printf("i_panel: %d\n", i_panel) + printf("curr_panel_start: %d\n", curr_panel_start) + printf("curr_panel_width: %d\n", curr_panel_width) + if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 k_remainder = loop4_partition_b_height % 8 + + printf("k_iters: %d\n", k_iters) + printf("k_remainder: %d\n", k_remainder) row = 0 for k_iter in range(k_iters): row = k_iter * 8 + printf('row: %d\n', row) b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) @@ -596,6 +608,7 @@ def gemm_pack_b( packed_b_buff_curr += 16 row = k_iters * 8 + printf("After the unrolled-by-8 loop, row: %d\n", row) for _ in range(k_remainder): b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) From 29de46f8ff8b595754a70bf38f58fae0eff2bd86 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 04:57:17 -0400 Subject: [PATCH 058/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 2 -- python/mat_new.py | 10 ++++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 79d95accc..25a6f01d6 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -908,8 +908,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_b_height, packed_b_buf, comm_id_packb, work_id_packb, loop3_nthreads) - - # The barrier at the end of the packing of B thrcomm_barrier( comm_id_packb, diff --git a/python/mat_new.py b/python/mat_new.py index 19e67a438..5a663ca3b 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -42,6 +42,16 @@ def matmul_ansor(M, K, N, dtype): opt_graph = hidet.graph.optimize(graph) compiled_func = opt_graph.nodes[0].compiled_task c = compiled_func(a, b) + + actual = c.numpy() + desired = a.numpy() @ b.numpy() + + for i in range(m): + for j in range(n): + if abs(actual[i, j] - desired[i, j]) > 1e-3: + print(f"mismatch at {i}, {j}") + + np.testing.assert_allclose( actual=c.numpy(), desired=a.numpy() @ b.numpy(), From ca9e67d5d9def2662c85afd316c98c5599036587 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 09:04:54 -0400 Subject: [PATCH 059/148] ... --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++++ python/mat_new.py | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 25a6f01d6..27610ad65 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -782,11 +782,14 @@ def gemm_3rd_loop( ~m_start_loop3, ~m_end_loop3 ) + printf("In loop 3: m_start_loop3: %d, m_end_loop3: %d\n", m_start_loop3, m_end_loop3) ii = m_start_loop3 while ii < m_end_loop3: b_alg_loop3 = determine_blocksize_f_sub( ii, m_size, MC ) + printf("The ii in loop3: %d\n", ii) + printf("b_alg_loop3: %d\n", b_alg_loop3) # Acquire the partition at loop 3 loop3_partition_a_start_row = ii loop3_partition_a_height = b_alg_loop3 @@ -893,6 +896,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) + # TODO: If passed, see if this barrier is really needed thrcomm_barrier( comm_id_packb, diff --git a/python/mat_new.py b/python/mat_new.py index 5a663ca3b..2ea2d8654 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -45,16 +45,16 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() - - for i in range(m): - for j in range(n): - if abs(actual[i, j] - desired[i, j]) > 1e-3: - print(f"mismatch at {i}, {j}") + # + # for i in range(m): + # for j in range(n): + # if abs(actual[i, j] - desired[i, j]) > 1e-3: + # print(f"mismatch at {i}, {j}") np.testing.assert_allclose( - actual=c.numpy(), - desired=a.numpy() @ b.numpy(), + actual=actual, + desired=desired, rtol=1e-3, atol=1e-3 ) From 43d4a60346a37d6643f02aa43d3a0fbcecf18be9 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 09:49:36 -0400 Subject: [PATCH 060/148] ..: --- .../ops/matmul/matmul_f32_x86_refactored.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 27610ad65..19acf876b 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -870,28 +870,29 @@ def gemm_4th_loop(a: float32[m_size, k_size], while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) - printf("i_loop4: %d\n", i_loop4) + # printf("i_loop4: %d\n", i_loop4) loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width loop4_partition_b_start_row = i_loop4 loop4_partition_b_start_col = loop5_partition_b_start_col - printf("loop4_partition_b_height: %d\n", loop4_partition_b_height) - printf("loop4_partition_b_width: %d\n", loop4_partition_b_width) - printf("loop4_partition_b_start_row: %d\n", loop4_partition_b_start_row) - printf("loop4_partition_b_start_col: %d\n", loop4_partition_b_start_col) + # printf("loop4_partition_b_height: %d\n", loop4_partition_b_height) + # printf("loop4_partition_b_width: %d\n", loop4_partition_b_width) + # printf("loop4_partition_b_start_row: %d\n", loop4_partition_b_start_row) + # printf("loop4_partition_b_start_col: %d\n", loop4_partition_b_start_col) loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) + # Get the thread's partition of the buffer and the matrix # packed_b_buf = packb_buf + ( # packb_start_offsets[work_id_5th_loop, 0] * packed_b_height # ) + packed_b_buf = packb_buf + ( packed_b_individual_size * work_id_5th_loop ) - loop4_partition_b = cast(b, ~float32) + \ (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) @@ -905,6 +906,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], packb_nthreads ) + # Start the packing of B # TODO: Check this assertion: # TODO: loop3_nthreads == packb_nthreads @@ -1003,8 +1005,6 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], ir_module = module.ir_module() return ir_module - # return ir_module - class Matmulx86Op_refactored(Operator): def __init__(self, a: Tensor, b: Tensor): From 3d67673a512837d48b5f380c4aab8fe30ed3f7da Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 11:01:26 -0400 Subject: [PATCH 061/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 19acf876b..f99dcbdf3 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -764,7 +764,7 @@ def gemm_3rd_loop( work_id_3rd_loop: int32, is_first: bool): - # printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) + printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) From 3c9d7928b75011ebd9851c7c838fe35d1aefe74a Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 23:00:17 -0400 Subject: [PATCH 062/148] .. --- .../ops/matmul/matmul_f32_x86_refactored.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index f99dcbdf3..169daccf8 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -520,9 +520,9 @@ def gemm_pack_b( npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR - printf("Start of the packing of B...") - printf("packed_b_height: %d", packed_b_height) - printf("packedb_panel_stride: %d\n", packedb_panel_stride) + # printf("Start of the packing of B...") + # printf("packed_b_height: %d", packed_b_height) + # printf("packedb_panel_stride: %d\n", packedb_panel_stride) # Loop for the packing of B for i_panel in range(npanels_b): @@ -534,9 +534,9 @@ def gemm_pack_b( curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) - printf("i_panel: %d\n", i_panel) - printf("curr_panel_start: %d\n", curr_panel_start) - printf("curr_panel_width: %d\n", curr_panel_width) + # printf("i_panel: %d\n", i_panel) + # printf("curr_panel_start: %d\n", curr_panel_start) + # printf("curr_panel_width: %d\n", curr_panel_width) if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 @@ -547,7 +547,7 @@ def gemm_pack_b( row = 0 for k_iter in range(k_iters): row = k_iter * 8 - printf('row: %d\n', row) + # printf('row: %d\n', row) b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) @@ -957,13 +957,13 @@ def gemm_5th_loop(a: float32[m_size, k_size], thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) - printf("loop5_my_start: %d, loop5_my_end: %d\n", loop5_my_start, loop5_my_end) + # printf("loop5_my_start: %d, loop5_my_end: %d\n", loop5_my_start, loop5_my_end) loop5_iter = loop5_my_start while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) - printf("b_alg_loop5: %d\n", b_alg_loop5) + # printf("b_alg_loop5: %d\n", b_alg_loop5) loop5_partition_c_width = b_alg_loop5 loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, From 6782047051da4bfab6d4d2f9d6099da5f1aead0f Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Mon, 28 Aug 2023 23:12:46 -0400 Subject: [PATCH 063/148] . --- .../ops/matmul/matmul_f32_x86_refactored.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 169daccf8..aad824d59 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -542,8 +542,8 @@ def gemm_pack_b( k_iters = loop4_partition_b_height // 8 k_remainder = loop4_partition_b_height % 8 - printf("k_iters: %d\n", k_iters) - printf("k_remainder: %d\n", k_remainder) + # printf("k_iters: %d\n", k_iters) + # printf("k_remainder: %d\n", k_remainder) row = 0 for k_iter in range(k_iters): row = k_iter * 8 @@ -608,7 +608,7 @@ def gemm_pack_b( packed_b_buff_curr += 16 row = k_iters * 8 - printf("After the unrolled-by-8 loop, row: %d\n", row) + # printf("After the unrolled-by-8 loop, row: %d\n", row) for _ in range(k_remainder): b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) @@ -671,6 +671,12 @@ def gemm_macro( m_iter = macro_m // MR m_remainder = macro_m % MR + printf("The start of the macro kernel.\n") + printf("n_iter: %d\n", n_iter) + printf("n_remainder: %d\n", n_remainder) + printf("m_iter: %d\n", m_iter) + printf("m_remainder: %d\n", m_remainder) + if n_remainder > 0: n_iter += 1 if m_remainder > 0: @@ -703,6 +709,14 @@ def gemm_macro( ~ir_inc ) + printf("jr_start: %d\n", jr_start) + printf("jr_end: %d\n", jr_end) + printf("jr_inc: %d\n", jr_inc) + + printf("ir_start: %d\n", ir_start) + printf("ir_end: %d\n", ir_end) + printf("ir_inc: %d\n", ir_inc) + rs_packeda = 1 rstep_a = ps_packed_a cstep_b = ps_packed_b @@ -710,6 +724,11 @@ def gemm_macro( cstep_c = NR rstep_c = n_size * MR + printf("rstep_a: %d\n", rstep_a) + printf("cstep_b: %d\n", cstep_b) + printf("cstep_c: %d\n", cstep_c) + printf("rstep_c: %d\n", rstep_c) + macro_c_cast = as_tensor_pointer( ~c[c_row_off, c_col_off], dtype=float32, From 9401c1e2cdf026846e430d0f7af036313640253e Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 02:14:15 -0400 Subject: [PATCH 064/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index aad824d59..c4c2a7361 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -749,6 +749,7 @@ def gemm_macro( while i < ir_end: a1 = packed_a + i * rstep_a c11 = c1 + i * rstep_a + c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder if m_cur == MR and n_cur == NR: From e655035ed42350348710c971737be90de9e55bb9 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 02:18:48 -0400 Subject: [PATCH 065/148] .. --- .../hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index c4c2a7361..f009ff944 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -676,6 +676,14 @@ def gemm_macro( printf("n_remainder: %d\n", n_remainder) printf("m_iter: %d\n", m_iter) printf("m_remainder: %d\n", m_remainder) + printf("c_row_off: %d\n", c_row_off) + printf("c_col_off: %d\n", c_col_off) + printf("macro_m: %d\n", macro_m) + printf("macro_n: %d\n", macro_n) + printf("macro_k: %d\n", macro_k) + printf("ps_packed_a: %d\n", ps_packed_a) + printf("ps_packed_b: %d\n", ps_packed_b) + if n_remainder > 0: n_iter += 1 From 4c7ed707db2f0161db69492324a6339cf0e3398b Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 04:29:14 -0400 Subject: [PATCH 066/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index f009ff944..0f3b6eab3 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -396,15 +396,23 @@ def gemm_pack_a( layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) ) + printf("pack a: packed_a_individual_height: %d, packed_a_width: %d\n", packed_a_individual_height, + packed_a_width) npanels_full_a = loop3_partition_a_height // MR panel_a_remainder = loop3_partition_a_height % MR + printf("loop3_partition_a_height: %d\n", loop3_partition_a_height) + printf("npanels_full_a: %d\n", npanels_full_a) + assert panel_a_remainder == 0 # TODO: remove after debugging + npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) for ii_panel in range(npanels_a): if ii_panel % packa_nways != work_id_packa % packa_nways: continue + printf("ii_panel: %d\n", ii_panel) + printf("packa_nways: %d\n", packa_nways) a_curr_panel_row_start = ii_panel * MR a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) @@ -490,7 +498,7 @@ def gemm_pack_a( loop3_partition_a[( micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] else: - remain_start_row = npanels_a * MR + remain_start_row = npanels_full_a * MR for remain_col in range(loop3_partition_a_width): for remain_row in range(panel_a_remainder): packed_a_tensor[ From 21978bb1ecb4d7d4e23b55b0379a2abb372d6575 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 04:31:45 -0400 Subject: [PATCH 067/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 0f3b6eab3..174755a8d 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -405,7 +405,6 @@ def gemm_pack_a( printf("loop3_partition_a_height: %d\n", loop3_partition_a_height) printf("npanels_full_a: %d\n", npanels_full_a) - assert panel_a_remainder == 0 # TODO: remove after debugging npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) for ii_panel in range(npanels_a): From c90991f8563411fa26510bd8c30caf1d96ec7d3d Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 04:33:06 -0400 Subject: [PATCH 068/148] .. --- python/mat_new.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/mat_new.py b/python/mat_new.py index 2ea2d8654..018f93c1e 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -45,11 +45,11 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() - # - # for i in range(m): - # for j in range(n): - # if abs(actual[i, j] - desired[i, j]) > 1e-3: - # print(f"mismatch at {i}, {j}") + + for i in range(m): + for j in range(n): + if abs(actual[i, j] - desired[i, j]) < 1e-3: + print(f"actually match at {i}, {j}") np.testing.assert_allclose( From 7c3ef0ab6eba92777032c78a88e1becc7b8df0d8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 04:52:40 -0400 Subject: [PATCH 069/148] continue fixing --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 174755a8d..b84fef181 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -389,7 +389,7 @@ def gemm_pack_a( work_id_packa: int32, packa_nways: int32 ): - printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) + # printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, @@ -410,8 +410,7 @@ def gemm_pack_a( for ii_panel in range(npanels_a): if ii_panel % packa_nways != work_id_packa % packa_nways: continue - printf("ii_panel: %d\n", ii_panel) - printf("packa_nways: %d\n", packa_nways) + a_curr_panel_row_start = ii_panel * MR a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) @@ -509,8 +508,8 @@ def gemm_pack_a( packed_a_tensor[ remain_start_row + remain_row, remain_col] = 0.0 remain_row += 1 - printf("The end of the pack a, comm id: %d, work id: %d\n", - comm_id_packa, work_id_packa) + # printf("The end of the pack a, comm id: %d, work id: %d\n", + # comm_id_packa, work_id_packa) @hidet.script def gemm_pack_b( @@ -763,7 +762,7 @@ def gemm_macro( i = ir_start while i < ir_end: a1 = packed_a + i * rstep_a - c11 = c1 + i * rstep_a + c11 = c1 + i * rstep_c c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder From 4acf6c085ba626ecb98f3785e751bbb25556f1fe Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:05:20 -0400 Subject: [PATCH 070/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++++ python/mat_new.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index b84fef181..f3cdb1ea8 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -756,6 +756,8 @@ def gemm_macro( while j < jr_end: b1 = packed_b + j * cstep_b c1 = macro_c_cast + j * cstep_c + printf("j = %d\n", j) + printf("The offset j * cstep_c: %d\n", j * cstep_c) n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder # Loop over the m dimension, MR rows at a time @@ -763,6 +765,8 @@ def gemm_macro( while i < ir_end: a1 = packed_a + i * rstep_a c11 = c1 + i * rstep_c + printf("The offset i * rstep_a: %d\n", i * rstep_a) + printf("The offset i * rstep_c: %d\n", i * rstep_c) c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder diff --git a/python/mat_new.py b/python/mat_new.py index 018f93c1e..66a83b8d9 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(128, 128, 128)]: +for m, n, k in [(64, 64, 64)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From c740a3a3d6aed0e81a02b8dfd5eeea64ab624d1b Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:07:29 -0400 Subject: [PATCH 071/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index f3cdb1ea8..4d82de371 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -757,16 +757,17 @@ def gemm_macro( b1 = packed_b + j * cstep_b c1 = macro_c_cast + j * cstep_c printf("j = %d\n", j) - printf("The offset j * cstep_c: %d\n", j * cstep_c) + printf("The offset j * cstep_c: %d\n\n", j * cstep_c) n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder # Loop over the m dimension, MR rows at a time i = ir_start while i < ir_end: + printf("i = %d\n", i) a1 = packed_a + i * rstep_a c11 = c1 + i * rstep_c printf("The offset i * rstep_a: %d\n", i * rstep_a) - printf("The offset i * rstep_c: %d\n", i * rstep_c) + printf("The offset i * rstep_c: %d\n\n", i * rstep_c) c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder From 8f0ee0e5c70a192a4837cc52c2b0d57d8983dbc1 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:33:25 -0400 Subject: [PATCH 072/148] ... --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 4d82de371..9bc121fd1 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -876,7 +876,7 @@ def gemm_3rd_loop( loop3_partition_a_height, loop3_partition_b_width, loop3_partition_a_width, - MR * loop3_partition_a_width, + MR * k_size, packed_b_height * NR, comm_id_macro, work_id_macro, From 01e84ecc6a5c8a19d4c274d85730b5776723da65 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:35:36 -0400 Subject: [PATCH 073/148] ... --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9bc121fd1..ff2af2a17 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -876,7 +876,7 @@ def gemm_3rd_loop( loop3_partition_a_height, loop3_partition_b_width, loop3_partition_a_width, - MR * k_size, + MR * packed_a_width, packed_b_height * NR, comm_id_macro, work_id_macro, From 90505e7c6b085b828c443a9e469514655c9ad78c Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:46:15 -0400 Subject: [PATCH 074/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index ff2af2a17..04fa042fe 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -776,7 +776,7 @@ def gemm_macro( else: for i, j in grid(MR, NR): temp_c[i, j] = 0.0 - micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, macro_m, macro_n, is_first) + micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) if not is_first: for mm, nn in grid(m_cur, n_cur): c11[mm, nn] += temp_c[mm, nn] From 805959e893921f95b6c9861f565475b1c811ce76 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:48:19 -0400 Subject: [PATCH 075/148] ... --- python/mat_new.py | 95 ++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 50 deletions(-) diff --git a/python/mat_new.py b/python/mat_new.py index 66a83b8d9..bb43edc72 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(64, 64, 64)]: +for m, n, k in [(64, 64, 64), (444, 3072, 768), (768, 768, 768), (123, 456, 789)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) @@ -46,11 +46,6 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() - for i in range(m): - for j in range(n): - if abs(actual[i, j] - desired[i, j]) < 1e-3: - print(f"actually match at {i}, {j}") - np.testing.assert_allclose( actual=actual, @@ -61,47 +56,47 @@ def matmul_ansor(M, K, N, dtype): print("passed for m={}, n={}, k={}".format(m, n, k)) - hidet_latency = hidet.utils.benchmark_func( - lambda: compiled_func(a, b), repeat=50 - ) - np_latency = hidet.utils.benchmark_func( - lambda: a.numpy() @ b.numpy(), repeat=50 - ) - - ansor_task = tvm.auto_scheduler.SearchTask( - func=matmul_ansor, args=(m, k, n, "float32"), target=target - ) - log_file = f"matmul_{m}x{k}x{n}.json" - tune_option = auto_scheduler.TuningOptions( - num_measure_trials=1000, - measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - verbose=2, - ) - - ansor_task.tune(tune_option) - sch, args = ansor_task.apply_best(log_file) - with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: - f.write(str(tvm.lower(sch, args, simple_mode=True))) - ansor_func = tvm.build(sch, args, target) - dev = tvm.cpu() - a_tvm = tvm.nd.array(a.numpy(), device=dev) - b_tvm = tvm.nd.array(b.numpy(), device=dev) - c_tvm = tvm.nd.empty((m, n), device=dev) - - ansor_func(a_tvm, b_tvm, c_tvm) - - np.testing.assert_allclose( - actual=c_tvm.numpy(), - desired=a_tvm.numpy() @ b_tvm.numpy(), - rtol=1e-3, - atol=1e-3 - ) - - ansor_latency = hidet.utils.benchmark_func( - lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 - ) - - with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: - f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") - f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") - f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") + # hidet_latency = hidet.utils.benchmark_func( + # lambda: compiled_func(a, b), repeat=50 + # ) + # np_latency = hidet.utils.benchmark_func( + # lambda: a.numpy() @ b.numpy(), repeat=50 + # ) + # + # ansor_task = tvm.auto_scheduler.SearchTask( + # func=matmul_ansor, args=(m, k, n, "float32"), target=target + # ) + # log_file = f"matmul_{m}x{k}x{n}.json" + # tune_option = auto_scheduler.TuningOptions( + # num_measure_trials=1000, + # measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + # verbose=2, + # ) + # + # ansor_task.tune(tune_option) + # sch, args = ansor_task.apply_best(log_file) + # with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: + # f.write(str(tvm.lower(sch, args, simple_mode=True))) + # ansor_func = tvm.build(sch, args, target) + # dev = tvm.cpu() + # a_tvm = tvm.nd.array(a.numpy(), device=dev) + # b_tvm = tvm.nd.array(b.numpy(), device=dev) + # c_tvm = tvm.nd.empty((m, n), device=dev) + # + # ansor_func(a_tvm, b_tvm, c_tvm) + # + # np.testing.assert_allclose( + # actual=c_tvm.numpy(), + # desired=a_tvm.numpy() @ b_tvm.numpy(), + # rtol=1e-3, + # atol=1e-3 + # ) + # + # ansor_latency = hidet.utils.benchmark_func( + # lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 + # ) + # + # with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: + # f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") + # f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") + # f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") From 8bb52d3ea83d1c9e053eab47e4f4df588eba1365 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:50:34 -0400 Subject: [PATCH 076/148] .. --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index bb43edc72..5c58b1fed 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(64, 64, 64), (444, 3072, 768), (768, 768, 768), (123, 456, 789)]: +for m, n, k in [(64, 64, 64), (768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 94abfa7dc8f701f225bc4dca8590a60a5b105ce8 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:56:41 -0400 Subject: [PATCH 077/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 7 ++++--- python/mat_new.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 04fa042fe..c66d7bd1b 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -343,8 +343,6 @@ def micro_kernel( # printf("The end of micro kernel....") #### Some setup code #### - packed_b_total_width = 0 - packed_b_height = KC if packed_b_height > k_size: packed_b_height = k_size @@ -352,6 +350,9 @@ def micro_kernel( if packed_b_width > n_size: packed_b_width = (n_size + NR - 1) // NR * NR + printf("packed_b_height: %d\n", packed_b_height) + printf("packed_b_width: %d\n", packed_b_width) + packed_b_total_width = packed_b_width * loop5_nways packed_b_total_size = packed_b_total_width * packed_b_height packed_b_individual_size = packed_b_width * packed_b_height @@ -361,7 +362,7 @@ def micro_kernel( packed_a_width = KC if packed_a_width > k_size: - packed_a_width = (k_size + MR - 1) // MR * MR + packed_a_width = k_size packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height diff --git a/python/mat_new.py b/python/mat_new.py index 5c58b1fed..1528feac7 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(64, 64, 64), (768, 768, 768)]: +for m, n, k in [(768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From a3f35dcbbc26fa7134948ce98200d4b4891f6729 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 05:57:00 -0400 Subject: [PATCH 078/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 1528feac7..5c58b1fed 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(768, 768, 768)]: +for m, n, k in [(64, 64, 64), (768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 230e6d0630c9d81c8c994c54a4e751a20d7e10d2 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 06:00:10 -0400 Subject: [PATCH 079/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index c66d7bd1b..35d01ec25 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -838,9 +838,11 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) + printf("Got the loop3_partition_a\n") # Get our position within the packed A global buffer packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + printf("Got the packed_a_buf\n") # TODO: If passed, see if this barrier is necessary thrcomm_barrier( From e3bf60aac68425b21343eca1fc254ba4db878e82 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 06:05:13 -0400 Subject: [PATCH 080/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 35d01ec25..9f30fac5a 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -839,6 +839,9 @@ def gemm_3rd_loop( loop3_partition_a_start_col ) printf("Got the loop3_partition_a\n") + printf("packed_a_individual_size: %d\n", packed_a_individual_size) + printf("work_id_packa: %d\n", work_id_packa) + printf("packed_a_total_size: %d\n", packed_a_total_size) # Get our position within the packed A global buffer packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) From e5e4466ca220ac62258bc996c01d28ba2a2f1e02 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 06:06:32 -0400 Subject: [PATCH 081/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9f30fac5a..9b0c78827 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -844,7 +844,8 @@ def gemm_3rd_loop( printf("packed_a_total_size: %d\n", packed_a_total_size) # Get our position within the packed A global buffer - packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + # packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + packed_a_buf = ~packa_buf[work_id_packa * packed_a_individual_size] printf("Got the packed_a_buf\n") # TODO: If passed, see if this barrier is necessary From 601e6b2cc58d3311570556d1914b0243d7b435ac Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 06:13:24 -0400 Subject: [PATCH 082/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9b0c78827..a611937d4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -845,6 +845,7 @@ def gemm_3rd_loop( # Get our position within the packed A global buffer # packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) + printf("work_id_packa * packed_a_individual_size: %d\n", work_id_packa * packed_a_individual_size) packed_a_buf = ~packa_buf[work_id_packa * packed_a_individual_size] printf("Got the packed_a_buf\n") From 2df735576ce9bf0d53a1286cd5458a5c2101db8f Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 06:14:49 -0400 Subject: [PATCH 083/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index a611937d4..4f250ca60 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -366,8 +366,8 @@ def micro_kernel( packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height - packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) - packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) + packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 64) + packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 64) packb_buf = cast(packb_buf_ptr, ~float32) packa_buf = cast(packa_buf_ptr, ~float32) From ee3007848f463ca490698548ad0ebee4eac51f4c Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 06:16:56 -0400 Subject: [PATCH 084/148] bruh --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 4f250ca60..4334b62b1 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -14,8 +14,8 @@ from hidet.ir.expr import cast from hidet.ir.module import IRModule from hidet.ir.compute import TensorNode -from hidet.ir.primitives import avx_malloc, printf -from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned +from hidet.ir.primitives import printf +from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned, avx_free, avx_malloc from hidet.ir.stmt import DeclareScope from hidet.ir.task import Task from hidet.ir.compute import compute, reduce @@ -366,8 +366,10 @@ def micro_kernel( packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height - packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 64) - packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 64) + packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) + packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) + + packb_buf = cast(packb_buf_ptr, ~float32) packa_buf = cast(packa_buf_ptr, ~float32) @@ -1046,6 +1048,8 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], comm_id_5th_loop = tid_5th_loop gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) + avx_free(packa_buf) + avx_free(packb_buf) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) matmul_kernel_x86_v3.kind = "cpu_kernel" From cb54a7ea2c31a9683eab47fa3be05073cebdf407 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 09:15:55 -0400 Subject: [PATCH 085/148] .. --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 5c58b1fed..b422416fc 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(64, 64, 64), (768, 768, 768)]: +for m, n, k in [(256, 256, 256)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 8e07dade7d183092217b61588b22b4bb04edbb61 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 09:16:38 -0400 Subject: [PATCH 086/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index b422416fc..96f1996ab 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(256, 256, 256)]: +for m, n, k in [(512, 512, 512)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 8723df660cd02ee0349af1ef7b8a6fd6f54fdb30 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 09:17:11 -0400 Subject: [PATCH 087/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 96f1996ab..1528feac7 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(512, 512, 512)]: +for m, n, k in [(768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 0919d12f0621a1232d3e8b366ae1649640fd4be5 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 09:23:24 -0400 Subject: [PATCH 088/148] .. --- .../ops/matmul/matmul_f32_x86_refactored.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 4334b62b1..23336bfd2 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -680,18 +680,18 @@ def gemm_macro( m_iter = macro_m // MR m_remainder = macro_m % MR - printf("The start of the macro kernel.\n") - printf("n_iter: %d\n", n_iter) - printf("n_remainder: %d\n", n_remainder) - printf("m_iter: %d\n", m_iter) - printf("m_remainder: %d\n", m_remainder) - printf("c_row_off: %d\n", c_row_off) - printf("c_col_off: %d\n", c_col_off) - printf("macro_m: %d\n", macro_m) - printf("macro_n: %d\n", macro_n) - printf("macro_k: %d\n", macro_k) - printf("ps_packed_a: %d\n", ps_packed_a) - printf("ps_packed_b: %d\n", ps_packed_b) + # printf("The start of the macro kernel.\n") + # printf("n_iter: %d\n", n_iter) + # printf("n_remainder: %d\n", n_remainder) + # printf("m_iter: %d\n", m_iter) + # printf("m_remainder: %d\n", m_remainder) + # printf("c_row_off: %d\n", c_row_off) + # printf("c_col_off: %d\n", c_col_off) + # printf("macro_m: %d\n", macro_m) + # printf("macro_n: %d\n", macro_n) + # printf("macro_k: %d\n", macro_k) + # printf("ps_packed_a: %d\n", ps_packed_a) + # printf("ps_packed_b: %d\n", ps_packed_b) if n_remainder > 0: From b2a6c152f067f47ebb42308f3d677da6945814bd Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 09:24:05 -0400 Subject: [PATCH 089/148] .. --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 1528feac7..96f1996ab 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(768, 768, 768)]: +for m, n, k in [(512, 512, 512)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 43922bb9093af2d08549eb191c983615167105d4 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:05:17 -0400 Subject: [PATCH 090/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 23336bfd2..7e99bbbc1 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -366,8 +366,16 @@ def micro_kernel( packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height - packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) - packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) + # packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) + # packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) + packb_buf_ptr = module.define_global_var( + name='packb_buf_ptr', + var_type=float32[packed_b_total_size] + ) + packa_buf_ptr = module.define_global_var( + name='packa_buf_ptr', + var_type=float32[packed_a_total_size] + ) From 553dfc409af07fee140ecdfd8d4f3d2ed5f29a15 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:08:09 -0400 Subject: [PATCH 091/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 7e99bbbc1..9f5821b65 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1056,8 +1056,8 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], comm_id_5th_loop = tid_5th_loop gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) - avx_free(packa_buf) - avx_free(packb_buf) + # avx_free(packa_buf) + # avx_free(packb_buf) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) matmul_kernel_x86_v3.kind = "cpu_kernel" From ae29fb31e318786476bb9f696dd4966fbdc991e7 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:09:09 -0400 Subject: [PATCH 092/148] ... --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 96f1996ab..1528feac7 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(512, 512, 512)]: +for m, n, k in [(768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 0572ace3495f5994b89311cc73d237ca99024a31 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:13:26 -0400 Subject: [PATCH 093/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 1528feac7..96f1996ab 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(768, 768, 768)]: +for m, n, k in [(512, 512, 512)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From ce1f5fd876ae350449c07942f7fe73ef02cbbd0e Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:23:44 -0400 Subject: [PATCH 094/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9f5821b65..9ee916246 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -856,7 +856,8 @@ def gemm_3rd_loop( # Get our position within the packed A global buffer # packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) printf("work_id_packa * packed_a_individual_size: %d\n", work_id_packa * packed_a_individual_size) - packed_a_buf = ~packa_buf[work_id_packa * packed_a_individual_size] + # packed_a_buf = ~packa_buf[work_id_packa * packed_a_individual_size] + packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) printf("Got the packed_a_buf\n") # TODO: If passed, see if this barrier is necessary From aaa500c34274df569f5d029f11ed7a3fc26fa3f0 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:26:22 -0400 Subject: [PATCH 095/148] .. --- .../ops/matmul/matmul_f32_x86_refactored.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9ee916246..fde2c40f3 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -570,57 +570,57 @@ def gemm_pack_b( b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) - avx_f32x8_store_aligned(packed_b_buff_curr, b00) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) + avx_f32x8_store(packed_b_buff_curr, b00) + avx_f32x8_store(packed_b_buff_curr + 8, b08) packed_b_buff_curr += 16 b10 = avx_f32x8_load(b_panel + n_size) b18 = avx_f32x8_load(b_panel + (n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b10) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b18) + avx_f32x8_store(packed_b_buff_curr, b10) + avx_f32x8_store(packed_b_buff_curr + 8, b18) packed_b_buff_curr += 16 b20 = avx_f32x8_load(b_panel + (2 * n_size)) b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b20) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b28) + avx_f32x8_store(packed_b_buff_curr, b20) + avx_f32x8_store(packed_b_buff_curr + 8, b28) packed_b_buff_curr += 16 b30 = avx_f32x8_load(b_panel + (3 * n_size)) b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b30) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b38) + avx_f32x8_store(packed_b_buff_curr, b30) + avx_f32x8_store(packed_b_buff_curr + 8, b38) packed_b_buff_curr += 16 b40 = avx_f32x8_load(b_panel + (4 * n_size)) b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b40) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b48) + avx_f32x8_store(packed_b_buff_curr, b40) + avx_f32x8_store(packed_b_buff_curr + 8, b48) packed_b_buff_curr += 16 b50 = avx_f32x8_load(b_panel + (5 * n_size)) b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b50) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b58) + avx_f32x8_store(packed_b_buff_curr, b50) + avx_f32x8_store(packed_b_buff_curr + 8, b58) packed_b_buff_curr += 16 b60 = avx_f32x8_load(b_panel + (6 * n_size)) b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b60) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b68) + avx_f32x8_store(packed_b_buff_curr, b60) + avx_f32x8_store(packed_b_buff_curr + 8, b68) packed_b_buff_curr += 16 b70 = avx_f32x8_load(b_panel + (7 * n_size)) b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) - avx_f32x8_store_aligned(packed_b_buff_curr, b70) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b78) + avx_f32x8_store(packed_b_buff_curr, b70) + avx_f32x8_store(packed_b_buff_curr + 8, b78) packed_b_buff_curr += 16 @@ -631,8 +631,8 @@ def gemm_pack_b( row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) - avx_f32x8_store_aligned(packed_b_buff_curr, b00) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) + avx_f32x8_store(packed_b_buff_curr, b00) + avx_f32x8_store(packed_b_buff_curr + 8, b08) packed_b_buff_curr += 16 row += 1 From 6445811e12807f94c409a824a6d5de02dc62ac36 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:26:46 -0400 Subject: [PATCH 096/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 96f1996ab..1528feac7 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(512, 512, 512)]: +for m, n, k in [(768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From d3e1a1d3f0b262ec8154a8e3497b496dce9e3943 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:27:44 -0400 Subject: [PATCH 097/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 1528feac7..96f1996ab 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(768, 768, 768)]: +for m, n, k in [(512, 512, 512)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 65898487a43565d527bf1dcf433fe0c4a9819b25 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:45:18 -0400 Subject: [PATCH 098/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index fde2c40f3..b2b1d5534 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -358,6 +358,8 @@ def micro_kernel( packed_b_individual_size = packed_b_width * packed_b_height packed_a_individual_height = MC + if packed_a_individual_height > m_size: + packed_a_individual_height = (m_size + MR - 1) // MR * MR packed_a_total_height = packed_a_individual_height * loop3_nways packed_a_width = KC From 4bc93c882d3920bc599e913a8e387cc2d1eefd62 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:51:05 -0400 Subject: [PATCH 099/148] . --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 96f1996ab..1528feac7 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(512, 512, 512)]: +for m, n, k in [(768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 563b121e47bdb4d8eb500b518badb722799f893f Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:53:03 -0400 Subject: [PATCH 100/148] . --- python/mat_new.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/mat_new.py b/python/mat_new.py index 1528feac7..115ef2fb8 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -46,6 +46,11 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() + for i in range(m): + for j in range(n): + if abs(actual[i, j] - desired[i, j]) < 1e-3: + print(f"Actually passed for i={i}, j={j}") + np.testing.assert_allclose( actual=actual, From 17011a144879d0fd159c4e916f199ce33cab3fe5 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 11:55:58 -0400 Subject: [PATCH 101/148] . --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index b2b1d5534..d31c527e0 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -769,18 +769,18 @@ def gemm_macro( while j < jr_end: b1 = packed_b + j * cstep_b c1 = macro_c_cast + j * cstep_c - printf("j = %d\n", j) - printf("The offset j * cstep_c: %d\n\n", j * cstep_c) + # printf("j = %d\n", j) + # printf("The offset j * cstep_c: %d\n\n", j * cstep_c) n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder # Loop over the m dimension, MR rows at a time i = ir_start while i < ir_end: - printf("i = %d\n", i) + # printf("i = %d\n", i) a1 = packed_a + i * rstep_a c11 = c1 + i * rstep_c - printf("The offset i * rstep_a: %d\n", i * rstep_a) - printf("The offset i * rstep_c: %d\n\n", i * rstep_c) + # printf("The offset i * rstep_a: %d\n", i * rstep_a) + # printf("The offset i * rstep_c: %d\n\n", i * rstep_c) c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder From 18f8b5377bdefb48d7cd5647640f9dbec11e2b67 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 12:09:39 -0400 Subject: [PATCH 102/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index d31c527e0..828270cfd 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -367,6 +367,8 @@ def micro_kernel( packed_a_width = k_size packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height + printf("packed_a_individual_height: %d\n", packed_a_individual_height) + printf("packed_a_width: %d\n", packed_a_width) # packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) # packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) From 12e44c239694244c194d602e7729dcf5dc74ce72 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 12:16:37 -0400 Subject: [PATCH 103/148] .. --- .../hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 828270cfd..6c80c88f6 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -817,9 +817,6 @@ def gemm_3rd_loop( comm_id_3rd_loop: int32, work_id_3rd_loop: int32, is_first: bool): - - printf("The start of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) - comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) comm_id_packa = comm_id_macro @@ -932,7 +929,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) - # printf("i_loop4: %d\n", i_loop4) + printf("i_loop4: %d\n", i_loop4) loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width @@ -1025,7 +1022,8 @@ def gemm_5th_loop(a: float32[m_size, k_size], while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) - # printf("b_alg_loop5: %d\n", b_alg_loop5) + printf("loop5_iter: %d\n", loop5_iter) + printf("b_alg_loop5: %d\n", b_alg_loop5) loop5_partition_c_width = b_alg_loop5 loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, From ceb22dda53b9761bfe29a1f059e08b37b3a7aaf0 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 12:18:14 -0400 Subject: [PATCH 104/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 6c80c88f6..5ac78aabc 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -930,6 +930,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) printf("i_loop4: %d\n", i_loop4) + printf("b_alg_loop4: %d\n", b_alg_loop4) loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width From 68fbba8cf2eaf0c66a24dac03bf861ef410140b7 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Tue, 29 Aug 2023 12:19:20 -0400 Subject: [PATCH 105/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 5ac78aabc..51a46e09f 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -927,7 +927,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], # printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) while i_loop4 < k_size: - b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, NC) + b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) printf("i_loop4: %d\n", i_loop4) printf("b_alg_loop4: %d\n", b_alg_loop4) From 0c3639fb63cfabef898dfe93686c95248b9432be Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 06:30:56 -0400 Subject: [PATCH 106/148] . --- .../ops/matmul/matmul_f32_x86_refactored.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 51a46e09f..0eb425a46 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -738,13 +738,13 @@ def gemm_macro( ~ir_inc ) - printf("jr_start: %d\n", jr_start) - printf("jr_end: %d\n", jr_end) - printf("jr_inc: %d\n", jr_inc) - - printf("ir_start: %d\n", ir_start) - printf("ir_end: %d\n", ir_end) - printf("ir_inc: %d\n", ir_inc) + # printf("jr_start: %d\n", jr_start) + # printf("jr_end: %d\n", jr_end) + # printf("jr_inc: %d\n", jr_inc) + # + # printf("ir_start: %d\n", ir_start) + # printf("ir_end: %d\n", ir_end) + # printf("ir_inc: %d\n", ir_inc) rs_packeda = 1 rstep_a = ps_packed_a @@ -753,10 +753,10 @@ def gemm_macro( cstep_c = NR rstep_c = n_size * MR - printf("rstep_a: %d\n", rstep_a) - printf("cstep_b: %d\n", cstep_b) - printf("cstep_c: %d\n", cstep_c) - printf("rstep_c: %d\n", rstep_c) + # printf("rstep_a: %d\n", rstep_a) + # printf("cstep_b: %d\n", cstep_b) + # printf("cstep_c: %d\n", cstep_c) + # printf("rstep_c: %d\n", rstep_c) macro_c_cast = as_tensor_pointer( ~c[c_row_off, c_col_off], @@ -833,14 +833,14 @@ def gemm_3rd_loop( ~m_start_loop3, ~m_end_loop3 ) - printf("In loop 3: m_start_loop3: %d, m_end_loop3: %d\n", m_start_loop3, m_end_loop3) + # printf("In loop 3: m_start_loop3: %d, m_end_loop3: %d\n", m_start_loop3, m_end_loop3) ii = m_start_loop3 while ii < m_end_loop3: b_alg_loop3 = determine_blocksize_f_sub( ii, m_size, MC ) - printf("The ii in loop3: %d\n", ii) - printf("b_alg_loop3: %d\n", b_alg_loop3) + # printf("The ii in loop3: %d\n", ii) + # printf("b_alg_loop3: %d\n", b_alg_loop3) # Acquire the partition at loop 3 loop3_partition_a_start_row = ii loop3_partition_a_height = b_alg_loop3 @@ -849,17 +849,17 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) - printf("Got the loop3_partition_a\n") - printf("packed_a_individual_size: %d\n", packed_a_individual_size) - printf("work_id_packa: %d\n", work_id_packa) - printf("packed_a_total_size: %d\n", packed_a_total_size) + # printf("Got the loop3_partition_a\n") + # printf("packed_a_individual_size: %d\n", packed_a_individual_size) + # printf("work_id_packa: %d\n", work_id_packa) + # printf("packed_a_total_size: %d\n", packed_a_total_size) # Get our position within the packed A global buffer # packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) - printf("work_id_packa * packed_a_individual_size: %d\n", work_id_packa * packed_a_individual_size) + # printf("work_id_packa * packed_a_individual_size: %d\n", work_id_packa * packed_a_individual_size) # packed_a_buf = ~packa_buf[work_id_packa * packed_a_individual_size] packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) - printf("Got the packed_a_buf\n") + # printf("Got the packed_a_buf\n") # TODO: If passed, see if this barrier is necessary thrcomm_barrier( @@ -929,8 +929,8 @@ def gemm_4th_loop(a: float32[m_size, k_size], while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) - printf("i_loop4: %d\n", i_loop4) - printf("b_alg_loop4: %d\n", b_alg_loop4) + # printf("i_loop4: %d\n", i_loop4) + # printf("b_alg_loop4: %d\n", b_alg_loop4) loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width From 76d55a16bcc084af6cbc90fc03757ad5099ce775 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 09:47:05 -0400 Subject: [PATCH 107/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 0eb425a46..fa6eb7a35 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -929,25 +929,15 @@ def gemm_4th_loop(a: float32[m_size, k_size], while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) - # printf("i_loop4: %d\n", i_loop4) - # printf("b_alg_loop4: %d\n", b_alg_loop4) - loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width loop4_partition_b_start_row = i_loop4 loop4_partition_b_start_col = loop5_partition_b_start_col - # printf("loop4_partition_b_height: %d\n", loop4_partition_b_height) - # printf("loop4_partition_b_width: %d\n", loop4_partition_b_width) - # printf("loop4_partition_b_start_row: %d\n", loop4_partition_b_start_row) - # printf("loop4_partition_b_start_col: %d\n", loop4_partition_b_start_col) + loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) - # Get the thread's partition of the buffer and the matrix - # packed_b_buf = packb_buf + ( - # packb_start_offsets[work_id_5th_loop, 0] * packed_b_height - # ) packed_b_buf = packb_buf + ( packed_b_individual_size * work_id_5th_loop From 9e289e491efefa2723ec9d7baab89ac03a7cbe5c Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 09:50:33 -0400 Subject: [PATCH 108/148] .. --- python/mat_new.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mat_new.py b/python/mat_new.py index 115ef2fb8..912deef16 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -46,10 +46,10 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() - for i in range(m): - for j in range(n): - if abs(actual[i, j] - desired[i, j]) < 1e-3: - print(f"Actually passed for i={i}, j={j}") + # for i in range(m): + # for j in range(n): + # if abs(actual[i, j] - desired[i, j]) < 1e-3: + # print(f"Actually passed for i={i}, j={j}") np.testing.assert_allclose( From 165c3d573f7533f0060976330246f49fa3a718bb Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 09:51:16 -0400 Subject: [PATCH 109/148] .. --- python/mat_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 912deef16..533d6462e 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,7 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(768, 768, 768)]: +for m, n, k in [(768, 768, 768), (111, 333, 222)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From e898772a983b46ef45204f89c91c111728c2a420 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 09:55:02 -0400 Subject: [PATCH 110/148] .. --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index fa6eb7a35..4f977598e 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -86,7 +86,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=896, KC=512, ways=(1, 1, 1, 1) + self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type From 4cb35cb160e62997df1e8e5f8be00ea248c1b00a Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 09:57:33 -0400 Subject: [PATCH 111/148] . --- python/mat_new.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/mat_new.py b/python/mat_new.py index 533d6462e..48a782eb1 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -51,6 +51,11 @@ def matmul_ansor(M, K, N, dtype): # if abs(actual[i, j] - desired[i, j]) < 1e-3: # print(f"Actually passed for i={i}, j={j}") + for i in range(m): + for j in range(n): + if actual[i, j] == 0.0: + print(f"element is 0 for i={i}, j={j}") + np.testing.assert_allclose( actual=actual, From 6ba8075e6c02d62f7d99f243569d17838793820a Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 09:58:52 -0400 Subject: [PATCH 112/148] .. --- python/mat_new.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mat_new.py b/python/mat_new.py index 48a782eb1..b719a65fb 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -30,7 +30,8 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -for m, n, k in [(768, 768, 768), (111, 333, 222)]: +# for m, n, k in [(768, 768, 768), (111, 333, 222)]: +for m, n, k in [(64, 64, 64)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) From 073266a45292c8a2305aab69ad26660d643367e5 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 10:30:47 -0400 Subject: [PATCH 113/148] . --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 4f977598e..fc45825f5 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1049,6 +1049,9 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop + printf("tid_5th_loop: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", + tid_5th_loop, work_id_5th_loop, comm_id_5th_loop) + gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) # avx_free(packa_buf) # avx_free(packb_buf) From d736d96a9e8da1c2767253fc73722650c6fcdaa7 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 10:37:28 -0400 Subject: [PATCH 114/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index fc45825f5..82884988c 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -350,8 +350,8 @@ def micro_kernel( if packed_b_width > n_size: packed_b_width = (n_size + NR - 1) // NR * NR - printf("packed_b_height: %d\n", packed_b_height) - printf("packed_b_width: %d\n", packed_b_width) + # printf("packed_b_height: %d\n", packed_b_height) + # printf("packed_b_width: %d\n", packed_b_width) packed_b_total_width = packed_b_width * loop5_nways packed_b_total_size = packed_b_total_width * packed_b_height @@ -411,15 +411,15 @@ def gemm_pack_a( layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) ) - printf("pack a: packed_a_individual_height: %d, packed_a_width: %d\n", packed_a_individual_height, - packed_a_width) + # printf("pack a: packed_a_individual_height: %d, packed_a_width: %d\n", packed_a_individual_height, + # packed_a_width) npanels_full_a = loop3_partition_a_height // MR panel_a_remainder = loop3_partition_a_height % MR - printf("loop3_partition_a_height: %d\n", loop3_partition_a_height) - printf("npanels_full_a: %d\n", npanels_full_a) + # printf("loop3_partition_a_height: %d\n", loop3_partition_a_height) + # printf("npanels_full_a: %d\n", npanels_full_a) npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) for ii_panel in range(npanels_a): From 83118f31c549cd366eed589d804455bfc801fdaf Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 10:38:08 -0400 Subject: [PATCH 115/148] . --- python/mat_new.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mat_new.py b/python/mat_new.py index b719a65fb..5c854db3f 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -52,10 +52,10 @@ def matmul_ansor(M, K, N, dtype): # if abs(actual[i, j] - desired[i, j]) < 1e-3: # print(f"Actually passed for i={i}, j={j}") - for i in range(m): - for j in range(n): - if actual[i, j] == 0.0: - print(f"element is 0 for i={i}, j={j}") + # for i in range(m): + # for j in range(n): + # if actual[i, j] == 0.0: + # print(f"element is 0 for i={i}, j={j}") np.testing.assert_allclose( From df1cc83ec0617a37ffbaae41651bb9485f7d2d36 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 10:40:05 -0400 Subject: [PATCH 116/148] .... --- python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 82884988c..de1cfe9e2 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -1007,6 +1007,9 @@ def gemm_5th_loop(a: float32[m_size, k_size], thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) + printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, loop5_my_start: %d, loop5_my_end: %d\n", + work_id_5th_loop, comm_id_5th_loop, loop5_my_start, loop5_my_end) + # printf("loop5_my_start: %d, loop5_my_end: %d\n", loop5_my_start, loop5_my_end) loop5_iter = loop5_my_start From a85e56f0d6b5de62955c09ea65bb432cedc43344 Mon Sep 17 00:00:00 2001 From: BolinSNLHM Date: Wed, 30 Aug 2023 12:59:42 -0400 Subject: [PATCH 117/148] .. --- .../graph/ops/matmul/matmul_f32_x86_refactored.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index de1cfe9e2..41afd7884 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -86,7 +86,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) + self, MC=2016, NC=896, KC=512, ways=(1, 2, 2, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -532,7 +532,7 @@ def gemm_pack_b( loop4_partition_b_width: int32, loop4_partition_b_height: int32, packed_b_buf: ~float32, - comm_id_packb: int32, work_id_packb: int32, + comm_id_packb: int32, workn_id_packb: int32, packb_nways: int32 ): # printf("The start of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) @@ -545,6 +545,9 @@ def gemm_pack_b( # printf("packed_b_height: %d", packed_b_height) # printf("packedb_panel_stride: %d\n", packedb_panel_stride) + printf("work_id_packb: %d, packed_b_height: %d, loop4_partition_b_width: %d, npanels_full_b: %d, packb_nways: %d", + work_id_packb, packed_b_height, loop4_partition_b_width, npanels_full_b, packb_nways) + # Loop for the packing of B for i_panel in range(npanels_b): if i_panel % packb_nways != work_id_packb % packb_nways: @@ -1007,8 +1010,8 @@ def gemm_5th_loop(a: float32[m_size, k_size], thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) - printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, loop5_my_start: %d, loop5_my_end: %d\n", - work_id_5th_loop, comm_id_5th_loop, loop5_my_start, loop5_my_end) + # printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, loop5_my_start: %d, loop5_my_end: %d\n", + # work_id_5th_loop, comm_id_5th_loop, loop5_my_start, loop5_my_end) # printf("loop5_my_start: %d, loop5_my_end: %d\n", loop5_my_start, loop5_my_end) @@ -1016,8 +1019,8 @@ def gemm_5th_loop(a: float32[m_size, k_size], while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) - printf("loop5_iter: %d\n", loop5_iter) - printf("b_alg_loop5: %d\n", b_alg_loop5) + # printf("loop5_iter: %d\n", loop5_iter) + # printf("b_alg_loop5: %d\n", b_alg_loop5) loop5_partition_c_width = b_alg_loop5 loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, From 728ec9aa4c6bbef8485cf938cbb42d5e4c283ccc Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Fri, 6 Oct 2023 16:22:05 -0400 Subject: [PATCH 118/148] kept debugging the matrix mul kernel --- python/hidet/backend/codegen.py | 2 +- .../ops/matmul/matmul_f32_x86_refactored.py | 203 ++++++++---------- python/mat_new.py | 45 +++- 3 files changed, 130 insertions(+), 120 deletions(-) diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index e5e474636..0c20dc0c5 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -496,7 +496,7 @@ def visit_ForStmt(self, stmt: ForStmt): doc += NewLine() + '#pragma unroll' elif stmt.attr.parallel: if stmt.attr.parallel_threads: - doc += NewLine() + '#pragma omp parallel for schedule(dynamic) num_threads({})'.format( + doc += NewLine() + '#pragma omp parallel for num_threads({})'.format( stmt.attr.parallel_threads ) else: diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 41afd7884..a9e84aead 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -86,7 +86,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=896, KC=512, ways=(1, 2, 2, 1) + self, MC=2016, NC=896, KC=512, ways=(2, 1, 1, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -143,8 +143,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): init_thr.kind = "cpu_internal" - - # Helpers packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major( @@ -156,7 +154,7 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): # Get the number of threads remaining at each level loop5_nthreads = nthreads - loop4_nthreads = loop5_nthreads // loop5_nways + loop4_nthreads = nthreads // loop5_nways loop3_nthreads = loop4_nthreads macro_nthreads = loop3_nthreads // loop3_nways loop1_nthreads = macro_nthreads // macro_nways @@ -260,7 +258,7 @@ def micro_kernel( is_first: bool ): # printf("The start of the micro_kernel.....") - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) + c = as_tensor_pointer(c_ptr, dtype=float32, shape=[m_size, n_size]) c0 = avx_f32x8_load(~c[0, 0]) c08 = avx_f32x8_load(~c[0, 8]) c1 = avx_f32x8_load(~c[1, 0]) @@ -274,6 +272,10 @@ def micro_kernel( c5 = avx_f32x8_load(~c[5, 0]) c58 = avx_f32x8_load(~c[5, 8]) + printf("The msize in the micro kernel: %d\n", msize) + printf("The nsize in the micro kernel: %d\n", nsize) + printf("The pb in the micro kernel: %d\n", pb) + if is_first: c0 = avx_f32x8_setzero() c08 = avx_f32x8_setzero() @@ -294,31 +296,42 @@ def micro_kernel( for _ in range(pb): bb0to7 = avx_f32x8_load_aligned(b_ptr) bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) + + + printf("bb0 to bb7: %lf %lf %lf %lf %lf %lf %lf %lf\n", b_ptr[0], b_ptr[1], b_ptr[2], b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7]) + printf("bb8 to bb15: %lf %lf %lf %lf %lf %lf %lf %lf\n", b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) + b_ptr = b_ptr + 16 aa = avx_f32x8_broadcast(a_ptr) c0 = avx_f32x8_fmadd(aa, bb0to7, c0) c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + printf("broadcasted aa: %lf\n", a_ptr[0]) aa = avx_f32x8_broadcast(a_ptr + 1) c1 = avx_f32x8_fmadd(aa, bb0to7, c1) c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + printf("broadcasted aa: %lf\n", a_ptr[1]) aa = avx_f32x8_broadcast(a_ptr + 2) c2 = avx_f32x8_fmadd(aa, bb0to7, c2) c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + printf("broadcasted aa: %lf\n", a_ptr[2]) aa = avx_f32x8_broadcast(a_ptr + 3) c3 = avx_f32x8_fmadd(aa, bb0to7, c3) c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + printf("broadcasted aa: %lf\n", a_ptr[3]) aa = avx_f32x8_broadcast(a_ptr + 4) c4 = avx_f32x8_fmadd(aa, bb0to7, c4) c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + printf("broadcasted aa: %lf\n", a_ptr[4]) aa = avx_f32x8_broadcast(a_ptr + 5) c5 = avx_f32x8_fmadd(aa, bb0to7, c5) c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + printf("broadcasted aa: %lf\n", a_ptr[5]) a_ptr = a_ptr + 6 @@ -326,21 +339,20 @@ def micro_kernel( avx_f32x8_store(c_ptr, c0) avx_f32x8_store(c_ptr + 8, c08) - avx_f32x8_store(c_ptr + nsize, c1) - avx_f32x8_store(c_ptr + (nsize + 8), c18) + avx_f32x8_store(c_ptr + n_size, c1) + avx_f32x8_store(c_ptr + (n_size + 8), c18) - avx_f32x8_store(c_ptr + 2 * nsize, c2) - avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) + avx_f32x8_store(c_ptr + 2 * n_size, c2) + avx_f32x8_store(c_ptr + (2 * n_size + 8), c28) - avx_f32x8_store(c_ptr + 3 * nsize, c3) - avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) + avx_f32x8_store(c_ptr + 3 * n_size, c3) + avx_f32x8_store(c_ptr + (3 * n_size + 8), c38) - avx_f32x8_store(c_ptr + 4 * nsize, c4) - avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) + avx_f32x8_store(c_ptr + 4 * n_size, c4) + avx_f32x8_store(c_ptr + (4 * n_size + 8), c48) - avx_f32x8_store(c_ptr + 5 * nsize, c5) - avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) - # printf("The end of micro kernel....") + avx_f32x8_store(c_ptr + 5 * n_size, c5) + avx_f32x8_store(c_ptr + (5 * n_size + 8), c58) #### Some setup code #### packed_b_height = KC @@ -350,9 +362,6 @@ def micro_kernel( if packed_b_width > n_size: packed_b_width = (n_size + NR - 1) // NR * NR - # printf("packed_b_height: %d\n", packed_b_height) - # printf("packed_b_width: %d\n", packed_b_width) - packed_b_total_width = packed_b_width * loop5_nways packed_b_total_size = packed_b_total_width * packed_b_height packed_b_individual_size = packed_b_width * packed_b_height @@ -367,11 +376,7 @@ def micro_kernel( packed_a_width = k_size packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height - printf("packed_a_individual_height: %d\n", packed_a_individual_height) - printf("packed_a_width: %d\n", packed_a_width) - # packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) - # packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) packb_buf_ptr = module.define_global_var( name='packb_buf_ptr', var_type=float32[packed_b_total_size] @@ -404,15 +409,15 @@ def gemm_pack_a( work_id_packa: int32, packa_nways: int32 ): - # printf("The start of the pack a, comm id: %d, work id: %d\n", comm_id_packa, work_id_packa) packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) ) - # printf("pack a: packed_a_individual_height: %d, packed_a_width: %d\n", packed_a_individual_height, - # packed_a_width) + + printf("work_id_packa: %d, packa_nways: %d, loop3_partition_a_width: %d, loop3_partition_a_height: %d\n", + work_id_packa, packa_nways, loop3_partition_a_width, loop3_partition_a_height) npanels_full_a = loop3_partition_a_height // MR @@ -422,6 +427,7 @@ def gemm_pack_a( # printf("npanels_full_a: %d\n", npanels_full_a) npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) + for ii_panel in range(npanels_a): if ii_panel % packa_nways != work_id_packa % packa_nways: continue @@ -429,6 +435,8 @@ def gemm_pack_a( a_curr_panel_row_start = ii_panel * MR a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) + # printf("packing of a: the panel %d is taken care of by the thread with pack a work id: %d; the a_curr_panel_row_start: %d, a_curr_panel_height: %d\n\n", + # ii_panel, work_id_packa, a_curr_panel_row_start, a_curr_panel_height) if a_curr_panel_height == MR: # unroll the packing by 8 k_iters = loop3_partition_a_width // 8 @@ -439,6 +447,11 @@ def gemm_pack_a( a_curr_panel_col = loop3_partition_a + ( a_curr_panel_row_start * k_size + col ) + + # printf("In the packing of A: the offset a_curr_panel_row_start * k_size + col for id %d: %d\n", work_id_packa, a_curr_panel_row_start * k_size + col) + printf("work_id_packa: %d, a_curr_panel_row_start: %d, a_curr_panel_height: %d, the offset a_curr_panel_row_start * k_size + col: %d\n", + work_id_packa, a_curr_panel_row_start, a_curr_panel_height, a_curr_panel_row_start * k_size + col) + v0 = avx_f32x8_load(a_curr_panel_col) v1 = avx_f32x8_load(a_curr_panel_col + k_size) v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) @@ -532,21 +545,18 @@ def gemm_pack_b( loop4_partition_b_width: int32, loop4_partition_b_height: int32, packed_b_buf: ~float32, - comm_id_packb: int32, workn_id_packb: int32, + comm_id_packb: int32, work_id_packb: int32, packb_nways: int32 ): - # printf("The start of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR + if comm_id_packb == 0: + printf("loop4_partition_b_width: %d; loop4_partition_b_height: %d\n", loop4_partition_b_width, loop4_partition_b_height) npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR - # printf("Start of the packing of B...") - # printf("packed_b_height: %d", packed_b_height) - # printf("packedb_panel_stride: %d\n", packedb_panel_stride) - printf("work_id_packb: %d, packed_b_height: %d, loop4_partition_b_width: %d, npanels_full_b: %d, packb_nways: %d", - work_id_packb, packed_b_height, loop4_partition_b_width, npanels_full_b, packb_nways) + printf("work_id_packb: %d; npabels_b: %d, packedb_panel_stride: %d\n", work_id_packb, npanels_b, packedb_panel_stride) # Loop for the packing of B for i_panel in range(npanels_b): @@ -558,20 +568,16 @@ def gemm_pack_b( curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) - # printf("i_panel: %d\n", i_panel) - # printf("curr_panel_start: %d\n", curr_panel_start) - # printf("curr_panel_width: %d\n", curr_panel_width) + printf("work_id_packb: %d; curr_panel_start: %d; curr_panel_width: %d; the offset i_panel * packedb_panel_stride: %d\n", + work_id_packb, curr_panel_start, curr_panel_width, i_panel * packedb_panel_stride) if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 k_remainder = loop4_partition_b_height % 8 - # printf("k_iters: %d\n", k_iters) - # printf("k_remainder: %d\n", k_remainder) row = 0 for k_iter in range(k_iters): row = k_iter * 8 - # printf('row: %d\n', row) b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) @@ -687,27 +693,17 @@ def gemm_macro( work_id_macro: int32, is_first: bool ): + # assert loop1_nthreads == 1 comm_id_1st_loop = comm_id_macro % loop1_nthreads - work_id_1st_loop = comm_id_macro // (loop1_nthreads // loop1_nways) + work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways) n_iter = macro_n // NR n_remainder = macro_n % NR m_iter = macro_m // MR m_remainder = macro_m % MR - # printf("The start of the macro kernel.\n") - # printf("n_iter: %d\n", n_iter) - # printf("n_remainder: %d\n", n_remainder) - # printf("m_iter: %d\n", m_iter) - # printf("m_remainder: %d\n", m_remainder) - # printf("c_row_off: %d\n", c_row_off) - # printf("c_col_off: %d\n", c_col_off) - # printf("macro_m: %d\n", macro_m) - # printf("macro_n: %d\n", macro_n) - # printf("macro_k: %d\n", macro_k) - # printf("ps_packed_a: %d\n", ps_packed_a) - # printf("ps_packed_b: %d\n", ps_packed_b) - + printf("The macro kernel with comm_id_macro: %d, work_id_macro: %d , macro_m: %d, macro_n: %d, macro_k: %d, c_row_off: %d, c_col_off: %d, ps_packed_a: %d, ps_packed_b: %d\n", + comm_id_macro, work_id_macro, macro_m, macro_n, macro_k, c_row_off, c_col_off, ps_packed_a, ps_packed_b) if n_remainder > 0: n_iter += 1 @@ -741,13 +737,11 @@ def gemm_macro( ~ir_inc ) - # printf("jr_start: %d\n", jr_start) - # printf("jr_end: %d\n", jr_end) - # printf("jr_inc: %d\n", jr_inc) - # - # printf("ir_start: %d\n", ir_start) - # printf("ir_end: %d\n", ir_end) - # printf("ir_inc: %d\n", ir_inc) + # printf("jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d, work_id_macro: %d, work_id_1st_loop: %d\n", + # jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc, work_id_macro, work_id_1st_loop) + + printf("work_id_macro: %d, work_id_1st_loop: %d, jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d\n", + work_id_macro, work_id_1st_loop, jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc) rs_packeda = 1 rstep_a = ps_packed_a @@ -756,11 +750,6 @@ def gemm_macro( cstep_c = NR rstep_c = n_size * MR - # printf("rstep_a: %d\n", rstep_a) - # printf("cstep_b: %d\n", cstep_b) - # printf("cstep_c: %d\n", cstep_c) - # printf("rstep_c: %d\n", rstep_c) - macro_c_cast = as_tensor_pointer( ~c[c_row_off, c_col_off], dtype=float32, @@ -774,21 +763,19 @@ def gemm_macro( while j < jr_end: b1 = packed_b + j * cstep_b c1 = macro_c_cast + j * cstep_c - # printf("j = %d\n", j) - # printf("The offset j * cstep_c: %d\n\n", j * cstep_c) - n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder - # Loop over the m dimension, MR rows at a time + i = ir_start while i < ir_end: # printf("i = %d\n", i) a1 = packed_a + i * rstep_a c11 = c1 + i * rstep_c - # printf("The offset i * rstep_a: %d\n", i * rstep_a) - # printf("The offset i * rstep_c: %d\n\n", i * rstep_c) c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder + printf("The i value is %d for the macro work id %d. The offset i * rstep_a: %d, the offset i * rstep_c: %d, m_cur: %d\n", + i, work_id_macro, i * rstep_a, i * rstep_c, m_cur) + if m_cur == MR and n_cur == NR: micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) else: @@ -826,6 +813,9 @@ def gemm_3rd_loop( work_id_packa = comm_id_macro packa_nways = macro_nthreads + printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, is_first: %d\n", + comm_id_3rd_loop, work_id_3rd_loop, loop3_partition_a_start_col, loop3_partition_b_start_col, is_first) + m_start_loop3 = 0 m_end_loop3 = 0 thread_range_sub( @@ -836,15 +826,11 @@ def gemm_3rd_loop( ~m_start_loop3, ~m_end_loop3 ) - # printf("In loop 3: m_start_loop3: %d, m_end_loop3: %d\n", m_start_loop3, m_end_loop3) ii = m_start_loop3 while ii < m_end_loop3: b_alg_loop3 = determine_blocksize_f_sub( ii, m_size, MC ) - # printf("The ii in loop3: %d\n", ii) - # printf("b_alg_loop3: %d\n", b_alg_loop3) - # Acquire the partition at loop 3 loop3_partition_a_start_row = ii loop3_partition_a_height = b_alg_loop3 @@ -852,17 +838,8 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) - # printf("Got the loop3_partition_a\n") - # printf("packed_a_individual_size: %d\n", packed_a_individual_size) - # printf("work_id_packa: %d\n", work_id_packa) - # printf("packed_a_total_size: %d\n", packed_a_total_size) - - # Get our position within the packed A global buffer - # packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) - # printf("work_id_packa * packed_a_individual_size: %d\n", work_id_packa * packed_a_individual_size) - # packed_a_buf = ~packa_buf[work_id_packa * packed_a_individual_size] - packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) - # printf("Got the packed_a_buf\n") + + packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) # TODO: If passed, see if this barrier is necessary thrcomm_barrier( @@ -906,7 +883,6 @@ def gemm_3rd_loop( is_first ) ii += b_alg_loop3 - # printf("The end of 3rd loop. comm_id_3rd_loop: %d, work_id_3rd_loop: %d\n", comm_id_3rd_loop, work_id_3rd_loop) gemm_3rd_loop.kind = "cpu_internal" @@ -926,8 +902,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) comm_id_packb = comm_id_3rd_loop work_id_packb = comm_id_3rd_loop - # packb_nways = loop3_nthreads - # printf("The start of the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) @@ -937,6 +911,9 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_b_start_row = i_loop4 loop4_partition_b_start_col = loop5_partition_b_start_col + printf("work_id_4th_loop: %d, work_id_5th_loop: %d, loop4_partition_b_height: %d, loop4_partition_b_width: %d, loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d\n", + work_id_4th_loop, work_id_5th_loop, loop4_partition_b_height, loop4_partition_b_width, loop4_partition_b_start_row, loop4_partition_b_start_col) + loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) @@ -950,8 +927,11 @@ def gemm_4th_loop(a: float32[m_size, k_size], (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) + printf("work_id_4th_loop: %d, comm_id_4th_loop: %d, the offset packed_b_individual_size * work_id_5th_loop: %d, the offset loop4_partition_b_start_row * n_size + loop4_partition_b_start_col: %d\n", + work_id_4th_loop, comm_id_4th_loop, packed_b_individual_size * work_id_5th_loop, loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - # TODO: If passed, see if this barrier is really needed + + # # TODO: If passed, see if this barrier is really needed thrcomm_barrier( comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_4th_loop], @@ -960,9 +940,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], ) - # Start the packing of B - # TODO: Check this assertion: - # TODO: loop3_nthreads == packb_nthreads gemm_pack_b(loop4_partition_b, loop4_partition_b_width, loop4_partition_b_height, packed_b_buf, comm_id_packb, work_id_packb, loop3_nthreads) @@ -987,9 +964,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], is_first ) - i_loop4 += b_alg_loop4 - # printf("The end of the 4th loop. work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) gemm_4th_loop.kind = "cpu_internal" @@ -1001,30 +976,27 @@ def gemm_5th_loop(a: float32[m_size, k_size], comm_id_5th_loop: int32): comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) - # printf("Start of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d, comm_id_4th_loop: %d, work_id_4th_loop: %d\n", - # work_id_5th_loop, comm_id_5th_loop, - # comm_id_4th_loop, work_id_4th_loop) loop5_my_start = -1 loop5_my_end = -1 thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) - # printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, loop5_my_start: %d, loop5_my_end: %d\n", - # work_id_5th_loop, comm_id_5th_loop, loop5_my_start, loop5_my_end) - - # printf("loop5_my_start: %d, loop5_my_end: %d\n", loop5_my_start, loop5_my_end) + printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, loop5_my_start: %d, loop5_my_end: %d\n", + work_id_5th_loop, comm_id_5th_loop, loop5_my_start, loop5_my_end) loop5_iter = loop5_my_start while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) - # printf("loop5_iter: %d\n", loop5_iter) - # printf("b_alg_loop5: %d\n", b_alg_loop5) + loop5_partition_c_width = b_alg_loop5 loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, loop5_partition_b_start_col = loop5_iter + + printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d\n", + work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, loop5_partition_b_start_col) gemm_4th_loop(a, b, c, loop5_partition_b_width, loop5_partition_b_start_col, @@ -1032,10 +1004,8 @@ def gemm_5th_loop(a: float32[m_size, k_size], work_id_4th_loop, work_id_5th_loop) loop5_iter += b_alg_loop5 - # printf("End of 5th loop, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop.kind = 'cpu_internal' - ################### Start of the main kernel ################### @hidet.script def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], @@ -1048,6 +1018,22 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], packb_thrcomm_barrier_threads_arrived, loop5_nways) + # the nthreads and nways print for each loop + printf("loop5_nthreads: %d, loop5_nways: %d\n", loop5_nthreads, loop5_nways) + printf("loop4_nthreads: %d, loop4_nways: %d\n", loop4_nthreads, loop4_nways) + printf("loop3_nthreads: %d, loop3_nways: %d\n", loop3_nthreads, loop3_nways) + printf("macro_nthreads: %d, macro_nways: %d\n", macro_nthreads, macro_nways) + printf("loop1_nthreads: %d, loop1_nways: %d\n", loop1_nthreads, loop1_nways) + + printf("packb_nthreads: %d, packa_nthreads: %d\n", packb_nthreads, packa_nthreads) + + printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, packed_b_total_width, packed_b_height) + printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", packed_a_width, packed_a_individual_height, packed_a_total_height) + + printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) + printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) + + parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads for tidx in grid(nthreads, attrs=parallel_attr): @@ -1055,12 +1041,7 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop - printf("tid_5th_loop: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", - tid_5th_loop, work_id_5th_loop, comm_id_5th_loop) - gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) - # avx_free(packa_buf) - # avx_free(packb_buf) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) matmul_kernel_x86_v3.kind = "cpu_kernel" diff --git a/python/mat_new.py b/python/mat_new.py index 5c854db3f..b62aa5b39 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -6,6 +6,8 @@ from hidet.testing import check_binary from hidet.option import debug_cache_tuning +import torch + import tvm from tvm import te, auto_scheduler @@ -30,10 +32,35 @@ def matmul_ansor(M, K, N, dtype): target = tvm.target.Target("llvm -mcpu=core-avx2") debug_cache_tuning(True) hidet.option.search_space(0) -# for m, n, k in [(768, 768, 768), (111, 333, 222)]: -for m, n, k in [(64, 64, 64)]: - a = hidet.randn([m, k], device='cpu') - b = hidet.randn([k, n], device='cpu') + +np.random.seed(42) +for m, n, k in [(6, 17, 1)]: +# for m, n, k in [(64, 64, 64)]: +# for m, n, k in [(16, 16, 16), (64, 64, 64), (211, 333, 222), (768, 768, 768)]: + a = hidet.ones([m, k], device='cpu') + b = hidet.ones([k, n], device='cpu') + + # a = hidet.randn([m, k], device='cpu') + # b = hidet.randn([k, n], device='cpu') + an = torch.ones(m, k, dtype=torch.float32) + bn = torch.ones(k, n, dtype=torch.float32) + + counter=0 + for i in range(m): + for j in range(k): + an[i, j] = counter + counter += 1 + counter = 0 + for i in range(k): + for j in range(n): + bn[i, j] = counter + counter += 1 + + a = hidet.from_torch(an) + b = hidet.from_torch(bn) + + + x1 = hidet.symbol_like(a) x2 = hidet.symbol_like(b) y = matmul_x86_refactored(x1, x2) @@ -47,10 +74,12 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() - # for i in range(m): - # for j in range(n): - # if abs(actual[i, j] - desired[i, j]) < 1e-3: - # print(f"Actually passed for i={i}, j={j}") + for i in range(m): + for j in range(n): + if abs(actual[i, j] - desired[i, j]) < 1e-3: + print(f"Actually passed for i={i}, j={j}") + else: + print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") # for i in range(m): # for j in range(n): From dfdf0848025842310b5d8d0aba95a7f5ef3131c7 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 26 Oct 2023 20:25:35 -0400 Subject: [PATCH 119/148] bruh --- .../ops/matmul/matmul_f32_x86_refactored.py | 172 ++++++++++-------- python/mat_new.py | 43 +++-- 2 files changed, 114 insertions(+), 101 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index a9e84aead..2e089c7ef 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -86,7 +86,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=896, KC=512, ways=(2, 1, 1, 1) + self, MC=36, NC=32, KC=16, ways=(1, 2, 1, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -109,7 +109,6 @@ def schedule_matmulf32_x86( tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') - with hidet.script_module() as module: # Get the number of threads... loop5_nways, loop3_nways, macro_nways, loop1_nways = ways @@ -257,8 +256,7 @@ def micro_kernel( a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, is_first: bool ): - # printf("The start of the micro_kernel.....") - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[m_size, n_size]) + c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) c08 = avx_f32x8_load(~c[0, 8]) c1 = avx_f32x8_load(~c[1, 0]) @@ -272,9 +270,9 @@ def micro_kernel( c5 = avx_f32x8_load(~c[5, 0]) c58 = avx_f32x8_load(~c[5, 8]) - printf("The msize in the micro kernel: %d\n", msize) - printf("The nsize in the micro kernel: %d\n", nsize) - printf("The pb in the micro kernel: %d\n", pb) + # printf("The msize in the micro kernel: %d\n", msize) + # printf("The nsize in the micro kernel: %d\n", nsize) + # printf("The pb in the micro kernel: %d\n", pb) if is_first: c0 = avx_f32x8_setzero() @@ -298,61 +296,60 @@ def micro_kernel( bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - printf("bb0 to bb7: %lf %lf %lf %lf %lf %lf %lf %lf\n", b_ptr[0], b_ptr[1], b_ptr[2], b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7]) - printf("bb8 to bb15: %lf %lf %lf %lf %lf %lf %lf %lf\n", b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) - b_ptr = b_ptr + 16 - aa = avx_f32x8_broadcast(a_ptr) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - printf("broadcasted aa: %lf\n", a_ptr[0]) + aa1 = avx_f32x8_broadcast(a_ptr) + c0 = avx_f32x8_fmadd(aa1, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa1, bb8to15, c08) + # printf("broadcasted aa: %lf\n", a_ptr[0]) - aa = avx_f32x8_broadcast(a_ptr + 1) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - printf("broadcasted aa: %lf\n", a_ptr[1]) + aa2 = avx_f32x8_broadcast(a_ptr + 1) + c1 = avx_f32x8_fmadd(aa2, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa2, bb8to15, c18) + # printf("broadcasted aa: %lf\n", a_ptr[1]) - aa = avx_f32x8_broadcast(a_ptr + 2) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - printf("broadcasted aa: %lf\n", a_ptr[2]) + aa3 = avx_f32x8_broadcast(a_ptr + 2) + c2 = avx_f32x8_fmadd(aa3, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa3, bb8to15, c28) + # printf("broadcasted aa: %lf\n", a_ptr[2]) - aa = avx_f32x8_broadcast(a_ptr + 3) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - printf("broadcasted aa: %lf\n", a_ptr[3]) + aa4 = avx_f32x8_broadcast(a_ptr + 3) + c3 = avx_f32x8_fmadd(aa4, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa4, bb8to15, c38) + # printf("broadcasted aa: %lf\n", a_ptr[3]) - aa = avx_f32x8_broadcast(a_ptr + 4) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - printf("broadcasted aa: %lf\n", a_ptr[4]) + aa5 = avx_f32x8_broadcast(a_ptr + 4) + c4 = avx_f32x8_fmadd(aa5, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa5, bb8to15, c48) + # printf("broadcasted aa: %lf\n", a_ptr[4]) - aa = avx_f32x8_broadcast(a_ptr + 5) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - printf("broadcasted aa: %lf\n", a_ptr[5]) + aa6 = avx_f32x8_broadcast(a_ptr + 5) + c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) + # printf("broadcasted aa: %lf\n", a_ptr[5]) + # printf("List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) a_ptr = a_ptr + 6 + b_ptr = b_ptr + 16 # Store the results avx_f32x8_store(c_ptr, c0) avx_f32x8_store(c_ptr + 8, c08) - avx_f32x8_store(c_ptr + n_size, c1) - avx_f32x8_store(c_ptr + (n_size + 8), c18) + avx_f32x8_store(c_ptr + nsize, c1) + avx_f32x8_store(c_ptr + (nsize + 8), c18) - avx_f32x8_store(c_ptr + 2 * n_size, c2) - avx_f32x8_store(c_ptr + (2 * n_size + 8), c28) + avx_f32x8_store(c_ptr + 2 * nsize, c2) + avx_f32x8_store(c_ptr + (2 * nsize+ 8), c28) - avx_f32x8_store(c_ptr + 3 * n_size, c3) - avx_f32x8_store(c_ptr + (3 * n_size + 8), c38) + avx_f32x8_store(c_ptr + 3 * nsize, c3) + avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) - avx_f32x8_store(c_ptr + 4 * n_size, c4) - avx_f32x8_store(c_ptr + (4 * n_size + 8), c48) + avx_f32x8_store(c_ptr + 4 * nsize, c4) + avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) - avx_f32x8_store(c_ptr + 5 * n_size, c5) - avx_f32x8_store(c_ptr + (5 * n_size + 8), c58) + avx_f32x8_store(c_ptr + 5 * nsize, c5) + avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) #### Some setup code #### packed_b_height = KC @@ -416,8 +413,8 @@ def gemm_pack_a( column_major(MR, packed_a_width) ) - printf("work_id_packa: %d, packa_nways: %d, loop3_partition_a_width: %d, loop3_partition_a_height: %d\n", - work_id_packa, packa_nways, loop3_partition_a_width, loop3_partition_a_height) + # printf("work_id_packa: %d, packa_nways: %d, loop3_partition_a_width: %d, loop3_partition_a_height: %d\n", + # work_id_packa, packa_nways, loop3_partition_a_width, loop3_partition_a_height) npanels_full_a = loop3_partition_a_height // MR @@ -449,8 +446,8 @@ def gemm_pack_a( ) # printf("In the packing of A: the offset a_curr_panel_row_start * k_size + col for id %d: %d\n", work_id_packa, a_curr_panel_row_start * k_size + col) - printf("work_id_packa: %d, a_curr_panel_row_start: %d, a_curr_panel_height: %d, the offset a_curr_panel_row_start * k_size + col: %d\n", - work_id_packa, a_curr_panel_row_start, a_curr_panel_height, a_curr_panel_row_start * k_size + col) + # printf("work_id_packa: %d, a_curr_panel_row_start: %d, a_curr_panel_height: %d, the offset a_curr_panel_row_start * k_size + col: %d\n", + # work_id_packa, a_curr_panel_row_start, a_curr_panel_height, a_curr_panel_row_start * k_size + col) v0 = avx_f32x8_load(a_curr_panel_col) v1 = avx_f32x8_load(a_curr_panel_col + k_size) @@ -550,13 +547,13 @@ def gemm_pack_b( ): npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR - if comm_id_packb == 0: - printf("loop4_partition_b_width: %d; loop4_partition_b_height: %d\n", loop4_partition_b_width, loop4_partition_b_height) + # if comm_id_packb == 0: + # printf("loop4_partition_b_width: %d; loop4_partition_b_height: %d\n", loop4_partition_b_width, loop4_partition_b_height) npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR - printf("work_id_packb: %d; npabels_b: %d, packedb_panel_stride: %d\n", work_id_packb, npanels_b, packedb_panel_stride) + # printf("work_id_packb: %d; npabels_b: %d, packedb_panel_stride: %d\n\n", work_id_packb, npanels_b, packedb_panel_stride) # Loop for the packing of B for i_panel in range(npanels_b): @@ -568,8 +565,8 @@ def gemm_pack_b( curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) - printf("work_id_packb: %d; curr_panel_start: %d; curr_panel_width: %d; the offset i_panel * packedb_panel_stride: %d\n", - work_id_packb, curr_panel_start, curr_panel_width, i_panel * packedb_panel_stride) + # printf("work_id_packb: %d; curr_panel_start: %d; curr_panel_width: %d; the offset i_panel * packedb_panel_stride: %d\n\n", + # work_id_packb, curr_panel_start, curr_panel_width, i_panel * packedb_panel_stride) if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 @@ -582,6 +579,14 @@ def gemm_pack_b( row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) + # b00_debug = b_panel[0] + # b01_debug = b_panel[1] + # b02_debug = b_panel[2] + # b03_debug = b_panel[3] + # b04_debug = b_panel[4] + # b05_debug = b_panel[5] + # b06_debug = b_panel[6] + # b avx_f32x8_store(packed_b_buff_curr, b00) avx_f32x8_store(packed_b_buff_curr + 8, b08) @@ -638,7 +643,6 @@ def gemm_pack_b( packed_b_buff_curr += 16 row = k_iters * 8 - # printf("After the unrolled-by-8 loop, row: %d\n", row) for _ in range(k_remainder): b_panel = loop4_partition_b + ( row * n_size + curr_panel_start) @@ -702,8 +706,8 @@ def gemm_macro( m_iter = macro_m // MR m_remainder = macro_m % MR - printf("The macro kernel with comm_id_macro: %d, work_id_macro: %d , macro_m: %d, macro_n: %d, macro_k: %d, c_row_off: %d, c_col_off: %d, ps_packed_a: %d, ps_packed_b: %d\n", - comm_id_macro, work_id_macro, macro_m, macro_n, macro_k, c_row_off, c_col_off, ps_packed_a, ps_packed_b) + # printf("The macro kernel with comm_id_macro: %d, work_id_macro: %d , macro_m: %d, macro_n: %d, macro_k: %d, c_row_off: %d, c_col_off: %d, ps_packed_a: %d, ps_packed_b: %d\n", + # comm_id_macro, work_id_macro, macro_m, macro_n, macro_k, c_row_off, c_col_off, ps_packed_a, ps_packed_b) if n_remainder > 0: n_iter += 1 @@ -737,10 +741,10 @@ def gemm_macro( ~ir_inc ) - # printf("jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d, work_id_macro: %d, work_id_1st_loop: %d\n", - # jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc, work_id_macro, work_id_1st_loop) + printf("jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d, work_id_macro: %d, work_id_1st_loop: %d\n\n", + jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc, work_id_macro, work_id_1st_loop) - printf("work_id_macro: %d, work_id_1st_loop: %d, jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d\n", + printf("work_id_macro: %d, work_id_1st_loop: %d, jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d\n\n", work_id_macro, work_id_1st_loop, jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc) rs_packeda = 1 @@ -773,14 +777,16 @@ def gemm_macro( c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - printf("The i value is %d for the macro work id %d. The offset i * rstep_a: %d, the offset i * rstep_c: %d, m_cur: %d\n", + printf("The i value is %d for the macro work id %d. The offset i * rstep_a: %d, the offset i * rstep_c: %d, m_cur: %d\n\n", i, work_id_macro, i * rstep_a, i * rstep_c, m_cur) if m_cur == MR and n_cur == NR: - micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) + # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) + micro_kernel(a1, b1, c11, macro_k, m_size, n_size, is_first) else: for i, j in grid(MR, NR): temp_c[i, j] = 0.0 + # micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) if not is_first: for mm, nn in grid(m_cur, n_cur): @@ -813,8 +819,6 @@ def gemm_3rd_loop( work_id_packa = comm_id_macro packa_nways = macro_nthreads - printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, is_first: %d\n", - comm_id_3rd_loop, work_id_3rd_loop, loop3_partition_a_start_col, loop3_partition_b_start_col, is_first) m_start_loop3 = 0 m_end_loop3 = 0 @@ -826,11 +830,16 @@ def gemm_3rd_loop( ~m_start_loop3, ~m_end_loop3 ) + + printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, m_start_loop3: %d, m_end_loop3: %d, is_first: %d\n\n", + comm_id_3rd_loop, work_id_3rd_loop, loop3_partition_a_start_col, loop3_partition_b_start_col, m_start_loop3, m_end_loop3, is_first) + ii = m_start_loop3 while ii < m_end_loop3: b_alg_loop3 = determine_blocksize_f_sub( ii, m_size, MC ) + b_alg_loop3 = min(b_alg_loop3, m_end_loop3 - ii) loop3_partition_a_start_row = ii loop3_partition_a_height = b_alg_loop3 @@ -838,6 +847,8 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) + printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, ii: %d, b_alg_loop3: %d\n\n", + comm_id_3rd_loop, work_id_3rd_loop, ii, b_alg_loop3) packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) @@ -911,10 +922,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_b_start_row = i_loop4 loop4_partition_b_start_col = loop5_partition_b_start_col - printf("work_id_4th_loop: %d, work_id_5th_loop: %d, loop4_partition_b_height: %d, loop4_partition_b_width: %d, loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d\n", - work_id_4th_loop, work_id_5th_loop, loop4_partition_b_height, loop4_partition_b_width, loop4_partition_b_start_row, loop4_partition_b_start_col) - - loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) @@ -927,7 +934,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - printf("work_id_4th_loop: %d, comm_id_4th_loop: %d, the offset packed_b_individual_size * work_id_5th_loop: %d, the offset loop4_partition_b_start_row * n_size + loop4_partition_b_start_col: %d\n", + printf("work_id_4th_loop: %d, comm_id_4th_loop: %d, the offset packed_b_individual_size * work_id_5th_loop: %d, the offset loop4_partition_b_start_row * n_size + loop4_partition_b_start_col: %d\n\n", work_id_4th_loop, comm_id_4th_loop, packed_b_individual_size * work_id_5th_loop, loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) @@ -964,6 +971,14 @@ def gemm_4th_loop(a: float32[m_size, k_size], is_first ) + # # TODO: Is not adding this barrier at the end the problem? + # thrcomm_barrier( + # comm_id_packb, + # ~packb_thrcomm_barrier_sense[work_id_4th_loop], + # ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + # packb_nthreads + # ) + i_loop4 += b_alg_loop4 gemm_4th_loop.kind = "cpu_internal" @@ -982,9 +997,6 @@ def gemm_5th_loop(a: float32[m_size, k_size], thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) - printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, loop5_my_start: %d, loop5_my_end: %d\n", - work_id_5th_loop, comm_id_5th_loop, loop5_my_start, loop5_my_end) - loop5_iter = loop5_my_start while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, @@ -995,7 +1007,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], loop5_partition_b_width = b_alg_loop5, loop5_partition_b_start_col = loop5_iter - printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d\n", + printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d\n\n", work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, loop5_partition_b_start_col) gemm_4th_loop(a, b, c, loop5_partition_b_width, @@ -1019,19 +1031,20 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], loop5_nways) # the nthreads and nways print for each loop + printf("nthreads: %d\n", nthreads) printf("loop5_nthreads: %d, loop5_nways: %d\n", loop5_nthreads, loop5_nways) - printf("loop4_nthreads: %d, loop4_nways: %d\n", loop4_nthreads, loop4_nways) - printf("loop3_nthreads: %d, loop3_nways: %d\n", loop3_nthreads, loop3_nways) - printf("macro_nthreads: %d, macro_nways: %d\n", macro_nthreads, macro_nways) - printf("loop1_nthreads: %d, loop1_nways: %d\n", loop1_nthreads, loop1_nways) + # printf("loop4_nthreads: %d, loop4_nways: %d\n", loop4_nthreads, loop4_nways) + # printf("loop3_nthreads: %d, loop3_nways: %d\n", loop3_nthreads, loop3_nways) + # printf("macro_nthreads: %d, macro_nways: %d\n", macro_nthreads, macro_nways) + # printf("loop1_nthreads: %d, loop1_nways: %d\n", loop1_nthreads, loop1_nways) - printf("packb_nthreads: %d, packa_nthreads: %d\n", packb_nthreads, packa_nthreads) + # printf("packb_nthreads: %d, packa_nthreads: %d\n", packb_nthreads, packa_nthreads) printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, packed_b_total_width, packed_b_height) printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", packed_a_width, packed_a_individual_height, packed_a_total_height) - printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) - printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) + # printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) + # printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) parallel_attr = 'p' + str(nthreads) @@ -1040,6 +1053,7 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], tid_5th_loop = tidx work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop + printf("tidx: %d, tid_5th_loop: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, tid_5th_loop, work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) diff --git a/python/mat_new.py b/python/mat_new.py index b62aa5b39..daf8f6633 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -34,31 +34,24 @@ def matmul_ansor(M, K, N, dtype): hidet.option.search_space(0) np.random.seed(42) -for m, n, k in [(6, 17, 1)]: -# for m, n, k in [(64, 64, 64)]: -# for m, n, k in [(16, 16, 16), (64, 64, 64), (211, 333, 222), (768, 768, 768)]: - a = hidet.ones([m, k], device='cpu') - b = hidet.ones([k, n], device='cpu') - +# for m, n, k in [(33, 65, 60), (32, 92, 128)]: +for m, n, k in [(7, 1, 17)]: # a = hidet.randn([m, k], device='cpu') # b = hidet.randn([k, n], device='cpu') - an = torch.ones(m, k, dtype=torch.float32) - bn = torch.ones(k, n, dtype=torch.float32) - counter=0 - for i in range(m): - for j in range(k): - an[i, j] = counter - counter += 1 - counter = 0 - for i in range(k): - for j in range(n): - bn[i, j] = counter - counter += 1 + a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') + b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') + # + # print(f"a_torch: {a_torch}") + # print(f"b_torch: {b_torch}") - a = hidet.from_torch(an) - b = hidet.from_torch(bn) + a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') + b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') + print(f"a: {a}") + print(f"b: {b}") + # a = hidet.ones([m, k], device='cpu') + # b = hidet.ones([k, n], device='cpu') x1 = hidet.symbol_like(a) @@ -74,12 +67,18 @@ def matmul_ansor(M, K, N, dtype): actual = c.numpy() desired = a.numpy() @ b.numpy() + fails = 0 + for i in range(m): for j in range(n): if abs(actual[i, j] - desired[i, j]) < 1e-3: - print(f"Actually passed for i={i}, j={j}") - else: + # print(f"Actually passed for i={i}, j={j}") + continue + else: print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") + fails += 1 + + print(f"Total fails: {fails}") # for i in range(m): # for j in range(n): From d2e1ab4e409a26dcf030befe96a84d7719d74481 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 13:25:16 -0500 Subject: [PATCH 120/148] fixed a dumb bug that got me stuck for way too much longer than necessary --- .../ops/matmul/matmul_f32_x86_refactored.py | 297 +++++++++++------- python/mat_new.py | 22 +- 2 files changed, 198 insertions(+), 121 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 2e089c7ef..1d2dda345 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -83,10 +83,9 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=36, NC=32, KC=16, ways=(1, 2, 1, 1) + self, MC=6, NC=16, KC=8, ways=(2, 1, 1, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -108,14 +107,12 @@ def schedule_matmulf32_x86( tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') - with hidet.script_module() as module: # Get the number of threads... loop5_nways, loop3_nways, macro_nways, loop1_nways = ways loop4_nways = 1 nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways - packa_thrcomm_barrier_sense = module.define_global_var( name="pack_a_barrier_sense", var_type=int32[loop3_nways] @@ -150,7 +147,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): NC // NR) * row_major( KC, NR)) - # Get the number of threads remaining at each level loop5_nthreads = nthreads loop4_nthreads = nthreads // loop5_nways @@ -161,7 +157,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): packb_nthreads = loop3_nthreads packa_nthreads = macro_nthreads - @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): if n_way == 1: @@ -201,6 +196,7 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~ # Add the remainder to the last thread's end if work_id == n_way - 1: end[0] += n_bf_left + thread_range_sub.kind = "cpu_internal" @hidet.script @@ -228,6 +224,7 @@ def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: @hidet.script def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: return i != n_iter - 1 or n_left == 0 + not_edge.kind = 'cpu_internal' # Thread barrier @@ -274,19 +271,19 @@ def micro_kernel( # printf("The nsize in the micro kernel: %d\n", nsize) # printf("The pb in the micro kernel: %d\n", pb) - if is_first: - c0 = avx_f32x8_setzero() - c08 = avx_f32x8_setzero() - c1 = avx_f32x8_setzero() - c18 = avx_f32x8_setzero() - c2 = avx_f32x8_setzero() - c28 = avx_f32x8_setzero() - c3 = avx_f32x8_setzero() - c38 = avx_f32x8_setzero() - c4 = avx_f32x8_setzero() - c48 = avx_f32x8_setzero() - c5 = avx_f32x8_setzero() - c58 = avx_f32x8_setzero() + # if is_first: + # c0 = avx_f32x8_setzero() + # c08 = avx_f32x8_setzero() + # c1 = avx_f32x8_setzero() + # c18 = avx_f32x8_setzero() + # c2 = avx_f32x8_setzero() + # c28 = avx_f32x8_setzero() + # c3 = avx_f32x8_setzero() + # c38 = avx_f32x8_setzero() + # c4 = avx_f32x8_setzero() + # c48 = avx_f32x8_setzero() + # c5 = avx_f32x8_setzero() + # c58 = avx_f32x8_setzero() a_ptr = cast(a, ~float32) b_ptr = cast(b, ~float32) @@ -295,39 +292,34 @@ def micro_kernel( bb0to7 = avx_f32x8_load_aligned(b_ptr) bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - - - aa1 = avx_f32x8_broadcast(a_ptr) c0 = avx_f32x8_fmadd(aa1, bb0to7, c0) c08 = avx_f32x8_fmadd(aa1, bb8to15, c08) - # printf("broadcasted aa: %lf\n", a_ptr[0]) aa2 = avx_f32x8_broadcast(a_ptr + 1) c1 = avx_f32x8_fmadd(aa2, bb0to7, c1) c18 = avx_f32x8_fmadd(aa2, bb8to15, c18) - # printf("broadcasted aa: %lf\n", a_ptr[1]) aa3 = avx_f32x8_broadcast(a_ptr + 2) c2 = avx_f32x8_fmadd(aa3, bb0to7, c2) c28 = avx_f32x8_fmadd(aa3, bb8to15, c28) - # printf("broadcasted aa: %lf\n", a_ptr[2]) aa4 = avx_f32x8_broadcast(a_ptr + 3) c3 = avx_f32x8_fmadd(aa4, bb0to7, c3) c38 = avx_f32x8_fmadd(aa4, bb8to15, c38) - # printf("broadcasted aa: %lf\n", a_ptr[3]) aa5 = avx_f32x8_broadcast(a_ptr + 4) c4 = avx_f32x8_fmadd(aa5, bb0to7, c4) c48 = avx_f32x8_fmadd(aa5, bb8to15, c48) - # printf("broadcasted aa: %lf\n", a_ptr[4]) aa6 = avx_f32x8_broadcast(a_ptr + 5) c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) - # printf("broadcasted aa: %lf\n", a_ptr[5]) - # printf("List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) + printf( + "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", + a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], + b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], + b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) a_ptr = a_ptr + 6 b_ptr = b_ptr + 16 @@ -340,7 +332,7 @@ def micro_kernel( avx_f32x8_store(c_ptr + (nsize + 8), c18) avx_f32x8_store(c_ptr + 2 * nsize, c2) - avx_f32x8_store(c_ptr + (2 * nsize+ 8), c28) + avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) avx_f32x8_store(c_ptr + 3 * nsize, c3) avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) @@ -366,7 +358,8 @@ def micro_kernel( packed_a_individual_height = MC if packed_a_individual_height > m_size: packed_a_individual_height = (m_size + MR - 1) // MR * MR - packed_a_total_height = packed_a_individual_height * loop3_nways + # packed_a_total_height = packed_a_individual_height * loop3_nways + packed_a_total_height = packed_a_individual_height * loop5_nways packed_a_width = KC if packed_a_width > k_size: @@ -383,8 +376,6 @@ def micro_kernel( var_type=float32[packed_a_total_size] ) - - packb_buf = cast(packb_buf_ptr, ~float32) packa_buf = cast(packa_buf_ptr, ~float32) @@ -393,7 +384,6 @@ def micro_kernel( layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) ) - ##### Start of the loops around micro kernel ##### @hidet.script @@ -416,7 +406,6 @@ def gemm_pack_a( # printf("work_id_packa: %d, packa_nways: %d, loop3_partition_a_width: %d, loop3_partition_a_height: %d\n", # work_id_packa, packa_nways, loop3_partition_a_width, loop3_partition_a_height) - npanels_full_a = loop3_partition_a_height // MR panel_a_remainder = loop3_partition_a_height % MR @@ -519,7 +508,7 @@ def gemm_pack_a( a_curr_panel_row_start + micropanel_row, curr_remain_col] = \ loop3_partition_a[( - micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] + micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] else: remain_start_row = npanels_full_a * MR for remain_col in range(loop3_partition_a_width): @@ -527,7 +516,7 @@ def gemm_pack_a( packed_a_tensor[ remain_start_row + remain_row, remain_col] = \ loop3_partition_a[( - remain_row + remain_start_row) * k_size + remain_col] + remain_row + remain_start_row) * k_size + remain_col] remain_row = panel_a_remainder while remain_row < MR: packed_a_tensor[ @@ -560,13 +549,14 @@ def gemm_pack_b( if i_panel % packb_nways != work_id_packb % packb_nways: continue packed_b_buff_curr = packed_b_buf + ( - i_panel * packedb_panel_stride) + i_panel * packedb_panel_stride) curr_panel_start = i_panel * NR curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) - # printf("work_id_packb: %d; curr_panel_start: %d; curr_panel_width: %d; the offset i_panel * packedb_panel_stride: %d\n\n", - # work_id_packb, curr_panel_start, curr_panel_width, i_panel * packedb_panel_stride) + # printf( + # "work_id_packb: %d; curr_panel_start: %d; curr_panel_width: %d; the offset i_panel * packedb_panel_stride: %d\n\n", + # work_id_packb, curr_panel_start, curr_panel_width, i_panel * packedb_panel_stride) if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 @@ -576,17 +566,9 @@ def gemm_pack_b( for k_iter in range(k_iters): row = k_iter * 8 b_panel = loop4_partition_b + ( - row * n_size + curr_panel_start) + row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) - # b00_debug = b_panel[0] - # b01_debug = b_panel[1] - # b02_debug = b_panel[2] - # b03_debug = b_panel[3] - # b04_debug = b_panel[4] - # b05_debug = b_panel[5] - # b06_debug = b_panel[6] - # b avx_f32x8_store(packed_b_buff_curr, b00) avx_f32x8_store(packed_b_buff_curr + 8, b08) @@ -645,7 +627,7 @@ def gemm_pack_b( row = k_iters * 8 for _ in range(k_remainder): b_panel = loop4_partition_b + ( - row * n_size + curr_panel_start) + row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) avx_f32x8_store(packed_b_buff_curr, b00) @@ -655,19 +637,19 @@ def gemm_pack_b( else: packed_b_remaining_buf = packed_b_buf + ( - npanels_full_b * packedb_panel_stride) + npanels_full_b * packedb_panel_stride) if npanels_b_remainder > 0: # TODO: I think this if should always be true if this is executed? remain_col_start = npanels_full_b * NR for remain_row in range(loop4_partition_b_height): packed_b_remaining_buf_curr = packed_b_remaining_buf + ( - remain_row * NR) + remain_row * NR) for remain_col in range(npanels_b_remainder): packed_b_remaining_buf_curr[0] = \ - loop4_partition_b[ - (remain_row * n_size) + ( + loop4_partition_b[ + (remain_row * n_size) + ( remain_col_start + remain_col) - ] + ] packed_b_remaining_buf_curr += 1 zero_fill_col = npanels_b_remainder while zero_fill_col < NR: @@ -676,7 +658,6 @@ def gemm_pack_b( zero_fill_col += 1 # printf("The end of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) - gemm_pack_b.kind = "cpu_internal" gemm_pack_a.kind = "cpu_internal" micro_kernel.kind = "cpu_internal" @@ -695,7 +676,7 @@ def gemm_macro( ps_packed_b: int32, comm_id_macro: int32, work_id_macro: int32, - is_first: bool + is_first: bool, work_id_3rd_loop: int32, work_id_4th_loop: int32, work_id_5th_loop: int32 ): # assert loop1_nthreads == 1 comm_id_1st_loop = comm_id_macro % loop1_nthreads @@ -706,8 +687,10 @@ def gemm_macro( m_iter = macro_m // MR m_remainder = macro_m % MR - # printf("The macro kernel with comm_id_macro: %d, work_id_macro: %d , macro_m: %d, macro_n: %d, macro_k: %d, c_row_off: %d, c_col_off: %d, ps_packed_a: %d, ps_packed_b: %d\n", - # comm_id_macro, work_id_macro, macro_m, macro_n, macro_k, c_row_off, c_col_off, ps_packed_a, ps_packed_b) + # printf( + # "The macro kernel with comm_id_macro: %d, work_id_macro: %d , macro_m: %d, macro_n: %d, macro_k: %d, c_row_off: %d, c_col_off: %d, ps_packed_a: %d, ps_packed_b: %d\n", + # comm_id_macro, work_id_macro, macro_m, macro_n, macro_k, c_row_off, c_col_off, ps_packed_a, + # ps_packed_b) if n_remainder > 0: n_iter += 1 @@ -741,11 +724,13 @@ def gemm_macro( ~ir_inc ) - printf("jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d, work_id_macro: %d, work_id_1st_loop: %d\n\n", - jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc, work_id_macro, work_id_1st_loop) - - printf("work_id_macro: %d, work_id_1st_loop: %d, jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d\n\n", - work_id_macro, work_id_1st_loop, jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc) + # printf( + # "jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d, work_id_macro: %d, work_id_1st_loop: %d\n\n", + # jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc, work_id_macro, work_id_1st_loop) + # + # printf( + # "work_id_macro: %d, work_id_1st_loop: %d, jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d\n\n", + # work_id_macro, work_id_1st_loop, jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc) rs_packeda = 1 rstep_a = ps_packed_a @@ -760,9 +745,9 @@ def gemm_macro( shape=(m_size, n_size) ) temp_c = tensor(scope=DeclareScope.Default, - dtype=float32, - layout=row_major(MR, NR), - is_static=True) + dtype=float32, + layout=row_major(MR, NR), + is_static=True) j = jr_start while j < jr_end: b1 = packed_b + j * cstep_b @@ -777,8 +762,26 @@ def gemm_macro( c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - printf("The i value is %d for the macro work id %d. The offset i * rstep_a: %d, the offset i * rstep_c: %d, m_cur: %d\n\n", - i, work_id_macro, i * rstep_a, i * rstep_c, m_cur) + printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, + macro_m: %d, macro_n: %d, macro_k: %d, + ps_packed_a: %d, ps_packed_b: %d, , + n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, + jr_start: %d, jr_end: %d, jr_inc: %d, + ir_start: %d, ir_end: %d, ir_inc: %d, + rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, + n_cur: %d, m_cur: %d \n\n''', + work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, + macro_m, macro_n, macro_k, + ps_packed_a, ps_packed_b, + n_iter, n_remainder, m_iter, m_remainder, + jr_start, jr_end, jr_inc, + ir_start, ir_end, ir_inc, + rstep_a, cstep_b, cstep_c, rstep_c, + n_cur, m_cur) + + # printf( + # "The i value is %d for the macro work id %d. The offset i * rstep_a: %d, the offset i * rstep_c: %d, m_cur: %d\n\n", + # i, work_id_macro, i * rstep_a, i * rstep_c, m_cur) if m_cur == MR and n_cur == NR: # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) @@ -793,14 +796,14 @@ def gemm_macro( c11[mm, nn] += temp_c[mm, nn] else: for mm, nn in grid(m_cur, n_cur): - c11[mm, nn] = temp_c[mm, nn] + # c11[mm, nn] = temp_c[mm, nn] FIXME: temporarily changed to see if zero-initing is the problem(well, it is not.....) + c11[mm, nn] += temp_c[mm, nn] i += ir_inc j += jr_inc gemm_macro.kind = "cpu_internal" - @hidet.script def gemm_3rd_loop( a: float32[m_size, k_size], @@ -812,14 +815,13 @@ def gemm_3rd_loop( loop3_partition_b_width: int32, comm_id_3rd_loop: int32, work_id_3rd_loop: int32, - is_first: bool): + is_first: bool, work_id_4th_loop: int32, work_id_5th_loop: int32): comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) comm_id_packa = comm_id_macro work_id_packa = comm_id_macro packa_nways = macro_nthreads - m_start_loop3 = 0 m_end_loop3 = 0 thread_range_sub( @@ -831,8 +833,15 @@ def gemm_3rd_loop( ~m_end_loop3 ) - printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, m_start_loop3: %d, m_end_loop3: %d, is_first: %d\n\n", - comm_id_3rd_loop, work_id_3rd_loop, loop3_partition_a_start_col, loop3_partition_b_start_col, m_start_loop3, m_end_loop3, is_first) + # printf( + # "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, loop3_partition_a_width: %d, loop3_partition_b_width: %d, comm_id_3rd_loop: %d, work_id_3rd_loop: %d, is_first: %d," + # "m_start_loop3: %d, m_end_loop3: %d\n\n", + # loop3_partition_a_start_col, loop3_partition_b_start_col, loop3_partition_a_width, + # loop3_partition_b_width, comm_id_3rd_loop, work_id_3rd_loop, is_first, + # m_start_loop3, m_end_loop3) + + # printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, m_start_loop3: %d, m_end_loop3: %d, is_first: %d\n\n", + # comm_id_3rd_loop, work_id_3rd_loop, loop3_partition_a_start_col, loop3_partition_b_start_col, m_start_loop3, m_end_loop3, is_first) ii = m_start_loop3 while ii < m_end_loop3: @@ -844,21 +853,43 @@ def gemm_3rd_loop( loop3_partition_a_height = b_alg_loop3 loop3_partition_a = cast(a, ~float32) + ( - loop3_partition_a_start_row * k_size + - loop3_partition_a_start_col + loop3_partition_a_start_row * k_size + + loop3_partition_a_start_col ) - printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, ii: %d, b_alg_loop3: %d\n\n", - comm_id_3rd_loop, work_id_3rd_loop, ii, b_alg_loop3) - - packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) + # printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, ii: %d, b_alg_loop3: %d\n\n", + # comm_id_3rd_loop, work_id_3rd_loop, ii, b_alg_loop3) + printf( + "work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, " + "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, " + "loop3_partition_a_width: %d, loop3_partition_b_width: %d, " + "loop3_partition_a_start_row: %d, loop3_partition_a_height: %d, " + "m_start_loop3: %d, m_end_loop3: %d, ii: %d, b_alg_loop3: %d\n\n", + work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, + loop3_partition_a_start_col, loop3_partition_b_start_col, + loop3_partition_a_width, loop3_partition_b_width, + loop3_partition_a_start_row, loop3_partition_a_height, + m_start_loop3, m_end_loop3, ii, b_alg_loop3) + + # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) + packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) # TODO: If passed, see if this barrier is necessary + printf( + "Begin: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + loop3_partition_a_start_col, + loop3_partition_b_start_col) thrcomm_barrier( comm_id_packa, ~packa_thrcomm_barrier_sense[work_id_3rd_loop], ~packa_thrcomm_threads_arrived[work_id_3rd_loop], packa_nthreads ) + printf( + "End: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + loop3_partition_a_start_col, + loop3_partition_b_start_col) gemm_pack_a( loop3_partition_a, @@ -872,12 +903,22 @@ def gemm_3rd_loop( # This marks the end of the packing of A, # so a barrier is needed + printf( + "Begin: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + loop3_partition_a_start_col, + loop3_partition_b_start_col) thrcomm_barrier( comm_id_packa, ~packa_thrcomm_barrier_sense[work_id_3rd_loop], ~packa_thrcomm_threads_arrived[work_id_3rd_loop], packa_nthreads ) + printf( + "End: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + loop3_partition_a_start_col, + loop3_partition_b_start_col) gemm_macro(packed_a_buf, packed_b, @@ -891,7 +932,10 @@ def gemm_3rd_loop( packed_b_height * NR, comm_id_macro, work_id_macro, - is_first + is_first, + work_id_3rd_loop, + work_id_4th_loop, + work_id_5th_loop, ) ii += b_alg_loop3 @@ -916,6 +960,7 @@ def gemm_4th_loop(a: float32[m_size, k_size], while i_loop4 < k_size: b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) + b_alg_loop4 = min(b_alg_loop4, k_size - i_loop4) loop4_partition_b_height = b_alg_loop4 loop4_partition_b_width = loop5_partition_b_width @@ -925,39 +970,53 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) + printf( + "work_id_5th_loop: %d, work_id_4th_loop: %d, i_loop4: %d, b_alg_loop4: %d, loop4_partition_b_height: %d, loop4_partition_b_width: %d," + "loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d, loop4_partition_a_start_col: %d, is_first: %d\n", + work_id_5th_loop, work_id_4th_loop, i_loop4, b_alg_loop4, loop4_partition_b_height, loop4_partition_b_width, + loop4_partition_b_start_row, + loop4_partition_b_start_col, loop4_partition_a_start_col, is_first) packed_b_buf = packb_buf + ( - packed_b_individual_size * work_id_5th_loop + packed_b_individual_size * work_id_5th_loop ) loop4_partition_b = cast(b, ~float32) + \ - (loop4_partition_b_start_row * n_size + - loop4_partition_b_start_col) - - printf("work_id_4th_loop: %d, comm_id_4th_loop: %d, the offset packed_b_individual_size * work_id_5th_loop: %d, the offset loop4_partition_b_start_row * n_size + loop4_partition_b_start_col: %d\n\n", - work_id_4th_loop, comm_id_4th_loop, packed_b_individual_size * work_id_5th_loop, loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) + (loop4_partition_b_start_row * n_size + + loop4_partition_b_start_col) + # printf("work_id_4th_loop: %d, comm_id_4th_loop: %d, the offset packed_b_individual_size * work_id_5th_loop: %d, " + # "the offset loop4_partition_b_start_row * n_size + loop4_partition_b_start_col: %d," + # "loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d, \n\n", + # work_id_4th_loop, comm_id_4th_loop, packed_b_individual_size * work_id_5th_loop, loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - # # TODO: If passed, see if this barrier is really needed - thrcomm_barrier( - comm_id_packb, - ~packb_thrcomm_barrier_sense[work_id_4th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], - packb_nthreads - ) - + # # # TODO: If passed, see if this barrier is really needed + # printf("Begin: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) + # thrcomm_barrier( + # comm_id_packb, + # ~packb_thrcomm_barrier_sense[work_id_4th_loop], + # ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + # packb_nthreads + # ) + # printf("End: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) gemm_pack_b(loop4_partition_b, loop4_partition_b_width, loop4_partition_b_height, packed_b_buf, comm_id_packb, work_id_packb, loop3_nthreads) # The barrier at the end of the packing of B + printf( + "Begin: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", + work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_4th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], packb_nthreads ) + printf( + "End: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", + work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) # TODO: The loop3 and beyond should start here? gemm_3rd_loop( @@ -968,20 +1027,28 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_b_width, comm_id_3rd_loop, work_id_3rd_loop, - is_first + is_first, + work_id_4th_loop, + work_id_5th_loop ) # # TODO: Is not adding this barrier at the end the problem? - # thrcomm_barrier( - # comm_id_packb, - # ~packb_thrcomm_barrier_sense[work_id_4th_loop], - # ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], - # packb_nthreads - # ) + printf( + "Begin: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d, work_id_5th_loop: %d\n", + work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) + thrcomm_barrier( + comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_4th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + packb_nthreads + ) + printf( + "End: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", + work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) i_loop4 += b_alg_loop4 - gemm_4th_loop.kind = "cpu_internal" + gemm_4th_loop.kind = "cpu_internal" @hidet.script def gemm_5th_loop(a: float32[m_size, k_size], @@ -1007,8 +1074,10 @@ def gemm_5th_loop(a: float32[m_size, k_size], loop5_partition_b_width = b_alg_loop5, loop5_partition_b_start_col = loop5_iter - printf("work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d\n\n", - work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, loop5_partition_b_start_col) + printf( + "work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d, loop5_iter: %d, loop5_my_start: %d, loop5_my_end: %d\n\n", + work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, + loop5_partition_b_start_col, loop5_iter, loop5_my_start, loop5_my_end) gemm_4th_loop(a, b, c, loop5_partition_b_width, loop5_partition_b_start_col, @@ -1016,6 +1085,7 @@ def gemm_5th_loop(a: float32[m_size, k_size], work_id_4th_loop, work_id_5th_loop) loop5_iter += b_alg_loop5 + gemm_5th_loop.kind = 'cpu_internal' ################### Start of the main kernel ################### @@ -1031,8 +1101,8 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], loop5_nways) # the nthreads and nways print for each loop - printf("nthreads: %d\n", nthreads) - printf("loop5_nthreads: %d, loop5_nways: %d\n", loop5_nthreads, loop5_nways) + # printf("nthreads: %d\n", nthreads) + # printf("loop5_nthreads: %d, loop5_nways: %d\n", loop5_nthreads, loop5_nways) # printf("loop4_nthreads: %d, loop4_nways: %d\n", loop4_nthreads, loop4_nways) # printf("loop3_nthreads: %d, loop3_nways: %d\n", loop3_nthreads, loop3_nways) # printf("macro_nthreads: %d, macro_nways: %d\n", macro_nthreads, macro_nways) @@ -1040,12 +1110,17 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], # printf("packb_nthreads: %d, packa_nthreads: %d\n", packb_nthreads, packa_nthreads) - printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, packed_b_total_width, packed_b_height) - printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", packed_a_width, packed_a_individual_height, packed_a_total_height) + printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, + packed_b_total_width, packed_b_height) + printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", + packed_a_width, packed_a_individual_height, packed_a_total_height) # printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) # printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) + for i in grid(m_size): + for j in grid(n_size): + c[i, j] = 0.0 parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads @@ -1053,7 +1128,7 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], tid_5th_loop = tidx work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop - printf("tidx: %d, tid_5th_loop: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, tid_5th_loop, work_id_5th_loop, comm_id_5th_loop) + # printf("tidx: %d, tid_5th_loop: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, tid_5th_loop, work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) diff --git a/python/mat_new.py b/python/mat_new.py index daf8f6633..29d8713a8 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -35,8 +35,10 @@ def matmul_ansor(M, K, N, dtype): np.random.seed(42) # for m, n, k in [(33, 65, 60), (32, 92, 128)]: -for m, n, k in [(7, 1, 17)]: - # a = hidet.randn([m, k], device='cpu') +# for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: +# for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: +for m, n, k in [(7, 17, 1), (333, 444, 555), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: + # a = hidet.randn([m, k], device='cpuO') # b = hidet.randn([k, n], device='cpu') a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') @@ -69,14 +71,14 @@ def matmul_ansor(M, K, N, dtype): fails = 0 - for i in range(m): - for j in range(n): - if abs(actual[i, j] - desired[i, j]) < 1e-3: - # print(f"Actually passed for i={i}, j={j}") - continue - else: - print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") - fails += 1 + # for i in range(m): + # for j in range(n): + # if abs(actual[i, j] - desired[i, j]) < 1e-3: + # # print(f"Actually passed for i={i}, j={j}") + # continue + # else: + # print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") + # fails += 1 print(f"Total fails: {fails}") From 0c0efe0ee98785927f7b82475f6f0056f477161c Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 14:14:24 -0500 Subject: [PATCH 121/148] . --- .../ops/matmul/matmul_f32_x86_refactored.py | 70 ++----------------- 1 file changed, 5 insertions(+), 65 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 1d2dda345..5d78ce4d6 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -267,10 +267,6 @@ def micro_kernel( c5 = avx_f32x8_load(~c[5, 0]) c58 = avx_f32x8_load(~c[5, 8]) - # printf("The msize in the micro kernel: %d\n", msize) - # printf("The nsize in the micro kernel: %d\n", nsize) - # printf("The pb in the micro kernel: %d\n", pb) - # if is_first: # c0 = avx_f32x8_setzero() # c08 = avx_f32x8_setzero() @@ -315,11 +311,11 @@ def micro_kernel( aa6 = avx_f32x8_broadcast(a_ptr + 5) c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) - printf( - "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", - a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], - b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], - b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) + # printf( + # "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", + # a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], + # b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], + # b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) a_ptr = a_ptr + 6 b_ptr = b_ptr + 16 @@ -403,15 +399,9 @@ def gemm_pack_a( column_major(MR, packed_a_width) ) - # printf("work_id_packa: %d, packa_nways: %d, loop3_partition_a_width: %d, loop3_partition_a_height: %d\n", - # work_id_packa, packa_nways, loop3_partition_a_width, loop3_partition_a_height) - npanels_full_a = loop3_partition_a_height // MR panel_a_remainder = loop3_partition_a_height % MR - # printf("loop3_partition_a_height: %d\n", loop3_partition_a_height) - # printf("npanels_full_a: %d\n", npanels_full_a) - npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) for ii_panel in range(npanels_a): @@ -421,8 +411,6 @@ def gemm_pack_a( a_curr_panel_row_start = ii_panel * MR a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) - # printf("packing of a: the panel %d is taken care of by the thread with pack a work id: %d; the a_curr_panel_row_start: %d, a_curr_panel_height: %d\n\n", - # ii_panel, work_id_packa, a_curr_panel_row_start, a_curr_panel_height) if a_curr_panel_height == MR: # unroll the packing by 8 k_iters = loop3_partition_a_width // 8 @@ -434,10 +422,6 @@ def gemm_pack_a( a_curr_panel_row_start * k_size + col ) - # printf("In the packing of A: the offset a_curr_panel_row_start * k_size + col for id %d: %d\n", work_id_packa, a_curr_panel_row_start * k_size + col) - # printf("work_id_packa: %d, a_curr_panel_row_start: %d, a_curr_panel_height: %d, the offset a_curr_panel_row_start * k_size + col: %d\n", - # work_id_packa, a_curr_panel_row_start, a_curr_panel_height, a_curr_panel_row_start * k_size + col) - v0 = avx_f32x8_load(a_curr_panel_col) v1 = avx_f32x8_load(a_curr_panel_col + k_size) v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) @@ -522,8 +506,6 @@ def gemm_pack_a( packed_a_tensor[ remain_start_row + remain_row, remain_col] = 0.0 remain_row += 1 - # printf("The end of the pack a, comm id: %d, work id: %d\n", - # comm_id_packa, work_id_packa) @hidet.script def gemm_pack_b( @@ -536,13 +518,10 @@ def gemm_pack_b( ): npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR - # if comm_id_packb == 0: - # printf("loop4_partition_b_width: %d; loop4_partition_b_height: %d\n", loop4_partition_b_width, loop4_partition_b_height) npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR - # printf("work_id_packb: %d; npabels_b: %d, packedb_panel_stride: %d\n\n", work_id_packb, npanels_b, packedb_panel_stride) # Loop for the packing of B for i_panel in range(npanels_b): @@ -554,10 +533,6 @@ def gemm_pack_b( curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) - # printf( - # "work_id_packb: %d; curr_panel_start: %d; curr_panel_width: %d; the offset i_panel * packedb_panel_stride: %d\n\n", - # work_id_packb, curr_panel_start, curr_panel_width, i_panel * packedb_panel_stride) - if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 k_remainder = loop4_partition_b_height % 8 @@ -656,7 +631,6 @@ def gemm_pack_b( packed_b_remaining_buf_curr[0] = 0.0 packed_b_remaining_buf_curr += 1 zero_fill_col += 1 - # printf("The end of pack B, comm_id_packb: %d, work_id_packb: %d\n", comm_id_packb, work_id_packb) gemm_pack_b.kind = "cpu_internal" gemm_pack_a.kind = "cpu_internal" @@ -687,11 +661,6 @@ def gemm_macro( m_iter = macro_m // MR m_remainder = macro_m % MR - # printf( - # "The macro kernel with comm_id_macro: %d, work_id_macro: %d , macro_m: %d, macro_n: %d, macro_k: %d, c_row_off: %d, c_col_off: %d, ps_packed_a: %d, ps_packed_b: %d\n", - # comm_id_macro, work_id_macro, macro_m, macro_n, macro_k, c_row_off, c_col_off, ps_packed_a, - # ps_packed_b) - if n_remainder > 0: n_iter += 1 if m_remainder > 0: @@ -724,14 +693,6 @@ def gemm_macro( ~ir_inc ) - # printf( - # "jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d, work_id_macro: %d, work_id_1st_loop: %d\n\n", - # jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc, work_id_macro, work_id_1st_loop) - # - # printf( - # "work_id_macro: %d, work_id_1st_loop: %d, jr_start: %d, jr_end: %d, jr_inc: %d, ir_start: %d, ir_end: %d, ir_inc: %d\n\n", - # work_id_macro, work_id_1st_loop, jr_start, jr_end, jr_inc, ir_start, ir_end, ir_inc) - rs_packeda = 1 rstep_a = ps_packed_a cstep_b = ps_packed_b @@ -756,7 +717,6 @@ def gemm_macro( i = ir_start while i < ir_end: - # printf("i = %d\n", i) a1 = packed_a + i * rstep_a c11 = c1 + i * rstep_c c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) @@ -779,10 +739,6 @@ def gemm_macro( rstep_a, cstep_b, cstep_c, rstep_c, n_cur, m_cur) - # printf( - # "The i value is %d for the macro work id %d. The offset i * rstep_a: %d, the offset i * rstep_c: %d, m_cur: %d\n\n", - # i, work_id_macro, i * rstep_a, i * rstep_c, m_cur) - if m_cur == MR and n_cur == NR: # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) micro_kernel(a1, b1, c11, macro_k, m_size, n_size, is_first) @@ -833,16 +789,6 @@ def gemm_3rd_loop( ~m_end_loop3 ) - # printf( - # "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, loop3_partition_a_width: %d, loop3_partition_b_width: %d, comm_id_3rd_loop: %d, work_id_3rd_loop: %d, is_first: %d," - # "m_start_loop3: %d, m_end_loop3: %d\n\n", - # loop3_partition_a_start_col, loop3_partition_b_start_col, loop3_partition_a_width, - # loop3_partition_b_width, comm_id_3rd_loop, work_id_3rd_loop, is_first, - # m_start_loop3, m_end_loop3) - - # printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, m_start_loop3: %d, m_end_loop3: %d, is_first: %d\n\n", - # comm_id_3rd_loop, work_id_3rd_loop, loop3_partition_a_start_col, loop3_partition_b_start_col, m_start_loop3, m_end_loop3, is_first) - ii = m_start_loop3 while ii < m_end_loop3: b_alg_loop3 = determine_blocksize_f_sub( @@ -856,8 +802,6 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) - # printf("comm_id_3rd_loop: %d, work_id_3rd_loop: %d, ii: %d, b_alg_loop3: %d\n\n", - # comm_id_3rd_loop, work_id_3rd_loop, ii, b_alg_loop3) printf( "work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, " "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, " @@ -985,10 +929,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - # printf("work_id_4th_loop: %d, comm_id_4th_loop: %d, the offset packed_b_individual_size * work_id_5th_loop: %d, " - # "the offset loop4_partition_b_start_row * n_size + loop4_partition_b_start_col: %d," - # "loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d, \n\n", - # work_id_4th_loop, comm_id_4th_loop, packed_b_individual_size * work_id_5th_loop, loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) # # # TODO: If passed, see if this barrier is really needed # printf("Begin: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) From 1bd2cfe1b84b29f5d2b1423e2e141c56860ca54a Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 14:16:46 -0500 Subject: [PATCH 122/148] remove prints --- .../ops/matmul/matmul_f32_x86_refactored.py | 138 +++++++++--------- python/mat_new.py | 6 +- 2 files changed, 72 insertions(+), 72 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 5d78ce4d6..d784a0b7a 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -722,22 +722,22 @@ def gemm_macro( c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, - macro_m: %d, macro_n: %d, macro_k: %d, - ps_packed_a: %d, ps_packed_b: %d, , - n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, - jr_start: %d, jr_end: %d, jr_inc: %d, - ir_start: %d, ir_end: %d, ir_inc: %d, - rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, - n_cur: %d, m_cur: %d \n\n''', - work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, - macro_m, macro_n, macro_k, - ps_packed_a, ps_packed_b, - n_iter, n_remainder, m_iter, m_remainder, - jr_start, jr_end, jr_inc, - ir_start, ir_end, ir_inc, - rstep_a, cstep_b, cstep_c, rstep_c, - n_cur, m_cur) + # printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, + # macro_m: %d, macro_n: %d, macro_k: %d, + # ps_packed_a: %d, ps_packed_b: %d, , + # n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, + # jr_start: %d, jr_end: %d, jr_inc: %d, + # ir_start: %d, ir_end: %d, ir_inc: %d, + # rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, + # n_cur: %d, m_cur: %d \n\n''', + # work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, + # macro_m, macro_n, macro_k, + # ps_packed_a, ps_packed_b, + # n_iter, n_remainder, m_iter, m_remainder, + # jr_start, jr_end, jr_inc, + # ir_start, ir_end, ir_inc, + # rstep_a, cstep_b, cstep_c, rstep_c, + # n_cur, m_cur) if m_cur == MR and n_cur == NR: # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) @@ -802,38 +802,38 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) - printf( - "work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, " - "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, " - "loop3_partition_a_width: %d, loop3_partition_b_width: %d, " - "loop3_partition_a_start_row: %d, loop3_partition_a_height: %d, " - "m_start_loop3: %d, m_end_loop3: %d, ii: %d, b_alg_loop3: %d\n\n", - work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, - loop3_partition_a_start_col, loop3_partition_b_start_col, - loop3_partition_a_width, loop3_partition_b_width, - loop3_partition_a_start_row, loop3_partition_a_height, - m_start_loop3, m_end_loop3, ii, b_alg_loop3) + # printf( + # "work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, " + # "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, " + # "loop3_partition_a_width: %d, loop3_partition_b_width: %d, " + # "loop3_partition_a_start_row: %d, loop3_partition_a_height: %d, " + # "m_start_loop3: %d, m_end_loop3: %d, ii: %d, b_alg_loop3: %d\n\n", + # work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, + # loop3_partition_a_start_col, loop3_partition_b_start_col, + # loop3_partition_a_width, loop3_partition_b_width, + # loop3_partition_a_start_row, loop3_partition_a_height, + # m_start_loop3, m_end_loop3, ii, b_alg_loop3) # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) # TODO: If passed, see if this barrier is necessary - printf( - "Begin: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", - work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - loop3_partition_a_start_col, - loop3_partition_b_start_col) + # printf( + # "Begin: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + # loop3_partition_a_start_col, + # loop3_partition_b_start_col) thrcomm_barrier( comm_id_packa, ~packa_thrcomm_barrier_sense[work_id_3rd_loop], ~packa_thrcomm_threads_arrived[work_id_3rd_loop], packa_nthreads ) - printf( - "End: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", - work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - loop3_partition_a_start_col, - loop3_partition_b_start_col) + # printf( + # "End: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + # loop3_partition_a_start_col, + # loop3_partition_b_start_col) gemm_pack_a( loop3_partition_a, @@ -847,22 +847,22 @@ def gemm_3rd_loop( # This marks the end of the packing of A, # so a barrier is needed - printf( - "Begin: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", - work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - loop3_partition_a_start_col, - loop3_partition_b_start_col) + # printf( + # "Begin: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + # loop3_partition_a_start_col, + # loop3_partition_b_start_col) thrcomm_barrier( comm_id_packa, ~packa_thrcomm_barrier_sense[work_id_3rd_loop], ~packa_thrcomm_threads_arrived[work_id_3rd_loop], packa_nthreads ) - printf( - "End: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", - work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - loop3_partition_a_start_col, - loop3_partition_b_start_col) + # printf( + # "End: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, + # loop3_partition_a_start_col, + # loop3_partition_b_start_col) gemm_macro(packed_a_buf, packed_b, @@ -914,12 +914,12 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) - printf( - "work_id_5th_loop: %d, work_id_4th_loop: %d, i_loop4: %d, b_alg_loop4: %d, loop4_partition_b_height: %d, loop4_partition_b_width: %d," - "loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d, loop4_partition_a_start_col: %d, is_first: %d\n", - work_id_5th_loop, work_id_4th_loop, i_loop4, b_alg_loop4, loop4_partition_b_height, loop4_partition_b_width, - loop4_partition_b_start_row, - loop4_partition_b_start_col, loop4_partition_a_start_col, is_first) + # printf( + # "work_id_5th_loop: %d, work_id_4th_loop: %d, i_loop4: %d, b_alg_loop4: %d, loop4_partition_b_height: %d, loop4_partition_b_width: %d," + # "loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d, loop4_partition_a_start_col: %d, is_first: %d\n", + # work_id_5th_loop, work_id_4th_loop, i_loop4, b_alg_loop4, loop4_partition_b_height, loop4_partition_b_width, + # loop4_partition_b_start_row, + # loop4_partition_b_start_col, loop4_partition_a_start_col, is_first) packed_b_buf = packb_buf + ( packed_b_individual_size * work_id_5th_loop @@ -945,18 +945,18 @@ def gemm_4th_loop(a: float32[m_size, k_size], comm_id_packb, work_id_packb, loop3_nthreads) # The barrier at the end of the packing of B - printf( - "Begin: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", - work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) + # printf( + # "Begin: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", + # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_4th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], packb_nthreads ) - printf( - "End: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", - work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) + # printf( + # "End: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", + # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) # TODO: The loop3 and beyond should start here? gemm_3rd_loop( @@ -973,18 +973,18 @@ def gemm_4th_loop(a: float32[m_size, k_size], ) # # TODO: Is not adding this barrier at the end the problem? - printf( - "Begin: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d, work_id_5th_loop: %d\n", - work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) + # printf( + # "Begin: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d, work_id_5th_loop: %d\n", + # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_4th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], packb_nthreads ) - printf( - "End: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", - work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) + # printf( + # "End: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", + # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) i_loop4 += b_alg_loop4 @@ -1014,10 +1014,10 @@ def gemm_5th_loop(a: float32[m_size, k_size], loop5_partition_b_width = b_alg_loop5, loop5_partition_b_start_col = loop5_iter - printf( - "work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d, loop5_iter: %d, loop5_my_start: %d, loop5_my_end: %d\n\n", - work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, - loop5_partition_b_start_col, loop5_iter, loop5_my_start, loop5_my_end) + # printf( + # "work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d, loop5_iter: %d, loop5_my_start: %d, loop5_my_end: %d\n\n", + # work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, + # loop5_partition_b_start_col, loop5_iter, loop5_my_start, loop5_my_end) gemm_4th_loop(a, b, c, loop5_partition_b_width, loop5_partition_b_start_col, diff --git a/python/mat_new.py b/python/mat_new.py index 29d8713a8..f6d7f6dfa 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -49,8 +49,8 @@ def matmul_ansor(M, K, N, dtype): a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') - print(f"a: {a}") - print(f"b: {b}") + # print(f"a: {a}") + # print(f"b: {b}") # a = hidet.ones([m, k], device='cpu') # b = hidet.ones([k, n], device='cpu') @@ -80,7 +80,7 @@ def matmul_ansor(M, K, N, dtype): # print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") # fails += 1 - print(f"Total fails: {fails}") + # print(f"Total fails: {fails}") # for i in range(m): # for j in range(n): From 6721ed21c1c3f8814b3fd887ab01cc27237d2c91 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 16:11:41 -0500 Subject: [PATCH 123/148] . --- python/mat_new.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/mat_new.py b/python/mat_new.py index f6d7f6dfa..88ff5a0a3 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -37,18 +37,18 @@ def matmul_ansor(M, K, N, dtype): # for m, n, k in [(33, 65, 60), (32, 92, 128)]: # for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: # for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: -for m, n, k in [(7, 17, 1), (333, 444, 555), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: - # a = hidet.randn([m, k], device='cpuO') - # b = hidet.randn([k, n], device='cpu') - - a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') - b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') +for m, n, k in [(7, 17, 1), (333, 444, 555)]: + a = hidet.randn([m, k], device='cpu') + b = hidet.randn([k, n], device='cpu') + + # a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') + # b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') + # # + # # print(f"a_torch: {a_torch}") + # # print(f"b_torch: {b_torch}") # - # print(f"a_torch: {a_torch}") - # print(f"b_torch: {b_torch}") - - a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') - b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') + # a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') + # b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') # print(f"a: {a}") # print(f"b: {b}") @@ -91,8 +91,8 @@ def matmul_ansor(M, K, N, dtype): np.testing.assert_allclose( actual=actual, desired=desired, - rtol=1e-3, - atol=1e-3 + rtol=1e-2, + atol=1e-2 ) print("passed for m={}, n={}, k={}".format(m, n, k)) From 442fbd214162a229e6d1da133e13360ad989e910 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 17:20:57 -0500 Subject: [PATCH 124/148] .. --- .../ops/matmul/matmul_f32_x86_refactored.py | 23 +++++++++++++------ python/mat_new.py | 2 +- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index d784a0b7a..b25c11f5f 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -157,6 +157,8 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): packb_nthreads = loop3_nthreads packa_nthreads = macro_nthreads + packed_a_buffers_needed = loop3_nways * loop5_nways + @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): if n_way == 1: @@ -227,6 +229,12 @@ def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: not_edge.kind = 'cpu_internal' + # TODO: Is this the way to find out the "index" of the packed A buffer? + @hidet.script + def packa_index(work_id_loop5: int32, work_id_loop3: int32): + return work_id_loop5 * loop3_nways + work_id_loop3 + packa_index.kind = 'cpu_internal' + # Thread barrier @hidet.script def thrcomm_barrier(tid: int32, barrier_sense: ~int32, @@ -355,7 +363,8 @@ def micro_kernel( if packed_a_individual_height > m_size: packed_a_individual_height = (m_size + MR - 1) // MR * MR # packed_a_total_height = packed_a_individual_height * loop3_nways - packed_a_total_height = packed_a_individual_height * loop5_nways + # packed_a_total_height = packed_a_individual_height * loop5_nways + packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed packed_a_width = KC if packed_a_width > k_size: @@ -522,7 +531,6 @@ def gemm_pack_b( npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) packedb_panel_stride = packed_b_height * NR - # Loop for the packing of B for i_panel in range(npanels_b): if i_panel % packb_nways != work_id_packb % packb_nways: @@ -816,6 +824,8 @@ def gemm_3rd_loop( # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) + # packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) + # packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) # TODO: If passed, see if this barrier is necessary # printf( @@ -929,7 +939,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - # # # TODO: If passed, see if this barrier is really needed # printf("Begin: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) # thrcomm_barrier( @@ -950,8 +959,8 @@ def gemm_4th_loop(a: float32[m_size, k_size], # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( comm_id_packb, - ~packb_thrcomm_barrier_sense[work_id_4th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + ~packb_thrcomm_barrier_sense[work_id_5th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], packb_nthreads ) # printf( @@ -978,8 +987,8 @@ def gemm_4th_loop(a: float32[m_size, k_size], # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( comm_id_packb, - ~packb_thrcomm_barrier_sense[work_id_4th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + ~packb_thrcomm_barrier_sense[work_id_5th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], packb_nthreads ) # printf( diff --git a/python/mat_new.py b/python/mat_new.py index 88ff5a0a3..da3f34b55 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -37,7 +37,7 @@ def matmul_ansor(M, K, N, dtype): # for m, n, k in [(33, 65, 60), (32, 92, 128)]: # for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: # for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: -for m, n, k in [(7, 17, 1), (333, 444, 555)]: +for m, n, k in [(7, 17, 1), (333, 444, 555), (768, 768, 768)]: a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') From b4e00e90164d806e030bce3b392b69a9b430d347 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 17:22:48 -0500 Subject: [PATCH 125/148] logic error fix in packing of A --- .../hidet/graph/ops/matmul/matmul_f32_x86_refactored.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index b25c11f5f..4c85de6fc 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -231,8 +231,9 @@ def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: # TODO: Is this the way to find out the "index" of the packed A buffer? @hidet.script - def packa_index(work_id_loop5: int32, work_id_loop3: int32): + def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32: return work_id_loop5 * loop3_nways + work_id_loop3 + packa_index.kind = 'cpu_internal' # Thread barrier @@ -823,9 +824,9 @@ def gemm_3rd_loop( # m_start_loop3, m_end_loop3, ii, b_alg_loop3) # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) - packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) - # packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) - # packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) + # packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) + packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) + packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) # TODO: If passed, see if this barrier is necessary # printf( From ad9c4533494ec329d59db0aaec69d7f779878a66 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 9 Nov 2023 20:26:54 -0500 Subject: [PATCH 126/148] seems like still bugs, but they disappear with print... --- .../ops/matmul/matmul_f32_x86_refactored.py | 90 +++++++++---------- python/mat_new.py | 44 ++++----- 2 files changed, 68 insertions(+), 66 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 4c85de6fc..de4e008a7 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -85,7 +85,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=6, NC=16, KC=8, ways=(2, 1, 1, 1) + self, MC=6, NC=16, KC=8, ways=(2, 2, 2, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -115,20 +115,20 @@ def schedule_matmulf32_x86( packa_thrcomm_barrier_sense = module.define_global_var( name="pack_a_barrier_sense", - var_type=int32[loop3_nways] + var_type=int32[nthreads] ) packa_thrcomm_threads_arrived = module.define_global_var( name="pack_a_threads_arrived", - var_type=int32[loop3_nways] + var_type=int32[nthreads] ) packb_thrcomm_barrier_sense = module.define_global_var( name='pack_b_barrier_sense', - var_type=int32[loop5_nways] + var_type=int32[nthreads] ) packb_thrcomm_barrier_threads_arrived = module.define_global_var( name="pack_b_threads_arrived", - var_type=int32[loop5_nways] + var_type=int32[nthreads] ) @hidet.script @@ -320,11 +320,11 @@ def micro_kernel( aa6 = avx_f32x8_broadcast(a_ptr + 5) c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) - # printf( - # "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", - # a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], - # b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], - # b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) + printf( + "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", + a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], + b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], + b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) a_ptr = a_ptr + 6 b_ptr = b_ptr + 16 @@ -415,7 +415,7 @@ def gemm_pack_a( npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) for ii_panel in range(npanels_a): - if ii_panel % packa_nways != work_id_packa % packa_nways: + if ii_panel % packa_nthreads != work_id_packa % packa_nthreads: continue a_curr_panel_row_start = ii_panel * MR @@ -534,7 +534,7 @@ def gemm_pack_b( # Loop for the packing of B for i_panel in range(npanels_b): - if i_panel % packb_nways != work_id_packb % packb_nways: + if i_panel % packb_nthreads != work_id_packb % packb_nthreads: continue packed_b_buff_curr = packed_b_buf + ( i_panel * packedb_panel_stride) @@ -731,22 +731,22 @@ def gemm_macro( c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - # printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, - # macro_m: %d, macro_n: %d, macro_k: %d, - # ps_packed_a: %d, ps_packed_b: %d, , - # n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, - # jr_start: %d, jr_end: %d, jr_inc: %d, - # ir_start: %d, ir_end: %d, ir_inc: %d, - # rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, - # n_cur: %d, m_cur: %d \n\n''', - # work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, - # macro_m, macro_n, macro_k, - # ps_packed_a, ps_packed_b, - # n_iter, n_remainder, m_iter, m_remainder, - # jr_start, jr_end, jr_inc, - # ir_start, ir_end, ir_inc, - # rstep_a, cstep_b, cstep_c, rstep_c, - # n_cur, m_cur) + printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, + macro_m: %d, macro_n: %d, macro_k: %d, + ps_packed_a: %d, ps_packed_b: %d, , + n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, + jr_start: %d, jr_end: %d, jr_inc: %d, + ir_start: %d, ir_end: %d, ir_inc: %d, + rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, + n_cur: %d, m_cur: %d \n\n''', + work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, + macro_m, macro_n, macro_k, + ps_packed_a, ps_packed_b, + n_iter, n_remainder, m_iter, m_remainder, + jr_start, jr_end, jr_inc, + ir_start, ir_end, ir_inc, + rstep_a, cstep_b, cstep_c, rstep_c, + n_cur, m_cur) if m_cur == MR and n_cur == NR: # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) @@ -830,21 +830,21 @@ def gemm_3rd_loop( # TODO: If passed, see if this barrier is necessary # printf( - # "Begin: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + # "Begin: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, packed_a_idx: %d\n", # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, # loop3_partition_a_start_col, - # loop3_partition_b_start_col) + # loop3_partition_b_start_col, packed_a_idx) thrcomm_barrier( comm_id_packa, - ~packa_thrcomm_barrier_sense[work_id_3rd_loop], - ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + ~packa_thrcomm_barrier_sense[packed_a_idx], + ~packa_thrcomm_threads_arrived[packed_a_idx], packa_nthreads ) # printf( - # "End: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", + # "End: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, packed_a_idx: %d\n", # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, # loop3_partition_a_start_col, - # loop3_partition_b_start_col) + # loop3_partition_b_start_col, packed_a_idx) gemm_pack_a( loop3_partition_a, @@ -865,8 +865,8 @@ def gemm_3rd_loop( # loop3_partition_b_start_col) thrcomm_barrier( comm_id_packa, - ~packa_thrcomm_barrier_sense[work_id_3rd_loop], - ~packa_thrcomm_threads_arrived[work_id_3rd_loop], + ~packa_thrcomm_barrier_sense[packed_a_idx], + ~packa_thrcomm_threads_arrived[packed_a_idx], packa_nthreads ) # printf( @@ -944,8 +944,8 @@ def gemm_4th_loop(a: float32[m_size, k_size], # printf("Begin: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) # thrcomm_barrier( # comm_id_packb, - # ~packb_thrcomm_barrier_sense[work_id_4th_loop], - # ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], + # ~packb_thrcomm_barrier_sense[work_id_5th_loop], + # ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], # packb_nthreads # ) # printf("End: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) @@ -1057,14 +1057,14 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], # printf("loop3_nthreads: %d, loop3_nways: %d\n", loop3_nthreads, loop3_nways) # printf("macro_nthreads: %d, macro_nways: %d\n", macro_nthreads, macro_nways) # printf("loop1_nthreads: %d, loop1_nways: %d\n", loop1_nthreads, loop1_nways) - + # # printf("packb_nthreads: %d, packa_nthreads: %d\n", packb_nthreads, packa_nthreads) - - printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, - packed_b_total_width, packed_b_height) - printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", - packed_a_width, packed_a_individual_height, packed_a_total_height) - + # + # printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, + # packed_b_total_width, packed_b_height) + # printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", + # packed_a_width, packed_a_individual_height, packed_a_total_height) + # # printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) # printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) diff --git a/python/mat_new.py b/python/mat_new.py index da3f34b55..39f52f7b7 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -37,23 +37,25 @@ def matmul_ansor(M, K, N, dtype): # for m, n, k in [(33, 65, 60), (32, 92, 128)]: # for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: # for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: -for m, n, k in [(7, 17, 1), (333, 444, 555), (768, 768, 768)]: - a = hidet.randn([m, k], device='cpu') - b = hidet.randn([k, n], device='cpu') - - # a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') - # b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') +# for m, n, k in [(7, 17, 1), (16, 16, 16), (333, 444, 555), (768, 768, 768)]: +# for m, n, k in [(7, 17, 1), (16, 16, 16), (17, 17, 17), (36, 36, 36), (37, 37, 37), (128, 128, 128), (256, 256, 256), (333, 444, 555), (768, 768, 768)]: +for m, n, k in [(7, 17, 1), (36, 20, 20), (128, 128, 128), (768, 768, 768)]: + # a = hidet.randn([m, k], device='cpu') + # b = hidet.randn([k, n], device='cpu') + + a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') + b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') # # # # print(f"a_torch: {a_torch}") # # print(f"b_torch: {b_torch}") # - # a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') - # b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') + a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') + b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') # print(f"a: {a}") # print(f"b: {b}") - # a = hidet.ones([m, k], device='cpu') - # b = hidet.ones([k, n], device='cpu') + a = hidet.randn([m, k], device='cpu') + b = hidet.randn([k, n], device='cpu') x1 = hidet.symbol_like(a) @@ -71,16 +73,16 @@ def matmul_ansor(M, K, N, dtype): fails = 0 - # for i in range(m): - # for j in range(n): - # if abs(actual[i, j] - desired[i, j]) < 1e-3: - # # print(f"Actually passed for i={i}, j={j}") - # continue - # else: - # print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") - # fails += 1 + for i in range(m): + for j in range(n): + if abs(actual[i, j] - desired[i, j]) < 1e-3: + # print(f"Actually passed for i={i}, j={j}") + continue + else: + print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") + fails += 1 - # print(f"Total fails: {fails}") + print(f"Total fails: {fails}") # for i in range(m): # for j in range(n): @@ -91,8 +93,8 @@ def matmul_ansor(M, K, N, dtype): np.testing.assert_allclose( actual=actual, desired=desired, - rtol=1e-2, - atol=1e-2 + rtol=1e-3, + atol=1e-3 ) print("passed for m={}, n={}, k={}".format(m, n, k)) From d34f03169813e6f895ecf6a5444e4447b0606e9c Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Sat, 11 Nov 2023 12:30:10 -0500 Subject: [PATCH 127/148] fix bug caused by static local vairable --- .../ops/matmul/matmul_f32_x86_refactored.py | 81 ++++++++++--------- python/mat_new.py | 15 ++-- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index de4e008a7..9d8c4fd8b 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -85,7 +85,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=6, NC=16, KC=8, ways=(2, 2, 2, 1) + self, MC=2016, NC=256, KC=560, ways=(2, 2, 4, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -198,6 +198,7 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~ # Add the remainder to the last thread's end if work_id == n_way - 1: end[0] += n_bf_left + end[0] = min(end[0], all_end) thread_range_sub.kind = "cpu_internal" @@ -320,11 +321,11 @@ def micro_kernel( aa6 = avx_f32x8_broadcast(a_ptr + 5) c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) - printf( - "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", - a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], - b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], - b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) + # printf( + # "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", + # a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], + # b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], + # b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) a_ptr = a_ptr + 6 b_ptr = b_ptr + 16 @@ -717,7 +718,7 @@ def gemm_macro( temp_c = tensor(scope=DeclareScope.Default, dtype=float32, layout=row_major(MR, NR), - is_static=True) + is_static=False) j = jr_start while j < jr_end: b1 = packed_b + j * cstep_b @@ -730,23 +731,25 @@ def gemm_macro( c11 = c1 + i * rstep_c c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - - printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, - macro_m: %d, macro_n: %d, macro_k: %d, - ps_packed_a: %d, ps_packed_b: %d, , - n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, - jr_start: %d, jr_end: %d, jr_inc: %d, - ir_start: %d, ir_end: %d, ir_inc: %d, - rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, - n_cur: %d, m_cur: %d \n\n''', - work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, - macro_m, macro_n, macro_k, - ps_packed_a, ps_packed_b, - n_iter, n_remainder, m_iter, m_remainder, - jr_start, jr_end, jr_inc, - ir_start, ir_end, ir_inc, - rstep_a, cstep_b, cstep_c, rstep_c, - n_cur, m_cur) + # + # printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, + # macro_m: %d, macro_n: %d, macro_k: %d, + # ps_packed_a: %d, ps_packed_b: %d, , + # n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, + # jr_start: %d, jr_end: %d, jr_inc: %d, + # ir_start: %d, ir_end: %d, ir_inc: %d, + # i: %d, j: %d, + # rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, + # n_cur: %d, m_cur: %d \n\n''', + # work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, + # macro_m, macro_n, macro_k, + # ps_packed_a, ps_packed_b, + # n_iter, n_remainder, m_iter, m_remainder, + # jr_start, jr_end, jr_inc, + # ir_start, ir_end, ir_inc, + # i, j, + # rstep_a, cstep_b, cstep_c, rstep_c, + # n_cur, m_cur) if m_cur == MR and n_cur == NR: # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) @@ -811,22 +814,23 @@ def gemm_3rd_loop( loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) + + # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) + # packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) + packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) + packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) + # printf( # "work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, " # "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, " # "loop3_partition_a_width: %d, loop3_partition_b_width: %d, " # "loop3_partition_a_start_row: %d, loop3_partition_a_height: %d, " - # "m_start_loop3: %d, m_end_loop3: %d, ii: %d, b_alg_loop3: %d\n\n", + # "m_start_loop3: %d, m_end_loop3: %d, ii: %d, b_alg_loop3: %d, packed_a_idx: %d\n\n", # work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, # loop3_partition_a_start_col, loop3_partition_b_start_col, # loop3_partition_a_width, loop3_partition_b_width, # loop3_partition_a_start_row, loop3_partition_a_height, - # m_start_loop3, m_end_loop3, ii, b_alg_loop3) - - # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) - # packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) - packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) - packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) + # m_start_loop3, m_end_loop3, ii, b_alg_loop3, packed_a_idx) # TODO: If passed, see if this barrier is necessary # printf( @@ -942,12 +946,12 @@ def gemm_4th_loop(a: float32[m_size, k_size], # # # TODO: If passed, see if this barrier is really needed # printf("Begin: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) - # thrcomm_barrier( - # comm_id_packb, - # ~packb_thrcomm_barrier_sense[work_id_5th_loop], - # ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - # packb_nthreads - # ) + thrcomm_barrier( + comm_id_packb, + ~packb_thrcomm_barrier_sense[work_id_5th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], + packb_nthreads + ) # printf("End: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) gemm_pack_b(loop4_partition_b, loop4_partition_b_width, @@ -1018,12 +1022,13 @@ def gemm_5th_loop(a: float32[m_size, k_size], while loop5_iter < loop5_my_end: b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) + b_alg_loop5 = min(b_alg_loop5, loop5_my_end - loop5_iter) loop5_partition_c_width = b_alg_loop5 loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, loop5_partition_b_start_col = loop5_iter - + # # printf( # "work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d, loop5_iter: %d, loop5_my_start: %d, loop5_my_end: %d\n\n", # work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, diff --git a/python/mat_new.py b/python/mat_new.py index 39f52f7b7..908fd31e8 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -39,24 +39,27 @@ def matmul_ansor(M, K, N, dtype): # for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: # for m, n, k in [(7, 17, 1), (16, 16, 16), (333, 444, 555), (768, 768, 768)]: # for m, n, k in [(7, 17, 1), (16, 16, 16), (17, 17, 17), (36, 36, 36), (37, 37, 37), (128, 128, 128), (256, 256, 256), (333, 444, 555), (768, 768, 768)]: -for m, n, k in [(7, 17, 1), (36, 20, 20), (128, 128, 128), (768, 768, 768)]: +# for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768)]: +for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768), (555, 256, 3072)]: # a = hidet.randn([m, k], device='cpu') # b = hidet.randn([k, n], device='cpu') - a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') - b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') + # a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') + # b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') # # # # print(f"a_torch: {a_torch}") # # print(f"b_torch: {b_torch}") # - a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') - b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') + # a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') + # b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') # print(f"a: {a}") # print(f"b: {b}") a = hidet.randn([m, k], device='cpu') b = hidet.randn([k, n], device='cpu') - + # a = hidet.ones([m, k], device='cpu') + # b = hidet.ones([k, n], device='cpu') + # x1 = hidet.symbol_like(a) x2 = hidet.symbol_like(b) From 954da89315faa29d26abc2bcf3332c1460e80d46 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Wed, 15 Nov 2023 15:26:31 -0500 Subject: [PATCH 128/148] ... --- .../ops/matmul/matmul_f32_x86_refactored.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index 9d8c4fd8b..bdcc667d2 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -85,7 +85,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=256, KC=560, ways=(2, 2, 4, 1) + self, MC=2016, NC=256, KC=560, ways=(1, 8, 2, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -277,19 +277,19 @@ def micro_kernel( c5 = avx_f32x8_load(~c[5, 0]) c58 = avx_f32x8_load(~c[5, 8]) - # if is_first: - # c0 = avx_f32x8_setzero() - # c08 = avx_f32x8_setzero() - # c1 = avx_f32x8_setzero() - # c18 = avx_f32x8_setzero() - # c2 = avx_f32x8_setzero() - # c28 = avx_f32x8_setzero() - # c3 = avx_f32x8_setzero() - # c38 = avx_f32x8_setzero() - # c4 = avx_f32x8_setzero() - # c48 = avx_f32x8_setzero() - # c5 = avx_f32x8_setzero() - # c58 = avx_f32x8_setzero() + if is_first: + c0 = avx_f32x8_setzero() + c08 = avx_f32x8_setzero() + c1 = avx_f32x8_setzero() + c18 = avx_f32x8_setzero() + c2 = avx_f32x8_setzero() + c28 = avx_f32x8_setzero() + c3 = avx_f32x8_setzero() + c38 = avx_f32x8_setzero() + c4 = avx_f32x8_setzero() + c48 = avx_f32x8_setzero() + c5 = avx_f32x8_setzero() + c58 = avx_f32x8_setzero() a_ptr = cast(a, ~float32) b_ptr = cast(b, ~float32) @@ -764,8 +764,8 @@ def gemm_macro( c11[mm, nn] += temp_c[mm, nn] else: for mm, nn in grid(m_cur, n_cur): - # c11[mm, nn] = temp_c[mm, nn] FIXME: temporarily changed to see if zero-initing is the problem(well, it is not.....) - c11[mm, nn] += temp_c[mm, nn] + c11[mm, nn] = temp_c[mm, nn] + # c11[mm, nn] += temp_c[mm, nn] i += ir_inc j += jr_inc @@ -1073,9 +1073,9 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], # printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) # printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) - for i in grid(m_size): - for j in grid(n_size): - c[i, j] = 0.0 + # for i in grid(m_size): + # for j in grid(n_size): + # c[i, j] = 0.0 parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads From 78d09c41dd9fc9c8d4b2bb2a196e196dbc1449f3 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Wed, 15 Nov 2023 16:16:37 -0500 Subject: [PATCH 129/148] fix alignment --- .../ops/matmul/matmul_f32_x86_refactored.py | 84 ++++++------------- python/mat_new.py | 2 +- 2 files changed, 28 insertions(+), 58 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index bdcc667d2..d08c8f0ef 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -85,7 +85,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=256, KC=560, ways=(1, 8, 2, 1) + self, MC=2016, NC=256, KC=560, ways=(4, 1, 2, 2) ) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -371,6 +371,8 @@ def micro_kernel( packed_a_width = KC if packed_a_width > k_size: packed_a_width = k_size + # FIXME: Can this allow us to use align versions of loads once and for all? + packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height @@ -467,30 +469,30 @@ def gemm_pack_a( res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) # TODO: Now I changed to unaligned to debug... - avx_f32x8_store( + avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start, col], res0 ) - avx_f32x8_store( + avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 2, col + 1], res2 ) - avx_f32x8_store( + avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 4, col + 2], res4) - avx_f32x8_store( + avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start, col + 4], res1 ) - avx_f32x8_store( + avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 2, col + 5], res3 ) - avx_f32x8_store( + avx_f32x8_store_aligned( ~packed_a_tensor[a_curr_panel_row_start + 4, col + 6], res5 @@ -555,57 +557,57 @@ def gemm_pack_b( b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) - avx_f32x8_store(packed_b_buff_curr, b00) - avx_f32x8_store(packed_b_buff_curr + 8, b08) + avx_f32x8_store_aligned(packed_b_buff_curr, b00) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) packed_b_buff_curr += 16 b10 = avx_f32x8_load(b_panel + n_size) b18 = avx_f32x8_load(b_panel + (n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b10) - avx_f32x8_store(packed_b_buff_curr + 8, b18) + avx_f32x8_store_aligned(packed_b_buff_curr, b10) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b18) packed_b_buff_curr += 16 b20 = avx_f32x8_load(b_panel + (2 * n_size)) b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b20) - avx_f32x8_store(packed_b_buff_curr + 8, b28) + avx_f32x8_store_aligned(packed_b_buff_curr, b20) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b28) packed_b_buff_curr += 16 b30 = avx_f32x8_load(b_panel + (3 * n_size)) b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b30) - avx_f32x8_store(packed_b_buff_curr + 8, b38) + avx_f32x8_store_aligned(packed_b_buff_curr, b30) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b38) packed_b_buff_curr += 16 b40 = avx_f32x8_load(b_panel + (4 * n_size)) b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b40) - avx_f32x8_store(packed_b_buff_curr + 8, b48) + avx_f32x8_store_aligned(packed_b_buff_curr, b40) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b48) packed_b_buff_curr += 16 b50 = avx_f32x8_load(b_panel + (5 * n_size)) b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b50) - avx_f32x8_store(packed_b_buff_curr + 8, b58) + avx_f32x8_store_aligned(packed_b_buff_curr, b50) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b58) packed_b_buff_curr += 16 b60 = avx_f32x8_load(b_panel + (6 * n_size)) b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b60) - avx_f32x8_store(packed_b_buff_curr + 8, b68) + avx_f32x8_store_aligned(packed_b_buff_curr, b60) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b68) packed_b_buff_curr += 16 b70 = avx_f32x8_load(b_panel + (7 * n_size)) b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) - avx_f32x8_store(packed_b_buff_curr, b70) - avx_f32x8_store(packed_b_buff_curr + 8, b78) + avx_f32x8_store_aligned(packed_b_buff_curr, b70) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b78) packed_b_buff_curr += 16 @@ -615,8 +617,8 @@ def gemm_pack_b( row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) - avx_f32x8_store(packed_b_buff_curr, b00) - avx_f32x8_store(packed_b_buff_curr + 8, b08) + avx_f32x8_store_aligned(packed_b_buff_curr, b00) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) packed_b_buff_curr += 16 row += 1 @@ -731,25 +733,6 @@ def gemm_macro( c11 = c1 + i * rstep_c c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - # - # printf(''' work_id_macro: %d, work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, c_row_off: %d, c_col_off: %d, - # macro_m: %d, macro_n: %d, macro_k: %d, - # ps_packed_a: %d, ps_packed_b: %d, , - # n_iter: %d, n_remainder: %d, m_iter: %d, m_remainder: %d, - # jr_start: %d, jr_end: %d, jr_inc: %d, - # ir_start: %d, ir_end: %d, ir_inc: %d, - # i: %d, j: %d, - # rstep_a: %d, cstep_b: %d, cstep_c: %d, rstep_c: %d, - # n_cur: %d, m_cur: %d \n\n''', - # work_id_macro, work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, c_row_off, c_col_off, - # macro_m, macro_n, macro_k, - # ps_packed_a, ps_packed_b, - # n_iter, n_remainder, m_iter, m_remainder, - # jr_start, jr_end, jr_inc, - # ir_start, ir_end, ir_inc, - # i, j, - # rstep_a, cstep_b, cstep_c, rstep_c, - # n_cur, m_cur) if m_cur == MR and n_cur == NR: # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) @@ -815,22 +798,9 @@ def gemm_3rd_loop( loop3_partition_a_start_col ) - # packed_a_buf = packa_buf + (work_id_3rd_loop * packed_a_individual_size) - # packed_a_buf = packa_buf + (work_id_5th_loop * packed_a_individual_size) packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) - # printf( - # "work_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, " - # "loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, " - # "loop3_partition_a_width: %d, loop3_partition_b_width: %d, " - # "loop3_partition_a_start_row: %d, loop3_partition_a_height: %d, " - # "m_start_loop3: %d, m_end_loop3: %d, ii: %d, b_alg_loop3: %d, packed_a_idx: %d\n\n", - # work_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, - # loop3_partition_a_start_col, loop3_partition_b_start_col, - # loop3_partition_a_width, loop3_partition_b_width, - # loop3_partition_a_start_row, loop3_partition_a_height, - # m_start_loop3, m_end_loop3, ii, b_alg_loop3, packed_a_idx) # TODO: If passed, see if this barrier is necessary # printf( diff --git a/python/mat_new.py b/python/mat_new.py index 908fd31e8..265fe205f 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -40,7 +40,7 @@ def matmul_ansor(M, K, N, dtype): # for m, n, k in [(7, 17, 1), (16, 16, 16), (333, 444, 555), (768, 768, 768)]: # for m, n, k in [(7, 17, 1), (16, 16, 16), (17, 17, 17), (36, 36, 36), (37, 37, 37), (128, 128, 128), (256, 256, 256), (333, 444, 555), (768, 768, 768)]: # for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768)]: -for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768), (555, 256, 3072)]: +for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768), (555, 256, 3072), (2048, 2048, 2048)]: # a = hidet.randn([m, k], device='cpu') # b = hidet.randn([k, n], device='cpu') From 838a61edb0324999a76e57fd8598c7084090371b Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 16 Nov 2023 20:40:36 -0500 Subject: [PATCH 130/148] cleanup --- .../ops/matmul/matmul_f32_x86_refactored.py | 137 ++---------------- 1 file changed, 10 insertions(+), 127 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py index d08c8f0ef..dcaee628a 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py @@ -14,8 +14,6 @@ from hidet.ir.expr import cast from hidet.ir.module import IRModule from hidet.ir.compute import TensorNode -from hidet.ir.primitives import printf -from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned, avx_free, avx_malloc from hidet.ir.stmt import DeclareScope from hidet.ir.task import Task from hidet.ir.compute import compute, reduce @@ -83,16 +81,16 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 1, 1, 1)]) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) def schedule_matmulf32_x86( - self, MC=2016, NC=256, KC=560, ways=(4, 1, 2, 2) + self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type from hidet.lang import tensor, grid, as_tensor_pointer from hidet.lang.layout import row_major, column_major from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast - from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned + from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned, avx_f32x8_setzero from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 @@ -148,7 +146,6 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): KC, NR)) # Get the number of threads remaining at each level - loop5_nthreads = nthreads loop4_nthreads = nthreads // loop5_nways loop3_nthreads = loop4_nthreads macro_nthreads = loop3_nthreads // loop3_nways @@ -239,7 +236,7 @@ def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32: # Thread barrier @hidet.script - def thrcomm_barrier(tid: int32, barrier_sense: ~int32, + def thrcomm_barrier(barrier_sense: ~int32, barrier_threads_arrived: ~int32, num_threads: int32): if num_threads == 1: return @@ -321,11 +318,6 @@ def micro_kernel( aa6 = avx_f32x8_broadcast(a_ptr + 5) c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) - # printf( - # "List of all the aa's broadcasted in this iteration: %lf %lf %lf %lf %lf %lf\n, bb0 to to bb15: %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf %lf\n\n", - # a_ptr[0], a_ptr[1], a_ptr[2], a_ptr[3], a_ptr[4], a_ptr[5], b_ptr[0], b_ptr[1], b_ptr[2], - # b_ptr[3], b_ptr[4], b_ptr[5], b_ptr[6], b_ptr[7], b_ptr[8], b_ptr[9], b_ptr[10], b_ptr[11], - # b_ptr[12], b_ptr[13], b_ptr[14], b_ptr[15]) a_ptr = a_ptr + 6 b_ptr = b_ptr + 16 @@ -364,14 +356,12 @@ def micro_kernel( packed_a_individual_height = MC if packed_a_individual_height > m_size: packed_a_individual_height = (m_size + MR - 1) // MR * MR - # packed_a_total_height = packed_a_individual_height * loop3_nways - # packed_a_total_height = packed_a_individual_height * loop5_nways packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed packed_a_width = KC if packed_a_width > k_size: packed_a_width = k_size - # FIXME: Can this allow us to use align versions of loads once and for all? + # pad this to be able to use the aligned version of the avx store packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height @@ -388,11 +378,6 @@ def micro_kernel( packb_buf = cast(packb_buf_ptr, ~float32) packa_buf = cast(packa_buf_ptr, ~float32) - packed_a_type = tensor_type( - dtype='float32', - layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) - ) - ##### Start of the loops around micro kernel ##### @hidet.script @@ -401,9 +386,7 @@ def gemm_pack_a( loop3_partition_a_width: int32, loop3_partition_a_height: int32, packed_a_buf: ~float32, - comm_id_packa: int32, work_id_packa: int32, - packa_nways: int32 ): packed_a_tensor = as_tensor_pointer( packed_a_buf, @@ -526,8 +509,7 @@ def gemm_pack_b( loop4_partition_b_width: int32, loop4_partition_b_height: int32, packed_b_buf: ~float32, - comm_id_packb: int32, work_id_packb: int32, - packb_nways: int32 + work_id_packb: int32 ): npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR @@ -662,9 +644,8 @@ def gemm_macro( ps_packed_b: int32, comm_id_macro: int32, work_id_macro: int32, - is_first: bool, work_id_3rd_loop: int32, work_id_4th_loop: int32, work_id_5th_loop: int32 + is_first: bool ): - # assert loop1_nthreads == 1 comm_id_1st_loop = comm_id_macro % loop1_nthreads work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways) @@ -705,7 +686,6 @@ def gemm_macro( ~ir_inc ) - rs_packeda = 1 rstep_a = ps_packed_a cstep_b = ps_packed_b @@ -735,12 +715,10 @@ def gemm_macro( m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder if m_cur == MR and n_cur == NR: - # micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) micro_kernel(a1, b1, c11, macro_k, m_size, n_size, is_first) else: for i, j in grid(MR, NR): temp_c[i, j] = 0.0 - # micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) if not is_first: for mm, nn in grid(m_cur, n_cur): @@ -748,7 +726,6 @@ def gemm_macro( else: for mm, nn in grid(m_cur, n_cur): c11[mm, nn] = temp_c[mm, nn] - # c11[mm, nn] += temp_c[mm, nn] i += ir_inc j += jr_inc @@ -766,12 +743,10 @@ def gemm_3rd_loop( loop3_partition_b_width: int32, comm_id_3rd_loop: int32, work_id_3rd_loop: int32, - is_first: bool, work_id_4th_loop: int32, work_id_5th_loop: int32): + is_first: bool, work_id_5th_loop: int32): comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) - comm_id_packa = comm_id_macro work_id_packa = comm_id_macro - packa_nways = macro_nthreads m_start_loop3 = 0 m_end_loop3 = 0 @@ -802,52 +777,27 @@ def gemm_3rd_loop( packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) - # TODO: If passed, see if this barrier is necessary - # printf( - # "Begin: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, packed_a_idx: %d\n", - # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - # loop3_partition_a_start_col, - # loop3_partition_b_start_col, packed_a_idx) thrcomm_barrier( - comm_id_packa, ~packa_thrcomm_barrier_sense[packed_a_idx], ~packa_thrcomm_threads_arrived[packed_a_idx], packa_nthreads ) - # printf( - # "End: calling the first barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d, packed_a_idx: %d\n", - # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - # loop3_partition_a_start_col, - # loop3_partition_b_start_col, packed_a_idx) gemm_pack_a( loop3_partition_a, loop3_partition_a_width, loop3_partition_a_height, packed_a_buf, - comm_id_packa, work_id_packa, - packa_nways ) # This marks the end of the packing of A, # so a barrier is needed - # printf( - # "Begin: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", - # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - # loop3_partition_a_start_col, - # loop3_partition_b_start_col) thrcomm_barrier( - comm_id_packa, ~packa_thrcomm_barrier_sense[packed_a_idx], ~packa_thrcomm_threads_arrived[packed_a_idx], packa_nthreads ) - # printf( - # "End: calling the second barrier for the 3rd loop; work_id_3rd_loop: %d, comm_id_3rd_loop: %d, work_id_4th_loop: %d, work_id_5th_loop: %d, ii: %d, loop3_partition_a_start_col: %d, loop3_partition_b_start_col: %d\n", - # work_id_3rd_loop, comm_id_3rd_loop, work_id_4th_loop, work_id_5th_loop, ii, - # loop3_partition_a_start_col, - # loop3_partition_b_start_col) gemm_macro(packed_a_buf, packed_b, @@ -861,10 +811,7 @@ def gemm_3rd_loop( packed_b_height * NR, comm_id_macro, work_id_macro, - is_first, - work_id_3rd_loop, - work_id_4th_loop, - work_id_5th_loop, + is_first ) ii += b_alg_loop3 @@ -877,14 +824,11 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop5_partition_b_width: int32, loop5_partition_b_start_col: int32, comm_id_4th_loop: int32, - work_id_4th_loop: int32, work_id_5th_loop: int32): - b_alg_loop4 = KC i_loop4 = 0 comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) - comm_id_packb = comm_id_3rd_loop work_id_packb = comm_id_3rd_loop while i_loop4 < k_size: @@ -899,13 +843,6 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_a_start_col = i_loop4 is_first = (i_loop4 == 0) - # printf( - # "work_id_5th_loop: %d, work_id_4th_loop: %d, i_loop4: %d, b_alg_loop4: %d, loop4_partition_b_height: %d, loop4_partition_b_width: %d," - # "loop4_partition_b_start_row: %d, loop4_partition_b_start_col: %d, loop4_partition_a_start_col: %d, is_first: %d\n", - # work_id_5th_loop, work_id_4th_loop, i_loop4, b_alg_loop4, loop4_partition_b_height, loop4_partition_b_width, - # loop4_partition_b_start_row, - # loop4_partition_b_start_col, loop4_partition_a_start_col, is_first) - packed_b_buf = packb_buf + ( packed_b_individual_size * work_id_5th_loop ) @@ -914,35 +851,22 @@ def gemm_4th_loop(a: float32[m_size, k_size], (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - # # # TODO: If passed, see if this barrier is really needed - # printf("Begin: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) thrcomm_barrier( - comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_5th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], packb_nthreads ) - # printf("End: calling the first barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d\n", work_id_4th_loop, comm_id_4th_loop) gemm_pack_b(loop4_partition_b, loop4_partition_b_width, loop4_partition_b_height, packed_b_buf, - comm_id_packb, work_id_packb, loop3_nthreads) + work_id_packb) - # The barrier at the end of the packing of B - # printf( - # "Begin: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", - # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( - comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_5th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], packb_nthreads ) - # printf( - # "End: calling the second barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", - # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) - # TODO: The loop3 and beyond should start here? gemm_3rd_loop( a, packed_b_buf, c, loop4_partition_a_start_col, @@ -952,23 +876,14 @@ def gemm_4th_loop(a: float32[m_size, k_size], comm_id_3rd_loop, work_id_3rd_loop, is_first, - work_id_4th_loop, work_id_5th_loop ) - # # TODO: Is not adding this barrier at the end the problem? - # printf( - # "Begin: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d, work_id_5th_loop: %d\n", - # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) thrcomm_barrier( - comm_id_packb, ~packb_thrcomm_barrier_sense[work_id_5th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], packb_nthreads ) - # printf( - # "End: calling the third barrier for the 4th loop; work_id_4th_loop: %d, comm_id_4th_loop: %d, work_id_5th_loop: %d, i_loop4: %d, loop5_partition_b_start_col: %d\n", - # work_id_4th_loop, comm_id_4th_loop, work_id_5th_loop, i_loop4, loop5_partition_b_start_col) i_loop4 += b_alg_loop4 @@ -981,7 +896,6 @@ def gemm_5th_loop(a: float32[m_size, k_size], work_id_5th_loop: int32, comm_id_5th_loop: int32): comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads - work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) loop5_my_start = -1 loop5_my_end = -1 @@ -994,20 +908,12 @@ def gemm_5th_loop(a: float32[m_size, k_size], loop5_my_end, NC) b_alg_loop5 = min(b_alg_loop5, loop5_my_end - loop5_iter) - loop5_partition_c_width = b_alg_loop5 - loop5_partition_c_start_col = loop5_iter loop5_partition_b_width = b_alg_loop5, loop5_partition_b_start_col = loop5_iter - # - # printf( - # "work_id_5th_loop: %d, comm_id_5th_loop: %d, b_alg_loop5: %d, loop5_partition_b_width: %d, loop5_partition_b_start_col: %d, loop5_iter: %d, loop5_my_start: %d, loop5_my_end: %d\n\n", - # work_id_5th_loop, comm_id_5th_loop, b_alg_loop5, loop5_partition_b_width, - # loop5_partition_b_start_col, loop5_iter, loop5_my_start, loop5_my_end) gemm_4th_loop(a, b, c, loop5_partition_b_width, loop5_partition_b_start_col, comm_id_4th_loop, - work_id_4th_loop, work_id_5th_loop) loop5_iter += b_alg_loop5 @@ -1025,35 +931,12 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], packb_thrcomm_barrier_threads_arrived, loop5_nways) - # the nthreads and nways print for each loop - # printf("nthreads: %d\n", nthreads) - # printf("loop5_nthreads: %d, loop5_nways: %d\n", loop5_nthreads, loop5_nways) - # printf("loop4_nthreads: %d, loop4_nways: %d\n", loop4_nthreads, loop4_nways) - # printf("loop3_nthreads: %d, loop3_nways: %d\n", loop3_nthreads, loop3_nways) - # printf("macro_nthreads: %d, macro_nways: %d\n", macro_nthreads, macro_nways) - # printf("loop1_nthreads: %d, loop1_nways: %d\n", loop1_nthreads, loop1_nways) - # - # printf("packb_nthreads: %d, packa_nthreads: %d\n", packb_nthreads, packa_nthreads) - # - # printf("packed_b_width: %d, packed_b_total_width: %d, packed_b_height: %d\n", packed_b_width, - # packed_b_total_width, packed_b_height) - # printf("packed_a_width: %d, packed_a_individual_height: %d, packed_a_total_height: %d\n", - # packed_a_width, packed_a_individual_height, packed_a_total_height) - # - # printf("packed_b_total_size: %d, packed_a_total_size: %d\n", packed_b_total_size, packed_a_total_size) - # printf("packed_b_individual_size: %d, packed_a_individual_size: %d\n", packed_b_individual_size, packed_a_individual_size) - - # for i in grid(m_size): - # for j in grid(n_size): - # c[i, j] = 0.0 - parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads for tidx in grid(nthreads, attrs=parallel_attr): tid_5th_loop = tidx work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) comm_id_5th_loop = tid_5th_loop - # printf("tidx: %d, tid_5th_loop: %d, work_id_5th_loop: %d, comm_id_5th_loop: %d\n", tidx, tid_5th_loop, work_id_5th_loop, comm_id_5th_loop) gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) From 3fbb6353fe4ce0f3ee3aa929bb57f81b120a1a43 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 16 Nov 2023 21:57:51 -0500 Subject: [PATCH 131/148] ready for PR --- .../hidet/graph/ops/matmul/matmul_f32_x86.py | 1073 +++++++++++++---- .../ops/matmul/matmul_f32_x86_refactored.py | 958 --------------- .../graph/ops/matmul/matmul_f32_x86_v3.py | 974 --------------- 3 files changed, 809 insertions(+), 2196 deletions(-) delete mode 100644 python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py delete mode 100644 python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index eeb1a8557..28c355fb2 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -14,6 +14,7 @@ from hidet.ir.expr import cast from hidet.ir.module import IRModule from hidet.ir.compute import TensorNode +from hidet.ir.stmt import DeclareScope from hidet.ir.task import Task from hidet.ir.compute import compute, reduce from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast @@ -22,7 +23,8 @@ from hidet.graph.ops.utils import broadcast_indices -class MatmulF32Taskx86(Task): +class MatmulF32Taskx86_refactored(Task): + def __init__(self, a: TensorNode, b: TensorNode): a_shape = a.const_shape b_shape = b.const_shape @@ -58,13 +60,13 @@ def __init__(self, a: TensorNode, b: TensorNode): fcompute=lambda *indices: reduce( shape=[k_size], fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] - * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], + * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], reduce_type='sum', ), ) super().__init__( - name='matmul_f32_x86', + name='matmul_f32_x86_v2', inputs=[a, b], outputs=[c], attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, @@ -79,47 +81,183 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - @tune.space( - 2, - block_m=[2016, 3024], - block_n=[64, 144, 192, 256, 384, 512, 592, 672, 752, 896, 1024], - block_k=[96, 128, 256, 384, 512, 560, 688, 784], - nthreads=[4, 8, 16, 32], - ) - @tune.space(1, block_m=[2016], block_n=[256, 384, 512], block_k=[384, 512, 560], nthreads=[8, 16]) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) def schedule_matmulf32_x86( - self, block_m=2016, block_n=896, block_k=512, micro_ker=(6, 16), nthreads=16 + self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1) ) -> IRModule: import hidet from hidet.ir.type import tensor_type from hidet.lang import tensor, grid, as_tensor_pointer from hidet.lang.layout import row_major, column_major from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast - from hidet.lang.cpu import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store + from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned, avx_f32x8_setzero + from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi + from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 + from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 + from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor node_a, node_b = self.inputs[0], self.inputs[1] a_shape = node_a.const_shape b_shape = node_b.const_shape m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] - tile_m, tile_n = micro_ker + MR, NR = 6, 16 - supported_microkers = ((6, 16), (4, 8), (8, 8)) - tune.check(micro_ker in supported_microkers, "The size of the micro-kernel is not supported") + tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') - tune.check(block_m % tile_m == block_n % tile_n == 0, 'Tile size must divide the corresponding block size') + with hidet.script_module() as module: + # Get the number of threads... + loop5_nways, loop3_nways, macro_nways, loop1_nways = ways + loop4_nways = 1 + nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways + + packa_thrcomm_barrier_sense = module.define_global_var( + name="pack_a_barrier_sense", + var_type=int32[nthreads] + ) + packa_thrcomm_threads_arrived = module.define_global_var( + name="pack_a_threads_arrived", + var_type=int32[nthreads] + ) + + packb_thrcomm_barrier_sense = module.define_global_var( + name='pack_b_barrier_sense', + var_type=int32[nthreads] + ) + packb_thrcomm_barrier_threads_arrived = module.define_global_var( + name="pack_b_threads_arrived", + var_type=int32[nthreads] + ) - packed_a_type = tensor_type('float32', layout=row_major(block_m // tile_m, 1) * column_major(tile_m, block_k)) - packed_b_type = tensor_type('float32', layout=row_major(1, block_n // tile_n) * row_major(block_k, tile_n)) + @hidet.script + def init_thr(sense: ~int32, arrived: ~int32, size: int32): + for i in range(size): + sense[i] = 0 + arrived[i] = 0 - aip_outer_rows = block_m // tile_m - bip_outer_cols = block_n // tile_n + init_thr.kind = "cpu_internal" - with hidet.script_module() as module: + # Helpers + packed_a_type = tensor_type('float32', layout=row_major(MC // MR, + 1) * column_major( + MR, KC)) + packed_b_type = tensor_type('float32', layout=row_major(1, + NC // NR) * row_major( + KC, NR)) + + # Get the number of threads remaining at each level + loop4_nthreads = nthreads // loop5_nways + loop3_nthreads = loop4_nthreads + macro_nthreads = loop3_nthreads // loop3_nways + loop1_nthreads = macro_nthreads // macro_nways + + packb_nthreads = loop3_nthreads + packa_nthreads = macro_nthreads + + packed_a_buffers_needed = loop3_nways * loop5_nways + + @hidet.script + def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): + if n_way == 1: + start[0] = 0 + end[0] = n + return + all_start = 0 + all_end = n + size = all_end - all_start + + n_bf_whole = size // bf + n_bf_left = size % bf + + n_bf_lo = n_bf_whole // n_way + n_bf_hi = n_bf_whole // n_way + + n_th_lo = n_bf_whole % n_way + # If some partitions must have more block_factors than others, assign the slightly larger partitions to lower index threads + if n_th_lo != 0: + n_bf_lo += 1 + # Compute the actual widths (in units of rows/columns) of individual threads in the low and high groups + size_lo = n_bf_lo * bf + size_hi = n_bf_hi * bf + + # Pre-compute the starting indices of the low and high groups + lo_start = all_start + hi_start = all_start + n_th_lo * size_lo + + # Compute the start and end of individual threads' ranges + if work_id < n_th_lo: + start[0] = lo_start + work_id * size_lo + end[0] = lo_start + (work_id + 1) * size_lo + else: + start[0] = hi_start + (work_id - n_th_lo) * size_hi + end[0] = hi_start + (work_id - n_th_lo + 1) * size_hi + + # Add the remainder to the last thread's end + if work_id == n_way - 1: + end[0] += n_bf_left + end[0] = min(end[0], all_end) + + thread_range_sub.kind = "cpu_internal" + + @hidet.script + def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, + start: ~int32, end: ~int32, inc: ~int32): + start[0] = work_id + end[0] = n + inc[0] = n_way + + thread_range_jrir.kind = "cpu_internal" + + @hidet.script + def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: + dim_left_now = dim - i + b_now = -1 + if dim_left_now <= b_alg: + b_now = dim_left_now + else: + b_now = b_alg + assert b_now >= 0 + return b_now + + determine_blocksize_f_sub.kind = "cpu_internal" + + @hidet.script + def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: + return i != n_iter - 1 or n_left == 0 + + not_edge.kind = 'cpu_internal' @hidet.script - def micro_kernel_6x16( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32 + def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32: + return work_id_loop5 * loop3_nways + work_id_loop3 + + packa_index.kind = 'cpu_internal' + + # Thread barrier + @hidet.script + def thrcomm_barrier(barrier_sense: ~int32, + barrier_threads_arrived: ~int32, num_threads: int32): + if num_threads == 1: + return + orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED + + # Register the current thread's arrival by incrementing the counter + my_threads_arrived = cpu_atomic_add_fetch( + barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL + + if my_threads_arrived == num_threads: + barrier_threads_arrived[0] = 0 + cpu_atomic_fetch_xor(barrier_sense, 1, 3) # _ATOMIC_RELEASE + else: + while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE + pass + + thrcomm_barrier.kind = 'cpu_internal' + + @hidet.script + def micro_kernel( + a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, + is_first: bool ): c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) @@ -135,275 +273,682 @@ def micro_kernel_6x16( c5 = avx_f32x8_load(~c[5, 0]) c58 = avx_f32x8_load(~c[5, 8]) + if is_first: + c0 = avx_f32x8_setzero() + c08 = avx_f32x8_setzero() + c1 = avx_f32x8_setzero() + c18 = avx_f32x8_setzero() + c2 = avx_f32x8_setzero() + c28 = avx_f32x8_setzero() + c3 = avx_f32x8_setzero() + c38 = avx_f32x8_setzero() + c4 = avx_f32x8_setzero() + c48 = avx_f32x8_setzero() + c5 = avx_f32x8_setzero() + c58 = avx_f32x8_setzero() a_ptr = cast(a, ~float32) b_ptr = cast(b, ~float32) for _ in range(pb): - bb0to7 = avx_f32x8_load(b_ptr) - bb8to15 = avx_f32x8_load(b_ptr + 8) - b_ptr = b_ptr + 16 + bb0to7 = avx_f32x8_load_aligned(b_ptr) + bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - aa = avx_f32x8_broadcast(a_ptr) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) + aa1 = avx_f32x8_broadcast(a_ptr) + c0 = avx_f32x8_fmadd(aa1, bb0to7, c0) + c08 = avx_f32x8_fmadd(aa1, bb8to15, c08) - aa = avx_f32x8_broadcast(a_ptr + 1) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) + aa2 = avx_f32x8_broadcast(a_ptr + 1) + c1 = avx_f32x8_fmadd(aa2, bb0to7, c1) + c18 = avx_f32x8_fmadd(aa2, bb8to15, c18) - aa = avx_f32x8_broadcast(a_ptr + 2) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) + aa3 = avx_f32x8_broadcast(a_ptr + 2) + c2 = avx_f32x8_fmadd(aa3, bb0to7, c2) + c28 = avx_f32x8_fmadd(aa3, bb8to15, c28) - aa = avx_f32x8_broadcast(a_ptr + 3) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) + aa4 = avx_f32x8_broadcast(a_ptr + 3) + c3 = avx_f32x8_fmadd(aa4, bb0to7, c3) + c38 = avx_f32x8_fmadd(aa4, bb8to15, c38) - aa = avx_f32x8_broadcast(a_ptr + 4) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) + aa5 = avx_f32x8_broadcast(a_ptr + 4) + c4 = avx_f32x8_fmadd(aa5, bb0to7, c4) + c48 = avx_f32x8_fmadd(aa5, bb8to15, c48) - aa = avx_f32x8_broadcast(a_ptr + 5) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) + aa6 = avx_f32x8_broadcast(a_ptr + 5) + c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) + c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) a_ptr = a_ptr + 6 - avx_f32x8_store(~c[0, 0], c0) - avx_f32x8_store(~c[0, 8], c08) - avx_f32x8_store(~c[1, 0], c1) - avx_f32x8_store(~c[1, 8], c18) - avx_f32x8_store(~c[2, 0], c2) - avx_f32x8_store(~c[2, 8], c28) - avx_f32x8_store(~c[3, 0], c3) - avx_f32x8_store(~c[3, 8], c38) - avx_f32x8_store(~c[4, 0], c4) - avx_f32x8_store(~c[4, 8], c48) - avx_f32x8_store(~c[5, 0], c5) - avx_f32x8_store(~c[5, 8], c58) + b_ptr = b_ptr + 16 - @hidet.script - def micro_kernel_4x8( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32 - ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) - c0 = avx_f32x8_load(~c[0, 0]) - c1 = avx_f32x8_load(~c[1, 0]) - c2 = avx_f32x8_load(~c[2, 0]) - c3 = avx_f32x8_load(~c[3, 0]) + # Store the results + avx_f32x8_store(c_ptr, c0) + avx_f32x8_store(c_ptr + 8, c08) - for pp in range(pb): - bb = avx_f32x8_load(~b[pp, 0]) - - aa = avx_f32x8_broadcast(~a[0, pp]) - c0 = avx_f32x8_fmadd(aa, bb, c0) - aa = avx_f32x8_broadcast(~a[1, pp]) - c1 = avx_f32x8_fmadd(aa, bb, c1) - aa = avx_f32x8_broadcast(~a[2, pp]) - c2 = avx_f32x8_fmadd(aa, bb, c2) - aa = avx_f32x8_broadcast(~a[3, pp]) - c3 = avx_f32x8_fmadd(aa, bb, c3) - avx_f32x8_store(~c[0, 0], c0) - avx_f32x8_store(~c[1, 0], c1) - avx_f32x8_store(~c[2, 0], c2) - avx_f32x8_store(~c[3, 0], c3) + avx_f32x8_store(c_ptr + nsize, c1) + avx_f32x8_store(c_ptr + (nsize + 8), c18) + + avx_f32x8_store(c_ptr + 2 * nsize, c2) + avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) + + avx_f32x8_store(c_ptr + 3 * nsize, c3) + avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) + + avx_f32x8_store(c_ptr + 4 * nsize, c4) + avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) + + avx_f32x8_store(c_ptr + 5 * nsize, c5) + avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) + + #### Some setup code #### + packed_b_height = KC + if packed_b_height > k_size: + packed_b_height = k_size + packed_b_width = NC + if packed_b_width > n_size: + packed_b_width = (n_size + NR - 1) // NR * NR + + packed_b_total_width = packed_b_width * loop5_nways + packed_b_total_size = packed_b_total_width * packed_b_height + packed_b_individual_size = packed_b_width * packed_b_height + + packed_a_individual_height = MC + if packed_a_individual_height > m_size: + packed_a_individual_height = (m_size + MR - 1) // MR * MR + packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed + + packed_a_width = KC + if packed_a_width > k_size: + packed_a_width = k_size + # pad this to be able to use the aligned version of the avx store + packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 + packed_a_total_size = packed_a_total_height * packed_a_width + packed_a_individual_size = packed_a_width * packed_a_individual_height + + packb_buf_ptr = module.define_global_var( + name='packb_buf_ptr', + var_type=float32[packed_b_total_size] + ) + packa_buf_ptr = module.define_global_var( + name='packa_buf_ptr', + var_type=float32[packed_a_total_size] + ) + + packb_buf = cast(packb_buf_ptr, ~float32) + packa_buf = cast(packa_buf_ptr, ~float32) + + ##### Start of the loops around micro kernel ##### @hidet.script - def micro_kernel_8x8( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32 + def gemm_pack_a( + loop3_partition_a: ~float32, + loop3_partition_a_width: int32, + loop3_partition_a_height: int32, + packed_a_buf: ~float32, + work_id_packa: int32, ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) - c0 = avx_f32x8_load(~c[0, 0]) - c1 = avx_f32x8_load(~c[1, 0]) - c2 = avx_f32x8_load(~c[2, 0]) - c3 = avx_f32x8_load(~c[3, 0]) - c4 = avx_f32x8_load(~c[4, 0]) - c5 = avx_f32x8_load(~c[5, 0]) - c6 = avx_f32x8_load(~c[6, 0]) - c7 = avx_f32x8_load(~c[7, 0]) - - for pp in range(pb): - bb = avx_f32x8_load(~b[pp, 0]) - - aa = avx_f32x8_broadcast(~a[0, pp]) - c0 = avx_f32x8_fmadd(aa, bb, c0) - aa = avx_f32x8_broadcast(~a[1, pp]) - c1 = avx_f32x8_fmadd(aa, bb, c1) - aa = avx_f32x8_broadcast(~a[2, pp]) - c2 = avx_f32x8_fmadd(aa, bb, c2) - aa = avx_f32x8_broadcast(~a[3, pp]) - c3 = avx_f32x8_fmadd(aa, bb, c3) - aa = avx_f32x8_broadcast(~a[4, pp]) - c4 = avx_f32x8_fmadd(aa, bb, c4) - aa = avx_f32x8_broadcast(~a[5, pp]) - c5 = avx_f32x8_fmadd(aa, bb, c5) - aa = avx_f32x8_broadcast(~a[6, pp]) - c6 = avx_f32x8_fmadd(aa, bb, c6) - aa = avx_f32x8_broadcast(~a[7, pp]) - c7 = avx_f32x8_fmadd(aa, bb, c7) - avx_f32x8_store(~c[0, 0], c0) - avx_f32x8_store(~c[1, 0], c1) - avx_f32x8_store(~c[2, 0], c2) - avx_f32x8_store(~c[3, 0], c3) - avx_f32x8_store(~c[4, 0], c4) - avx_f32x8_store(~c[5, 0], c5) - avx_f32x8_store(~c[6, 0], c6) - avx_f32x8_store(~c[7, 0], c7) + packed_a_tensor = as_tensor_pointer( + packed_a_buf, + float32, + layout=row_major(packed_a_individual_height // MR, 1) * + column_major(MR, packed_a_width) + ) + + npanels_full_a = loop3_partition_a_height // MR + panel_a_remainder = loop3_partition_a_height % MR + + npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) + + for ii_panel in range(npanels_a): + if ii_panel % packa_nthreads != work_id_packa % packa_nthreads: + continue + + a_curr_panel_row_start = ii_panel * MR + a_curr_panel_height = min(MR, + loop3_partition_a_height - a_curr_panel_row_start) + + if a_curr_panel_height == MR: # unroll the packing by 8 + k_iters = loop3_partition_a_width // 8 + k_remainder = loop3_partition_a_width % 8 + col = 0 + for k_iter in range(k_iters): + col = k_iter * 8 + a_curr_panel_col = loop3_partition_a + ( + a_curr_panel_row_start * k_size + col + ) + + v0 = avx_f32x8_load(a_curr_panel_col) + v1 = avx_f32x8_load(a_curr_panel_col + k_size) + v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) + v3 = avx_f32x8_load(a_curr_panel_col + (3 * k_size)) + v4 = avx_f32x8_load(a_curr_panel_col + (4 * k_size)) + v5 = avx_f32x8_load(a_curr_panel_col + (5 * k_size)) + + unpack0 = avx_f32x8_unpacklo(v0, v1) + unpack1 = avx_f32x8_unpackhi(v0, v1) + unpack2 = avx_f32x8_unpacklo(v2, v3) + unpack3 = avx_f32x8_unpackhi(v2, v3) + unpack4 = avx_f32x8_unpacklo(v4, v5) + unpack5 = avx_f32x8_unpackhi(v4, v5) + + shf0 = avx_f32x8_shuffle(unpack0, unpack2, 0x44) + shf1 = avx_f32x8_shuffle(unpack4, unpack0, 0xE4) + shf2 = avx_f32x8_shuffle(unpack2, unpack4, 0xEE) + shf3 = avx_f32x8_shuffle(unpack5, unpack1, 0xE4) + shf4 = avx_f32x8_shuffle(unpack3, unpack5, 0xEE) + shf5 = avx_f32x8_shuffle(unpack1, unpack3, 0x44) + + low_shf1 = avx_f32x8_cast_f32x4(shf1) + res0 = avx_f32x8_insert_f32x4(shf0, low_shf1, 0x1) + res1 = avx_f32x8_permute2f32x4(shf0, shf1, 0x31) + + low_shf5 = avx_f32x8_cast_f32x4(shf5) + res2 = avx_f32x8_insert_f32x4(shf2, low_shf5, 0x1) + res3 = avx_f32x8_permute2f32x4(shf2, shf5, 0x31) + + low_shf4 = avx_f32x8_cast_f32x4(shf4) + res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) + res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) + + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start, col], + res0 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 2, + col + 1], + res2 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 4, + col + 2], + res4) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start, + col + 4], + res1 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 2, + col + 5], + res3 + ) + avx_f32x8_store_aligned( + ~packed_a_tensor[a_curr_panel_row_start + 4, + col + 6], + res5 + ) + remaining_start_col = k_iters * 8 + for remain_off in range(k_remainder): + curr_remain_col = remaining_start_col + remain_off + for micropanel_row in range(MR): + packed_a_tensor[ + a_curr_panel_row_start + micropanel_row, + curr_remain_col] = \ + loop3_partition_a[( + micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] + else: + remain_start_row = npanels_full_a * MR + for remain_col in range(loop3_partition_a_width): + for remain_row in range(panel_a_remainder): + packed_a_tensor[ + remain_start_row + remain_row, remain_col] = \ + loop3_partition_a[( + remain_row + remain_start_row) * k_size + remain_col] + remain_row = panel_a_remainder + while remain_row < MR: + packed_a_tensor[ + remain_start_row + remain_row, remain_col] = 0.0 + remain_row += 1 @hidet.script - def micro_kernel_4x4( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32 + def gemm_pack_b( + loop4_partition_b: ~float32, + loop4_partition_b_width: int32, + loop4_partition_b_height: int32, + packed_b_buf: ~float32, + work_id_packb: int32 ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) - - c0 = avx_f32x4_load(~c[0, 0]) - c1 = avx_f32x4_load(~c[1, 0]) - c2 = avx_f32x4_load(~c[2, 0]) - c3 = avx_f32x4_load(~c[3, 0]) - - for pp in range(pb): - bb = avx_f32x4_load(~b[pp, 0]) - - aa = avx_f32x4_broadcast(~a[0, pp]) - c0 = avx_f32x4_fmadd(aa, bb, c0) - aa = avx_f32x4_broadcast(~a[1, pp]) - c1 = avx_f32x4_fmadd(aa, bb, c1) - aa = avx_f32x4_broadcast(~a[2, pp]) - c2 = avx_f32x4_fmadd(aa, bb, c2) - aa = avx_f32x4_broadcast(~a[3, pp]) - c3 = avx_f32x4_fmadd(aa, bb, c3) - avx_f32x4_store(~c[0, 0], c0) - avx_f32x4_store(~c[1, 0], c1) - avx_f32x4_store(~c[2, 0], c2) - avx_f32x4_store(~c[3, 0], c3) - - micro_kernel = micro_kernel_6x16 - if tile_m == 8 and tile_n == 8: - micro_kernel = micro_kernel_8x8 - elif tile_m == 4 and tile_n == 8: - micro_kernel = micro_kernel_4x8 - elif tile_m == 4 and tile_n == 4: - micro_kernel = micro_kernel_4x4 + npanels_full_b = loop4_partition_b_width // NR + npanels_b_remainder = loop4_partition_b_width % NR + + npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) + packedb_panel_stride = packed_b_height * NR + + # Loop for the packing of B + for i_panel in range(npanels_b): + if i_panel % packb_nthreads != work_id_packb % packb_nthreads: + continue + packed_b_buff_curr = packed_b_buf + ( + i_panel * packedb_panel_stride) + curr_panel_start = i_panel * NR + curr_panel_width = min(NR, + loop4_partition_b_width - curr_panel_start) + + if curr_panel_width == NR: + k_iters = loop4_partition_b_height // 8 + k_remainder = loop4_partition_b_height % 8 + + row = 0 + for k_iter in range(k_iters): + row = k_iter * 8 + b_panel = loop4_partition_b + ( + row * n_size + curr_panel_start) + b00 = avx_f32x8_load(b_panel) + b08 = avx_f32x8_load(b_panel + 8) + + avx_f32x8_store_aligned(packed_b_buff_curr, b00) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) + packed_b_buff_curr += 16 + + b10 = avx_f32x8_load(b_panel + n_size) + b18 = avx_f32x8_load(b_panel + (n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b10) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b18) + packed_b_buff_curr += 16 + + b20 = avx_f32x8_load(b_panel + (2 * n_size)) + b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b20) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b28) + packed_b_buff_curr += 16 + + b30 = avx_f32x8_load(b_panel + (3 * n_size)) + b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b30) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b38) + packed_b_buff_curr += 16 + + b40 = avx_f32x8_load(b_panel + (4 * n_size)) + b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b40) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b48) + packed_b_buff_curr += 16 + + b50 = avx_f32x8_load(b_panel + (5 * n_size)) + b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b50) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b58) + packed_b_buff_curr += 16 + + b60 = avx_f32x8_load(b_panel + (6 * n_size)) + b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b60) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b68) + packed_b_buff_curr += 16 + + b70 = avx_f32x8_load(b_panel + (7 * n_size)) + b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) + + avx_f32x8_store_aligned(packed_b_buff_curr, b70) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b78) + + packed_b_buff_curr += 16 + + row = k_iters * 8 + for _ in range(k_remainder): + b_panel = loop4_partition_b + ( + row * n_size + curr_panel_start) + b00 = avx_f32x8_load(b_panel) + b08 = avx_f32x8_load(b_panel + 8) + avx_f32x8_store_aligned(packed_b_buff_curr, b00) + avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) + packed_b_buff_curr += 16 + row += 1 + + else: + packed_b_remaining_buf = packed_b_buf + ( + npanels_full_b * packedb_panel_stride) + if npanels_b_remainder > 0: + remain_col_start = npanels_full_b * NR + for remain_row in range(loop4_partition_b_height): + packed_b_remaining_buf_curr = packed_b_remaining_buf + ( + remain_row * NR) + for remain_col in range(npanels_b_remainder): + packed_b_remaining_buf_curr[0] = \ + loop4_partition_b[ + (remain_row * n_size) + ( + remain_col_start + remain_col) + ] + packed_b_remaining_buf_curr += 1 + zero_fill_col = npanels_b_remainder + while zero_fill_col < NR: + packed_b_remaining_buf_curr[0] = 0.0 + packed_b_remaining_buf_curr += 1 + zero_fill_col += 1 + + gemm_pack_b.kind = "cpu_internal" + gemm_pack_a.kind = "cpu_internal" + micro_kernel.kind = "cpu_internal" @hidet.script - def macro_kernel( - a: packed_a_type, b: packed_b_type, c_in_macro: float32[m_size, n_size], ib: int32, jb: int32, pb: int32 + def gemm_macro( + packed_a: ~float32, + packed_b: ~float32, + c: float32[m_size, n_size], + c_row_off: int32, + c_col_off: int32, + macro_m: int32, + macro_n: int32, + macro_k: int32, + ps_packed_a: int32, + ps_packed_b: int32, + comm_id_macro: int32, + work_id_macro: int32, + is_first: bool ): - mpanels = (ib + tile_m - 1) // tile_m - npanels = (jb + tile_n - 1) // tile_n - _mr = ib % tile_m - _nr = jb % tile_n - - # Loop 2 - para = 'p' + str(nthreads) - for mpanel in grid(mpanels, attrs=para): - mr = tile_m if mpanel != mpanels - 1 or _mr == 0 else _mr - ii = mpanel * tile_m - # Loop 1 - for npanel in range(npanels): - nr = tile_n if npanel != npanels - 1 or _nr == 0 else _nr - jj = npanel * tile_n - # micro-kernel - if mr == tile_m and nr == tile_n: - micro_kernel(~a[ii, 0], ~b[0, jj], ~c_in_macro[ii, jj], pb, m_size, n_size) + comm_id_1st_loop = comm_id_macro % loop1_nthreads + work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways) + + n_iter = macro_n // NR + n_remainder = macro_n % NR + m_iter = macro_m // MR + m_remainder = macro_m % MR + + if n_remainder > 0: + n_iter += 1 + if m_remainder > 0: + m_iter += 1 + + jr_start = -1 + jr_end = -1 + ir_start = -1 + ir_end = -1 + jr_inc = -1 + ir_inc = -1 + + thread_range_jrir( + work_id_macro, + macro_nways, + n_iter, + 1, + ~jr_start, + ~jr_end, + ~jr_inc + ) + + thread_range_jrir( + work_id_1st_loop, + loop1_nways, + m_iter, + 1, + ~ir_start, + ~ir_end, + ~ir_inc + ) + + rstep_a = ps_packed_a + cstep_b = ps_packed_b + + cstep_c = NR + rstep_c = n_size * MR + + macro_c_cast = as_tensor_pointer( + ~c[c_row_off, c_col_off], + dtype=float32, + shape=(m_size, n_size) + ) + temp_c = tensor(scope=DeclareScope.Default, + dtype=float32, + layout=row_major(MR, NR), + is_static=False) + j = jr_start + while j < jr_end: + b1 = packed_b + j * cstep_b + c1 = macro_c_cast + j * cstep_c + n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder + + i = ir_start + while i < ir_end: + a1 = packed_a + i * rstep_a + c11 = c1 + i * rstep_c + c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) + m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder + + if m_cur == MR and n_cur == NR: + micro_kernel(a1, b1, c11, macro_k, m_size, n_size, is_first) else: - temp_c = tensor(dtype='float32', layout=row_major(tile_m, tile_n)) - for tempi in range(tile_m): - for tempj in range(tile_n): - temp_c[tempi, tempj] = 0.0 - micro_kernel(~a[ii, 0], ~b[0, jj], temp_c, pb, tile_m, tile_n) - for remain_row, remain_col in grid(mr, nr): - c_in_macro[ii + remain_row, jj + remain_col] += temp_c[remain_row, remain_col] + for i, j in grid(MR, NR): + temp_c[i, j] = 0.0 + micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) + if not is_first: + for mm, nn in grid(m_cur, n_cur): + c11[mm, nn] += temp_c[mm, nn] + else: + for mm, nn in grid(m_cur, n_cur): + c11[mm, nn] = temp_c[mm, nn] + + i += ir_inc + j += jr_inc + + gemm_macro.kind = "cpu_internal" + + @hidet.script + def gemm_3rd_loop( + a: float32[m_size, k_size], + packed_b: ~float32, + c: float32[m_size, n_size], + loop3_partition_a_start_col: int32, + loop3_partition_b_start_col: int32, + loop3_partition_a_width: int32, + loop3_partition_b_width: int32, + comm_id_3rd_loop: int32, + work_id_3rd_loop: int32, + is_first: bool, work_id_5th_loop: int32): + comm_id_macro = comm_id_3rd_loop % macro_nthreads + work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) + work_id_packa = comm_id_macro + + m_start_loop3 = 0 + m_end_loop3 = 0 + thread_range_sub( + loop3_nways, + work_id_3rd_loop, + m_size, + MR, + ~m_start_loop3, + ~m_end_loop3 + ) + + ii = m_start_loop3 + while ii < m_end_loop3: + b_alg_loop3 = determine_blocksize_f_sub( + ii, m_size, MC + ) + b_alg_loop3 = min(b_alg_loop3, m_end_loop3 - ii) + loop3_partition_a_start_row = ii + loop3_partition_a_height = b_alg_loop3 + + loop3_partition_a = cast(a, ~float32) + ( + loop3_partition_a_start_row * k_size + + loop3_partition_a_start_col + ) + + packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) + packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) + + + thrcomm_barrier( + ~packa_thrcomm_barrier_sense[packed_a_idx], + ~packa_thrcomm_threads_arrived[packed_a_idx], + packa_nthreads + ) + + gemm_pack_a( + loop3_partition_a, + loop3_partition_a_width, + loop3_partition_a_height, + packed_a_buf, + work_id_packa, + ) + + # This marks the end of the packing of A, + # so a barrier is needed + thrcomm_barrier( + ~packa_thrcomm_barrier_sense[packed_a_idx], + ~packa_thrcomm_threads_arrived[packed_a_idx], + packa_nthreads + ) + + gemm_macro(packed_a_buf, + packed_b, + c, + loop3_partition_a_start_row, + loop3_partition_b_start_col, + loop3_partition_a_height, + loop3_partition_b_width, + loop3_partition_a_width, + MR * packed_a_width, + packed_b_height * NR, + comm_id_macro, + work_id_macro, + is_first + ) + ii += b_alg_loop3 + + gemm_3rd_loop.kind = "cpu_internal" + + @hidet.script + def gemm_4th_loop(a: float32[m_size, k_size], + b: float32[k_size, n_size], + c: float32[k_size, n_size], + loop5_partition_b_width: int32, + loop5_partition_b_start_col: int32, + comm_id_4th_loop: int32, + work_id_5th_loop: int32): + i_loop4 = 0 + + comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads + work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) + work_id_packb = comm_id_3rd_loop + + while i_loop4 < k_size: + b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) + b_alg_loop4 = min(b_alg_loop4, k_size - i_loop4) + + loop4_partition_b_height = b_alg_loop4 + loop4_partition_b_width = loop5_partition_b_width + loop4_partition_b_start_row = i_loop4 + loop4_partition_b_start_col = loop5_partition_b_start_col + + loop4_partition_a_start_col = i_loop4 + is_first = (i_loop4 == 0) + + packed_b_buf = packb_buf + ( + packed_b_individual_size * work_id_5th_loop + ) + + loop4_partition_b = cast(b, ~float32) + \ + (loop4_partition_b_start_row * n_size + + loop4_partition_b_start_col) + + thrcomm_barrier( + ~packb_thrcomm_barrier_sense[work_id_5th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], + packb_nthreads + ) + + gemm_pack_b(loop4_partition_b, loop4_partition_b_width, + loop4_partition_b_height, packed_b_buf, + work_id_packb) + + thrcomm_barrier( + ~packb_thrcomm_barrier_sense[work_id_5th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], + packb_nthreads + ) + + gemm_3rd_loop( + a, packed_b_buf, c, + loop4_partition_a_start_col, + loop4_partition_b_start_col, + loop4_partition_b_height, + loop4_partition_b_width, + comm_id_3rd_loop, + work_id_3rd_loop, + is_first, + work_id_5th_loop + ) + + thrcomm_barrier( + ~packb_thrcomm_barrier_sense[work_id_5th_loop], + ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], + packb_nthreads + ) + + i_loop4 += b_alg_loop4 + + gemm_4th_loop.kind = "cpu_internal" @hidet.script - def matmul_kernel_x86(a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size]): - mbs = (m_size + block_m - 1) // block_m - nbs = (n_size + block_n - 1) // block_n - kbs = (k_size + block_k - 1) // block_k - - packed_a = tensor(dtype=float32, layout=row_major(aip_outer_rows, 1) * column_major(tile_m, block_k)) - packed_b = tensor(dtype=float32, layout=row_major(1, bip_outer_cols) * row_major(block_k, tile_n)) - - for mb in range(mbs): - i = mb * block_m - ib = min(block_m, m_size - i) - for kb in range(kbs): - p = kb * block_k - pb = min(block_k, k_size - p) - - mp = ib // tile_m - mr = ib % tile_m - - # Should be working? But error in really strange ways.... - # packeda_ptr = cast(packed_a, ~float32) - # idx = 0 - for micropanel_idx in range(mp): - panel_row_start = micropanel_idx * tile_m - for micropanel_col in range(pb): - for micropanel_row in range(tile_m): - packed_a[panel_row_start + micropanel_row, micropanel_col] = a[ - i + micropanel_row + panel_row_start, p + micropanel_col - ] - - # TODO: really strange; the index is indeed incremented by 1 each iteration, - # TODO: but I just can't get this to pass the test... - # packeda_ptr[idx] = a[i + micropanel_row + panel_row_start, p + micropanel_col] - # idx += 1 - if mr > 0: - remain_start_row = mp * tile_m - for remain_col in range(pb): - for remain_row in range(mr): - packed_a[remain_start_row + remain_row, remain_col] = a[ - i + remain_start_row + remain_row, p + remain_col - ] - remain_row = mr - while remain_row < tile_m: - packed_a[remain_start_row + remain_row, remain_col] = 0.0 - remain_row += 1 - - for nb in range(nbs): - j = nb * block_n - jb = min(block_n, n_size - j) - np = jb // tile_n - nr = jb % tile_n - - # packedb_ptr = cast(packed_b, ~float32) - # idx = 0 - for micropanel_idx in range(np): - panel_col_start = micropanel_idx * tile_n - for micropanel_row in range(pb): - for micropanel_col in range(tile_n): - packed_b[micropanel_row, micropanel_col + panel_col_start] = b[ - p + micropanel_row, j + micropanel_col + panel_col_start - ] - # TODO: same as above... why isn't this working? - # packedb_ptr[idx] = b[p + micropanel_row, j + micropanel_col + panel_col_start] - # idx += 1 - if nr > 0: - remain_col_start = np * tile_n - for remain_row in range(pb): - for remain_col in range(nr): - packed_b[remain_row, remain_col + remain_col_start] = b[ - p + remain_row, j + remain_col + remain_col_start - ] - remain_col = nr - while remain_col < tile_n: - packed_b[remain_row, remain_col_start + remain_col] = 0.0 - remain_col += 1 - macro_kernel(packed_a, packed_b, ~c[i, j], ib, jb, pb) - - assert isinstance(matmul_kernel_x86, hidet.ir.Function) - matmul_kernel_x86.kind = "cpu_kernel" - ir_module = module.ir_module() - return ir_module - - -class Matmulx86Op(Operator): + def gemm_5th_loop(a: float32[m_size, k_size], + b: float32[k_size, n_size], + c: float32[m_size, n_size], + work_id_5th_loop: int32, + comm_id_5th_loop: int32): + comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads + + loop5_my_start = -1 + loop5_my_end = -1 + thread_range_sub(loop5_nways, work_id_5th_loop, n_size, + NR, ~loop5_my_start, ~loop5_my_end) + + loop5_iter = loop5_my_start + while loop5_iter < loop5_my_end: + b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, + loop5_my_end, NC) + b_alg_loop5 = min(b_alg_loop5, loop5_my_end - loop5_iter) + + loop5_partition_b_width = b_alg_loop5, + loop5_partition_b_start_col = loop5_iter + gemm_4th_loop(a, b, c, + loop5_partition_b_width, + loop5_partition_b_start_col, + comm_id_4th_loop, + work_id_5th_loop) + loop5_iter += b_alg_loop5 + + gemm_5th_loop.kind = 'cpu_internal' + + ################### Start of the main kernel ################### + @hidet.script + def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], + c: float32[m_size, n_size]): + + init_thr(packa_thrcomm_barrier_sense, + packa_thrcomm_threads_arrived, + loop3_nways) + init_thr(packb_thrcomm_barrier_sense, + packb_thrcomm_barrier_threads_arrived, + loop5_nways) + + parallel_attr = 'p' + str(nthreads) + # The outermost loop spawning threads + for tidx in grid(nthreads, attrs=parallel_attr): + tid_5th_loop = tidx + work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) + comm_id_5th_loop = tid_5th_loop + + gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) + + assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) + matmul_kernel_x86_v3.kind = "cpu_kernel" + ir_module = module.ir_module() + return ir_module + + +class Matmulx86Op_refactored(Operator): def __init__(self, a: Tensor, b: Tensor): if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) - task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b')) + task = MatmulF32Taskx86_refactored(input_like(a, 'a'), input_like(b, 'b')) super().__init__(inputs=[a, b], attributes={}, task=task) -def matmul_x86(a: Tensor, b: Tensor) -> Tensor: - return Matmulx86Op(a, b).outputs[0] +def matmul_x86_refactored(a: Tensor, b: Tensor) -> Tensor: + return Matmulx86Op_refactored(a, b).outputs[0] diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py deleted file mode 100644 index dcaee628a..000000000 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_refactored.py +++ /dev/null @@ -1,958 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Union -from hidet.ir.dtypes import float32, int32 -from hidet.ir.expr import cast -from hidet.ir.module import IRModule -from hidet.ir.compute import TensorNode -from hidet.ir.stmt import DeclareScope -from hidet.ir.task import Task -from hidet.ir.compute import compute, reduce -from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast -from hidet.ir.library import tune -from hidet.graph.operator import Operator, Tensor -from hidet.graph.ops.utils import broadcast_indices - - -class MatmulF32Taskx86_refactored(Task): - - def __init__(self, a: TensorNode, b: TensorNode): - a_shape = a.const_shape - b_shape = b.const_shape - - if not a.type.dtype == float32 or not b.type.dtype == float32: - raise ValueError('Both inputs must be float32 tensors') - - if len(a_shape) < 2 or len(b_shape) < 2: - raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) - - self._assert( - a_shape[-1] == b_shape[-2], - msg=( - 'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' - ', got {} and {}'.format(a_shape, b_shape) - ), - ) - - self._assert( - can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), - msg=( - 'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' - 'got {} and {}'.format(a_shape, b_shape) - ), - ) - - k_size = a_shape[-1] - c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] - - c = compute( - name='c', - shape=c_shape, - fcompute=lambda *indices: reduce( - shape=[k_size], - fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] - * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], - reduce_type='sum', - ), - ) - - super().__init__( - name='matmul_f32_x86_v2', - inputs=[a, b], - outputs=[c], - attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, - ) - - def allow_epilogue(self) -> bool: - return True - - def allow_prologue(self) -> bool: - return False - - def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: - return tune.extract_ir_modules(self.schedule_matmulf32_x86) - - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) - def schedule_matmulf32_x86( - self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1) - ) -> IRModule: - import hidet - from hidet.ir.type import tensor_type - from hidet.lang import tensor, grid, as_tensor_pointer - from hidet.lang.layout import row_major, column_major - from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast - from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned, avx_f32x8_setzero - from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi - from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 - from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 - from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor - - node_a, node_b = self.inputs[0], self.inputs[1] - a_shape = node_a.const_shape - b_shape = node_b.const_shape - m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] - - MR, NR = 6, 16 - - tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') - - with hidet.script_module() as module: - # Get the number of threads... - loop5_nways, loop3_nways, macro_nways, loop1_nways = ways - loop4_nways = 1 - nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways - - packa_thrcomm_barrier_sense = module.define_global_var( - name="pack_a_barrier_sense", - var_type=int32[nthreads] - ) - packa_thrcomm_threads_arrived = module.define_global_var( - name="pack_a_threads_arrived", - var_type=int32[nthreads] - ) - - packb_thrcomm_barrier_sense = module.define_global_var( - name='pack_b_barrier_sense', - var_type=int32[nthreads] - ) - packb_thrcomm_barrier_threads_arrived = module.define_global_var( - name="pack_b_threads_arrived", - var_type=int32[nthreads] - ) - - @hidet.script - def init_thr(sense: ~int32, arrived: ~int32, size: int32): - for i in range(size): - sense[i] = 0 - arrived[i] = 0 - - init_thr.kind = "cpu_internal" - - # Helpers - packed_a_type = tensor_type('float32', layout=row_major(MC // MR, - 1) * column_major( - MR, KC)) - packed_b_type = tensor_type('float32', layout=row_major(1, - NC // NR) * row_major( - KC, NR)) - - # Get the number of threads remaining at each level - loop4_nthreads = nthreads // loop5_nways - loop3_nthreads = loop4_nthreads - macro_nthreads = loop3_nthreads // loop3_nways - loop1_nthreads = macro_nthreads // macro_nways - - packb_nthreads = loop3_nthreads - packa_nthreads = macro_nthreads - - packed_a_buffers_needed = loop3_nways * loop5_nways - - @hidet.script - def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): - if n_way == 1: - start[0] = 0 - end[0] = n - return - all_start = 0 - all_end = n - size = all_end - all_start - - n_bf_whole = size // bf - n_bf_left = size % bf - - n_bf_lo = n_bf_whole // n_way - n_bf_hi = n_bf_whole // n_way - - n_th_lo = n_bf_whole % n_way - # If some partitions must have more block_factors than others, assign the slightly larger partitions to lower index threads - if n_th_lo != 0: - n_bf_lo += 1 - # Compute the actual widths (in units of rows/columns) of individual threads in the low and high groups - size_lo = n_bf_lo * bf - size_hi = n_bf_hi * bf - - # Pre-compute the starting indices of the low and high groups - lo_start = all_start - hi_start = all_start + n_th_lo * size_lo - - # Compute the start and end of individual threads' ranges - if work_id < n_th_lo: - start[0] = lo_start + work_id * size_lo - end[0] = lo_start + (work_id + 1) * size_lo - else: - start[0] = hi_start + (work_id - n_th_lo) * size_hi - end[0] = hi_start + (work_id - n_th_lo + 1) * size_hi - - # Add the remainder to the last thread's end - if work_id == n_way - 1: - end[0] += n_bf_left - end[0] = min(end[0], all_end) - - thread_range_sub.kind = "cpu_internal" - - @hidet.script - def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, - start: ~int32, end: ~int32, inc: ~int32): - start[0] = work_id - end[0] = n - inc[0] = n_way - - thread_range_jrir.kind = "cpu_internal" - - @hidet.script - def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: - dim_left_now = dim - i - b_now = -1 - if dim_left_now <= b_alg: - b_now = dim_left_now - else: - b_now = b_alg - assert b_now >= 0 - return b_now - - determine_blocksize_f_sub.kind = "cpu_internal" - - @hidet.script - def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: - return i != n_iter - 1 or n_left == 0 - - not_edge.kind = 'cpu_internal' - - # TODO: Is this the way to find out the "index" of the packed A buffer? - @hidet.script - def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32: - return work_id_loop5 * loop3_nways + work_id_loop3 - - packa_index.kind = 'cpu_internal' - - # Thread barrier - @hidet.script - def thrcomm_barrier(barrier_sense: ~int32, - barrier_threads_arrived: ~int32, num_threads: int32): - if num_threads == 1: - return - orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED - - # Register the current thread's arrival by incrementing the counter - my_threads_arrived = cpu_atomic_add_fetch( - barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL - - if my_threads_arrived == num_threads: - barrier_threads_arrived[0] = 0 - cpu_atomic_fetch_xor(barrier_sense, 1, 3) # _ATOMIC_RELEASE - else: - while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE - pass - - thrcomm_barrier.kind = 'cpu_internal' - - @hidet.script - def micro_kernel( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, - is_first: bool - ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) - c0 = avx_f32x8_load(~c[0, 0]) - c08 = avx_f32x8_load(~c[0, 8]) - c1 = avx_f32x8_load(~c[1, 0]) - c18 = avx_f32x8_load(~c[1, 8]) - c2 = avx_f32x8_load(~c[2, 0]) - c28 = avx_f32x8_load(~c[2, 8]) - c3 = avx_f32x8_load(~c[3, 0]) - c38 = avx_f32x8_load(~c[3, 8]) - c4 = avx_f32x8_load(~c[4, 0]) - c48 = avx_f32x8_load(~c[4, 8]) - c5 = avx_f32x8_load(~c[5, 0]) - c58 = avx_f32x8_load(~c[5, 8]) - - if is_first: - c0 = avx_f32x8_setzero() - c08 = avx_f32x8_setzero() - c1 = avx_f32x8_setzero() - c18 = avx_f32x8_setzero() - c2 = avx_f32x8_setzero() - c28 = avx_f32x8_setzero() - c3 = avx_f32x8_setzero() - c38 = avx_f32x8_setzero() - c4 = avx_f32x8_setzero() - c48 = avx_f32x8_setzero() - c5 = avx_f32x8_setzero() - c58 = avx_f32x8_setzero() - a_ptr = cast(a, ~float32) - b_ptr = cast(b, ~float32) - - # TODO: For now, let's forget about unrolling for now. - for _ in range(pb): - bb0to7 = avx_f32x8_load_aligned(b_ptr) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - - aa1 = avx_f32x8_broadcast(a_ptr) - c0 = avx_f32x8_fmadd(aa1, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa1, bb8to15, c08) - - aa2 = avx_f32x8_broadcast(a_ptr + 1) - c1 = avx_f32x8_fmadd(aa2, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa2, bb8to15, c18) - - aa3 = avx_f32x8_broadcast(a_ptr + 2) - c2 = avx_f32x8_fmadd(aa3, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa3, bb8to15, c28) - - aa4 = avx_f32x8_broadcast(a_ptr + 3) - c3 = avx_f32x8_fmadd(aa4, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa4, bb8to15, c38) - - aa5 = avx_f32x8_broadcast(a_ptr + 4) - c4 = avx_f32x8_fmadd(aa5, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa5, bb8to15, c48) - - aa6 = avx_f32x8_broadcast(a_ptr + 5) - c5 = avx_f32x8_fmadd(aa6, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa6, bb8to15, c58) - - a_ptr = a_ptr + 6 - b_ptr = b_ptr + 16 - - # Store the results - avx_f32x8_store(c_ptr, c0) - avx_f32x8_store(c_ptr + 8, c08) - - avx_f32x8_store(c_ptr + nsize, c1) - avx_f32x8_store(c_ptr + (nsize + 8), c18) - - avx_f32x8_store(c_ptr + 2 * nsize, c2) - avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) - - avx_f32x8_store(c_ptr + 3 * nsize, c3) - avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) - - avx_f32x8_store(c_ptr + 4 * nsize, c4) - avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) - - avx_f32x8_store(c_ptr + 5 * nsize, c5) - avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) - - #### Some setup code #### - packed_b_height = KC - if packed_b_height > k_size: - packed_b_height = k_size - packed_b_width = NC - if packed_b_width > n_size: - packed_b_width = (n_size + NR - 1) // NR * NR - - packed_b_total_width = packed_b_width * loop5_nways - packed_b_total_size = packed_b_total_width * packed_b_height - packed_b_individual_size = packed_b_width * packed_b_height - - packed_a_individual_height = MC - if packed_a_individual_height > m_size: - packed_a_individual_height = (m_size + MR - 1) // MR * MR - packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed - - packed_a_width = KC - if packed_a_width > k_size: - packed_a_width = k_size - # pad this to be able to use the aligned version of the avx store - packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 - packed_a_total_size = packed_a_total_height * packed_a_width - packed_a_individual_size = packed_a_width * packed_a_individual_height - - packb_buf_ptr = module.define_global_var( - name='packb_buf_ptr', - var_type=float32[packed_b_total_size] - ) - packa_buf_ptr = module.define_global_var( - name='packa_buf_ptr', - var_type=float32[packed_a_total_size] - ) - - packb_buf = cast(packb_buf_ptr, ~float32) - packa_buf = cast(packa_buf_ptr, ~float32) - - ##### Start of the loops around micro kernel ##### - - @hidet.script - def gemm_pack_a( - loop3_partition_a: ~float32, - loop3_partition_a_width: int32, - loop3_partition_a_height: int32, - packed_a_buf: ~float32, - work_id_packa: int32, - ): - packed_a_tensor = as_tensor_pointer( - packed_a_buf, - float32, - layout=row_major(packed_a_individual_height // MR, 1) * - column_major(MR, packed_a_width) - ) - - npanels_full_a = loop3_partition_a_height // MR - panel_a_remainder = loop3_partition_a_height % MR - - npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) - - for ii_panel in range(npanels_a): - if ii_panel % packa_nthreads != work_id_packa % packa_nthreads: - continue - - a_curr_panel_row_start = ii_panel * MR - a_curr_panel_height = min(MR, - loop3_partition_a_height - a_curr_panel_row_start) - - if a_curr_panel_height == MR: # unroll the packing by 8 - k_iters = loop3_partition_a_width // 8 - k_remainder = loop3_partition_a_width % 8 - col = 0 - for k_iter in range(k_iters): - col = k_iter * 8 - a_curr_panel_col = loop3_partition_a + ( - a_curr_panel_row_start * k_size + col - ) - - v0 = avx_f32x8_load(a_curr_panel_col) - v1 = avx_f32x8_load(a_curr_panel_col + k_size) - v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) - v3 = avx_f32x8_load(a_curr_panel_col + (3 * k_size)) - v4 = avx_f32x8_load(a_curr_panel_col + (4 * k_size)) - v5 = avx_f32x8_load(a_curr_panel_col + (5 * k_size)) - - unpack0 = avx_f32x8_unpacklo(v0, v1) - unpack1 = avx_f32x8_unpackhi(v0, v1) - unpack2 = avx_f32x8_unpacklo(v2, v3) - unpack3 = avx_f32x8_unpackhi(v2, v3) - unpack4 = avx_f32x8_unpacklo(v4, v5) - unpack5 = avx_f32x8_unpackhi(v4, v5) - - shf0 = avx_f32x8_shuffle(unpack0, unpack2, 0x44) - shf1 = avx_f32x8_shuffle(unpack4, unpack0, 0xE4) - shf2 = avx_f32x8_shuffle(unpack2, unpack4, 0xEE) - shf3 = avx_f32x8_shuffle(unpack5, unpack1, 0xE4) - shf4 = avx_f32x8_shuffle(unpack3, unpack5, 0xEE) - shf5 = avx_f32x8_shuffle(unpack1, unpack3, 0x44) - - low_shf1 = avx_f32x8_cast_f32x4(shf1) - res0 = avx_f32x8_insert_f32x4(shf0, low_shf1, 0x1) - res1 = avx_f32x8_permute2f32x4(shf0, shf1, 0x31) - - low_shf5 = avx_f32x8_cast_f32x4(shf5) - res2 = avx_f32x8_insert_f32x4(shf2, low_shf5, 0x1) - res3 = avx_f32x8_permute2f32x4(shf2, shf5, 0x31) - - low_shf4 = avx_f32x8_cast_f32x4(shf4) - res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) - res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) - - # TODO: Now I changed to unaligned to debug... - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start, col], - res0 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 2, - col + 1], - res2 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 4, - col + 2], - res4) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start, - col + 4], - res1 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 2, - col + 5], - res3 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 4, - col + 6], - res5 - ) - remaining_start_col = k_iters * 8 - for remain_off in range(k_remainder): - curr_remain_col = remaining_start_col + remain_off - for micropanel_row in range(MR): - packed_a_tensor[ - a_curr_panel_row_start + micropanel_row, - curr_remain_col] = \ - loop3_partition_a[( - micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] - else: - remain_start_row = npanels_full_a * MR - for remain_col in range(loop3_partition_a_width): - for remain_row in range(panel_a_remainder): - packed_a_tensor[ - remain_start_row + remain_row, remain_col] = \ - loop3_partition_a[( - remain_row + remain_start_row) * k_size + remain_col] - remain_row = panel_a_remainder - while remain_row < MR: - packed_a_tensor[ - remain_start_row + remain_row, remain_col] = 0.0 - remain_row += 1 - - @hidet.script - def gemm_pack_b( - loop4_partition_b: ~float32, - loop4_partition_b_width: int32, - loop4_partition_b_height: int32, - packed_b_buf: ~float32, - work_id_packb: int32 - ): - npanels_full_b = loop4_partition_b_width // NR - npanels_b_remainder = loop4_partition_b_width % NR - - npanels_b = npanels_full_b + (1 if npanels_b_remainder != 0 else 0) - packedb_panel_stride = packed_b_height * NR - - # Loop for the packing of B - for i_panel in range(npanels_b): - if i_panel % packb_nthreads != work_id_packb % packb_nthreads: - continue - packed_b_buff_curr = packed_b_buf + ( - i_panel * packedb_panel_stride) - curr_panel_start = i_panel * NR - curr_panel_width = min(NR, - loop4_partition_b_width - curr_panel_start) - - if curr_panel_width == NR: - k_iters = loop4_partition_b_height // 8 - k_remainder = loop4_partition_b_height % 8 - - row = 0 - for k_iter in range(k_iters): - row = k_iter * 8 - b_panel = loop4_partition_b + ( - row * n_size + curr_panel_start) - b00 = avx_f32x8_load(b_panel) - b08 = avx_f32x8_load(b_panel + 8) - - avx_f32x8_store_aligned(packed_b_buff_curr, b00) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) - packed_b_buff_curr += 16 - - b10 = avx_f32x8_load(b_panel + n_size) - b18 = avx_f32x8_load(b_panel + (n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b10) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b18) - packed_b_buff_curr += 16 - - b20 = avx_f32x8_load(b_panel + (2 * n_size)) - b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b20) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b28) - packed_b_buff_curr += 16 - - b30 = avx_f32x8_load(b_panel + (3 * n_size)) - b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b30) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b38) - packed_b_buff_curr += 16 - - b40 = avx_f32x8_load(b_panel + (4 * n_size)) - b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b40) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b48) - packed_b_buff_curr += 16 - - b50 = avx_f32x8_load(b_panel + (5 * n_size)) - b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b50) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b58) - packed_b_buff_curr += 16 - - b60 = avx_f32x8_load(b_panel + (6 * n_size)) - b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b60) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b68) - packed_b_buff_curr += 16 - - b70 = avx_f32x8_load(b_panel + (7 * n_size)) - b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buff_curr, b70) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b78) - - packed_b_buff_curr += 16 - - row = k_iters * 8 - for _ in range(k_remainder): - b_panel = loop4_partition_b + ( - row * n_size + curr_panel_start) - b00 = avx_f32x8_load(b_panel) - b08 = avx_f32x8_load(b_panel + 8) - avx_f32x8_store_aligned(packed_b_buff_curr, b00) - avx_f32x8_store_aligned(packed_b_buff_curr + 8, b08) - packed_b_buff_curr += 16 - row += 1 - - else: - packed_b_remaining_buf = packed_b_buf + ( - npanels_full_b * packedb_panel_stride) - if npanels_b_remainder > 0: - # TODO: I think this if should always be true if this is executed? - remain_col_start = npanels_full_b * NR - for remain_row in range(loop4_partition_b_height): - packed_b_remaining_buf_curr = packed_b_remaining_buf + ( - remain_row * NR) - for remain_col in range(npanels_b_remainder): - packed_b_remaining_buf_curr[0] = \ - loop4_partition_b[ - (remain_row * n_size) + ( - remain_col_start + remain_col) - ] - packed_b_remaining_buf_curr += 1 - zero_fill_col = npanels_b_remainder - while zero_fill_col < NR: - packed_b_remaining_buf_curr[0] = 0.0 - packed_b_remaining_buf_curr += 1 - zero_fill_col += 1 - - gemm_pack_b.kind = "cpu_internal" - gemm_pack_a.kind = "cpu_internal" - micro_kernel.kind = "cpu_internal" - - @hidet.script - def gemm_macro( - packed_a: ~float32, - packed_b: ~float32, - c: float32[m_size, n_size], - c_row_off: int32, - c_col_off: int32, - macro_m: int32, - macro_n: int32, - macro_k: int32, - ps_packed_a: int32, - ps_packed_b: int32, - comm_id_macro: int32, - work_id_macro: int32, - is_first: bool - ): - comm_id_1st_loop = comm_id_macro % loop1_nthreads - work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways) - - n_iter = macro_n // NR - n_remainder = macro_n % NR - m_iter = macro_m // MR - m_remainder = macro_m % MR - - if n_remainder > 0: - n_iter += 1 - if m_remainder > 0: - m_iter += 1 - - jr_start = -1 - jr_end = -1 - ir_start = -1 - ir_end = -1 - jr_inc = -1 - ir_inc = -1 - - thread_range_jrir( - work_id_macro, - macro_nways, - n_iter, - 1, - ~jr_start, - ~jr_end, - ~jr_inc - ) - - thread_range_jrir( - work_id_1st_loop, - loop1_nways, - m_iter, - 1, - ~ir_start, - ~ir_end, - ~ir_inc - ) - - rstep_a = ps_packed_a - cstep_b = ps_packed_b - - cstep_c = NR - rstep_c = n_size * MR - - macro_c_cast = as_tensor_pointer( - ~c[c_row_off, c_col_off], - dtype=float32, - shape=(m_size, n_size) - ) - temp_c = tensor(scope=DeclareScope.Default, - dtype=float32, - layout=row_major(MR, NR), - is_static=False) - j = jr_start - while j < jr_end: - b1 = packed_b + j * cstep_b - c1 = macro_c_cast + j * cstep_c - n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder - - i = ir_start - while i < ir_end: - a1 = packed_a + i * rstep_a - c11 = c1 + i * rstep_c - c11 = as_tensor_pointer(c11, dtype=float32, shape=(m_size, n_size)) - m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - - if m_cur == MR and n_cur == NR: - micro_kernel(a1, b1, c11, macro_k, m_size, n_size, is_first) - else: - for i, j in grid(MR, NR): - temp_c[i, j] = 0.0 - micro_kernel(a1, b1, cast(temp_c, ~float32), macro_k, MR, NR, is_first) - if not is_first: - for mm, nn in grid(m_cur, n_cur): - c11[mm, nn] += temp_c[mm, nn] - else: - for mm, nn in grid(m_cur, n_cur): - c11[mm, nn] = temp_c[mm, nn] - - i += ir_inc - j += jr_inc - - gemm_macro.kind = "cpu_internal" - - @hidet.script - def gemm_3rd_loop( - a: float32[m_size, k_size], - packed_b: ~float32, - c: float32[m_size, n_size], - loop3_partition_a_start_col: int32, - loop3_partition_b_start_col: int32, - loop3_partition_a_width: int32, - loop3_partition_b_width: int32, - comm_id_3rd_loop: int32, - work_id_3rd_loop: int32, - is_first: bool, work_id_5th_loop: int32): - comm_id_macro = comm_id_3rd_loop % macro_nthreads - work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) - work_id_packa = comm_id_macro - - m_start_loop3 = 0 - m_end_loop3 = 0 - thread_range_sub( - loop3_nways, - work_id_3rd_loop, - m_size, - MR, - ~m_start_loop3, - ~m_end_loop3 - ) - - ii = m_start_loop3 - while ii < m_end_loop3: - b_alg_loop3 = determine_blocksize_f_sub( - ii, m_size, MC - ) - b_alg_loop3 = min(b_alg_loop3, m_end_loop3 - ii) - loop3_partition_a_start_row = ii - loop3_partition_a_height = b_alg_loop3 - - loop3_partition_a = cast(a, ~float32) + ( - loop3_partition_a_start_row * k_size + - loop3_partition_a_start_col - ) - - packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) - packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) - - - thrcomm_barrier( - ~packa_thrcomm_barrier_sense[packed_a_idx], - ~packa_thrcomm_threads_arrived[packed_a_idx], - packa_nthreads - ) - - gemm_pack_a( - loop3_partition_a, - loop3_partition_a_width, - loop3_partition_a_height, - packed_a_buf, - work_id_packa, - ) - - # This marks the end of the packing of A, - # so a barrier is needed - thrcomm_barrier( - ~packa_thrcomm_barrier_sense[packed_a_idx], - ~packa_thrcomm_threads_arrived[packed_a_idx], - packa_nthreads - ) - - gemm_macro(packed_a_buf, - packed_b, - c, - loop3_partition_a_start_row, - loop3_partition_b_start_col, - loop3_partition_a_height, - loop3_partition_b_width, - loop3_partition_a_width, - MR * packed_a_width, - packed_b_height * NR, - comm_id_macro, - work_id_macro, - is_first - ) - ii += b_alg_loop3 - - gemm_3rd_loop.kind = "cpu_internal" - - @hidet.script - def gemm_4th_loop(a: float32[m_size, k_size], - b: float32[k_size, n_size], - c: float32[k_size, n_size], - loop5_partition_b_width: int32, - loop5_partition_b_start_col: int32, - comm_id_4th_loop: int32, - work_id_5th_loop: int32): - i_loop4 = 0 - - comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads - work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) - work_id_packb = comm_id_3rd_loop - - while i_loop4 < k_size: - b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) - b_alg_loop4 = min(b_alg_loop4, k_size - i_loop4) - - loop4_partition_b_height = b_alg_loop4 - loop4_partition_b_width = loop5_partition_b_width - loop4_partition_b_start_row = i_loop4 - loop4_partition_b_start_col = loop5_partition_b_start_col - - loop4_partition_a_start_col = i_loop4 - is_first = (i_loop4 == 0) - - packed_b_buf = packb_buf + ( - packed_b_individual_size * work_id_5th_loop - ) - - loop4_partition_b = cast(b, ~float32) + \ - (loop4_partition_b_start_row * n_size + - loop4_partition_b_start_col) - - thrcomm_barrier( - ~packb_thrcomm_barrier_sense[work_id_5th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - packb_nthreads - ) - - gemm_pack_b(loop4_partition_b, loop4_partition_b_width, - loop4_partition_b_height, packed_b_buf, - work_id_packb) - - thrcomm_barrier( - ~packb_thrcomm_barrier_sense[work_id_5th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - packb_nthreads - ) - - gemm_3rd_loop( - a, packed_b_buf, c, - loop4_partition_a_start_col, - loop4_partition_b_start_col, - loop4_partition_b_height, - loop4_partition_b_width, - comm_id_3rd_loop, - work_id_3rd_loop, - is_first, - work_id_5th_loop - ) - - thrcomm_barrier( - ~packb_thrcomm_barrier_sense[work_id_5th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - packb_nthreads - ) - - i_loop4 += b_alg_loop4 - - gemm_4th_loop.kind = "cpu_internal" - - @hidet.script - def gemm_5th_loop(a: float32[m_size, k_size], - b: float32[k_size, n_size], - c: float32[m_size, n_size], - work_id_5th_loop: int32, - comm_id_5th_loop: int32): - comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads - - loop5_my_start = -1 - loop5_my_end = -1 - thread_range_sub(loop5_nways, work_id_5th_loop, n_size, - NR, ~loop5_my_start, ~loop5_my_end) - - loop5_iter = loop5_my_start - while loop5_iter < loop5_my_end: - b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, - loop5_my_end, NC) - b_alg_loop5 = min(b_alg_loop5, loop5_my_end - loop5_iter) - - loop5_partition_b_width = b_alg_loop5, - loop5_partition_b_start_col = loop5_iter - gemm_4th_loop(a, b, c, - loop5_partition_b_width, - loop5_partition_b_start_col, - comm_id_4th_loop, - work_id_5th_loop) - loop5_iter += b_alg_loop5 - - gemm_5th_loop.kind = 'cpu_internal' - - ################### Start of the main kernel ################### - @hidet.script - def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], - c: float32[m_size, n_size]): - - init_thr(packa_thrcomm_barrier_sense, - packa_thrcomm_threads_arrived, - loop3_nways) - init_thr(packb_thrcomm_barrier_sense, - packb_thrcomm_barrier_threads_arrived, - loop5_nways) - - parallel_attr = 'p' + str(nthreads) - # The outermost loop spawning threads - for tidx in grid(nthreads, attrs=parallel_attr): - tid_5th_loop = tidx - work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) - comm_id_5th_loop = tid_5th_loop - - gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) - - assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) - matmul_kernel_x86_v3.kind = "cpu_kernel" - ir_module = module.ir_module() - return ir_module - - -class Matmulx86Op_refactored(Operator): - def __init__(self, a: Tensor, b: Tensor): - if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): - raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) - task = MatmulF32Taskx86_refactored(input_like(a, 'a'), input_like(b, 'b')) - super().__init__(inputs=[a, b], attributes={}, task=task) - - -def matmul_x86_refactored(a: Tensor, b: Tensor) -> Tensor: - return Matmulx86Op_refactored(a, b).outputs[0] diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py b/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py deleted file mode 100644 index b65690138..000000000 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86_v3.py +++ /dev/null @@ -1,974 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Union -from hidet.ir.dtypes import float32, int32 -from hidet.ir.expr import cast -from hidet.ir.module import IRModule -from hidet.ir.compute import TensorNode -from hidet.ir.primitives import avx_malloc -from hidet.ir.primitives.cpu import avx_f32x8_setzero, avx_f32x8_load_aligned -from hidet.ir.stmt import DeclareScope -from hidet.ir.task import Task -from hidet.ir.compute import compute, reduce -from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast -from hidet.ir.library import tune -from hidet.graph.operator import Operator, Tensor -from hidet.graph.ops.utils import broadcast_indices - - -class MatmulF32Taskx86_v2(Task): - - def __init__(self, a: TensorNode, b: TensorNode): - a_shape = a.const_shape - b_shape = b.const_shape - - if not a.type.dtype == float32 or not b.type.dtype == float32: - raise ValueError('Both inputs must be float32 tensors') - - if len(a_shape) < 2 or len(b_shape) < 2: - raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) - - self._assert( - a_shape[-1] == b_shape[-2], - msg=( - 'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' - ', got {} and {}'.format(a_shape, b_shape) - ), - ) - - self._assert( - can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), - msg=( - 'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' - 'got {} and {}'.format(a_shape, b_shape) - ), - ) - - k_size = a_shape[-1] - c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] - - c = compute( - name='c', - shape=c_shape, - fcompute=lambda *indices: reduce( - shape=[k_size], - fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] - * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], - reduce_type='sum', - ), - ) - - super().__init__( - name='matmul_f32_x86_v2', - inputs=[a, b], - outputs=[c], - attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, - ) - - def allow_epilogue(self) -> bool: - return True - - def allow_prologue(self) -> bool: - return False - - def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: - return tune.extract_ir_modules(self.schedule_matmulf32_x86) - - # @tune.space( - # 2, - # block_m=[2016, 3024], - # block_n=[64, 144, 192, 256, 384, 512, 592, 672, 752, 896, 1024], - # block_k=[96, 128, 256, 384, 512, 560, 688, 784], - # nthreads=[4, 8, 16, 32], - # ) - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], nthreads=[8, 16]) - def schedule_matmulf32_x86( - self, MC=2016, NC=896, KC=512, ways=(1, 8, 4, 1) - ) -> IRModule: - import hidet - from hidet.ir.type import tensor_type - from hidet.lang import tensor, grid, as_tensor_pointer - from hidet.lang.layout import row_major, column_major - from hidet.lang.cpu import avx_f32x8_store, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_broadcast - from hidet.lang.cpu import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store - from hidet.lang.cpu import avx_f32x8_store_aligned, avx_f32x8_load_aligned - from hidet.lang.cpu import avx_f32x4_store_aligned, avx_f32x4_load_aligned - from hidet.lang.cpu import avx_f32x8_unpacklo, avx_f32x8_unpackhi - from hidet.lang.cpu import avx_f32x8_shuffle, avx_f32x8_cast_f32x4 - from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 - from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor - - node_a, node_b = self.inputs[0], self.inputs[1] - a_shape = node_a.const_shape - b_shape = node_b.const_shape - m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] - - MR, NR = 6, 16 - - tune.check(MC % MR == NC % NR == 0, 'Tile size must divide the corresponding block size') - - packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC)) - packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR)) - - # Get the number of threads... - loop5_nways, loop3_nways, macro_nways, loop1_nways = ways - loop4_nways = 1 - nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways - - # Get the number of threads remaining at each level - loop5_nthreads = nthreads - loop4_nthreads = loop5_nthreads // loop5_nways - loop3_nthreads = loop4_nthreads - macro_nthreads = loop3_nthreads // loop3_nways - loop1_nthreads = macro_nthreads // macro_nways - - packb_nthreads = loop3_nthreads - packa_nthreads = macro_nthreads - - # TODO: Since Hidet doesn't support the parallel region syntax as in OpenMP, - # TODO: We instead use a loop to simulate the parallel region, with the "thread id" being the loop index. - outermost_iters = nthreads - - loop5_thrcomm_barrier_sense = 0 - loop5_thrcomm_barrier_threads_arrived = 0 - - packb_thrcomm_barrier_sense = tensor('int32', shape=[loop4_nways], is_static=True) - # for idx in range(loop4_nways): - # packb_thrcomm_barrier_sense[idx] = 0 TODO: This shouldn't be necessary, as static arrays are 0-initialized - packb_thrcomm_barrier_threads_arrived = tensor('int32', shape=[loop4_nways], is_static=True) - - packa_thrcomm_barrier_sense = tensor('int32', shape=[loop3_nways], is_static=True) - packa_thrcomm_threads_arrived = tensor('int32', shape=[loop3_nways], is_static=True) - - # The buffer for storing the starting offset of the packed B buffers for thread, - # indexed by the work ID of Loop5 - packb_start_offsets = tensor('int32', shape=[loop5_nways], is_static=True) - # The buffer for storing the starting offset of the packed A buffers for thread, - # indexed by the work ID of Loop3 - packa_start_offsets = tensor('int32', shape=[loop3_nways], is_static=True) - - # The array to store the needed size for each packed B buffer, indexed by the work ID of Loop5 - packb_sizes = tensor('int32', shape=[loop5_nways], is_static=True) - # The array to store the needed size for each packed A buffer, indexed by the work ID of Loop3 - packa_sizes = tensor('int32', shape=[loop3_nways], is_static=True) - - with hidet.script_module() as module: - # Helpers - @hidet.script - def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): - if n_way == 1: - start[0] = 0 - end[0] = n - return - all_start = 0 - all_end = n - size = all_end - all_start - - n_bf_whole = size // bf - n_bf_left = size % bf - - n_bf_lo = n_bf_whole // n_way - n_bf_hi = n_bf_whole // n_way - - n_th_lo = n_bf_whole % n_way - # If some partitions must have more block_factors than others, assign the slightly larger partitions to lower index threads - if n_th_lo != 0: - n_bf_lo += 1 - # Compute the actual widths (in units of rows/columns) of individual threads in the low and high groups - size_lo = n_bf_lo * bf - size_hi = n_bf_hi * bf - - # Pre-compute the starting indices of the low and high groups - lo_start = all_start - hi_start = all_start + n_th_lo * size_lo - - # Compute the start and end of individual threads' ranges - if work_id < n_th_lo: - start[0] = lo_start + work_id * size_lo - end[0] = lo_start + (work_id + 1) * size_lo - else: - start[0] = hi_start + (work_id - n_th_lo) * size_hi - end[0] = hi_start + (work_id - n_th_lo + 1) * size_hi - - # Add the remainder to the last thread's end - if work_id == n_way - 1: - end[0] += n_bf_left - - @hidet.script - def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, - start: ~int32, end: ~int32, inc: ~int32): - start[0] = work_id - end[0] = n - inc[0] = n_way - - @hidet.script - def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: - dim_left_now = dim - i - if dim_left_now <= b_alg: - b_now = dim_left_now - else: - b_now = b_alg - return b_now - - @hidet.script - def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: - return i != n_iter - 1 or n_left == 0 - - # Thread barrier - def thrcomm_barrier(tid: int32, barrier_sense: ~int32, - barrier_threads_arrived: ~int32, nthreads: int32): - if nthreads == 1: - return - orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED - - # Register the current thread's arrival by incrementing the counter - my_threads_arrived = cpu_atomic_add_fetch( - barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL - - if my_threads_arrived == nthreads: - barrier_threads_arrived[0] = 0 - cpu_atomic_fetch_xor(barrier_sense, 1, 3) # _ATOMIC_RELEASE - else: - while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE - pass - - @hidet.script - def micro_kernel( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, - is_first: bool - ): - c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) - c0 = avx_f32x8_load(~c[0, 0]) - c08 = avx_f32x8_load(~c[0, 8]) - c1 = avx_f32x8_load(~c[1, 0]) - c18 = avx_f32x8_load(~c[1, 8]) - c2 = avx_f32x8_load(~c[2, 0]) - c28 = avx_f32x8_load(~c[2, 8]) - c3 = avx_f32x8_load(~c[3, 0]) - c38 = avx_f32x8_load(~c[3, 8]) - c4 = avx_f32x8_load(~c[4, 0]) - c48 = avx_f32x8_load(~c[4, 8]) - c5 = avx_f32x8_load(~c[5, 0]) - c58 = avx_f32x8_load(~c[5, 8]) - - if is_first: - c0 = avx_f32x8_setzero() - c08 = avx_f32x8_setzero() - c1 = avx_f32x8_setzero() - c18 = avx_f32x8_setzero() - c2 = avx_f32x8_setzero() - c28 = avx_f32x8_setzero() - c3 = avx_f32x8_setzero() - c38 = avx_f32x8_setzero() - c4 = avx_f32x8_setzero() - c48 = avx_f32x8_setzero() - c5 = avx_f32x8_setzero() - c58 = avx_f32x8_setzero() - a_ptr = cast(a, ~float32) - b_ptr = cast(b, ~float32) - - niters = msize // 4 - nleft = msize % 4 - # Outer iterations with step 4 - for _ in range(niters): - # First of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - - aa = avx_f32x8_broadcast(a_ptr) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 1) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 2) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 3) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 4) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 5) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Second of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr + 16) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 24) - - aa = avx_f32x8_broadcast(a_ptr + 6) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 7) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 8) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 9) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 10) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 11) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Third of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr + 32) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 40) - - aa = avx_f32x8_broadcast(a_ptr + 12) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 13) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 14) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 15) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 16) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 17) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Fourth of the 4 unrolled iterations - bb0to7 = avx_f32x8_load_aligned(b_ptr + 48) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 56) - - aa = avx_f32x8_broadcast(a_ptr + 18) - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 19) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 20) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 21) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 22) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 23) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - # Increment the a_ptr and b_ptr for the next iteration of the outermost loop - a_ptr += 24 - b_ptr += 64 - - # process the edge - for _ in range(nleft): - aa = avx_f32x8_broadcast(a_ptr) - bb0to7 = avx_f32x8_load_aligned(b_ptr) - bb8to15 = avx_f32x8_load_aligned(b_ptr + 8) - - c0 = avx_f32x8_fmadd(aa, bb0to7, c0) - c08 = avx_f32x8_fmadd(aa, bb8to15, c08) - - aa = avx_f32x8_broadcast(a_ptr + 1) - c1 = avx_f32x8_fmadd(aa, bb0to7, c1) - c18 = avx_f32x8_fmadd(aa, bb8to15, c18) - - aa = avx_f32x8_broadcast(a_ptr + 2) - c2 = avx_f32x8_fmadd(aa, bb0to7, c2) - c28 = avx_f32x8_fmadd(aa, bb8to15, c28) - - aa = avx_f32x8_broadcast(a_ptr + 3) - c3 = avx_f32x8_fmadd(aa, bb0to7, c3) - c38 = avx_f32x8_fmadd(aa, bb8to15, c38) - - aa = avx_f32x8_broadcast(a_ptr + 4) - c4 = avx_f32x8_fmadd(aa, bb0to7, c4) - c48 = avx_f32x8_fmadd(aa, bb8to15, c48) - - aa = avx_f32x8_broadcast(a_ptr + 5) - c5 = avx_f32x8_fmadd(aa, bb0to7, c5) - c58 = avx_f32x8_fmadd(aa, bb8to15, c58) - - a_ptr += 6 - b_ptr += 16 - - # Store the results - avx_f32x8_store(c_ptr, c0) - avx_f32x8_store(c_ptr + 8, c08) - - avx_f32x8_store(c_ptr + nsize, c1) - avx_f32x8_store(c_ptr + (nsize + 8), c18) - - avx_f32x8_store(c_ptr + 2 * nsize, c2) - avx_f32x8_store(c_ptr + (2 * nsize + 8), c28) - - avx_f32x8_store(c_ptr + 3 * nsize, c3) - avx_f32x8_store(c_ptr + (3 * nsize + 8), c38) - - avx_f32x8_store(c_ptr + 4 * nsize, c4) - avx_f32x8_store(c_ptr + (4 * nsize + 8), c48) - - avx_f32x8_store(c_ptr + 5 * nsize, c5) - avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) - - @hidet.script - def macro_kernel( - a: packed_a_type, b: packed_b_type, c_in_macro: float32[m_size, n_size], ib: int32, jb: int32, - pb: int32 - ): - return - - #### Some setup code #### - packed_b_total_width = 0 - for workid_loop5 in range(loop5_nways): - loop5_start = 0 - loop5_end = 0 - thread_range_sub(loop5_nways, workid_loop5, n_size, NR, ~loop5_start, ~loop5_end) - curr_width = loop5_end - loop5_start - # packed_b_total_width += curr_width - # packb_start_offsets[workid_loop5] = temp_prev - # temp_prev += curr_width - packb_start_offsets[workid_loop5] = packed_b_total_width - packed_b_total_width += curr_width - - packed_b_height = KC - if packed_b_height > k_size: - packed_b_height = (k_size + NR - 1) // NR * NR - packed_b_total_size = packed_b_total_width * packed_b_height - - a_height_mr_partitions = (m_size + MR - 1) // MR - a_height_mr_remainder = m_size % MR - packed_a_individual_height = MC - packed_a_total_height = packed_a_individual_height * loop3_nways - # if packed_a_total_height > m_size: - # packed_a_total_height = a_height_mr_partitions * MR - packed_a_width = KC - if packed_a_width > k_size: - packed_a_width = (k_size + MR - 1) // MR * MR - packed_a_total_size = packed_a_total_height * packed_a_width - packed_a_individual_size = packed_a_width * packed_a_individual_height - - packb_buf_ptr = avx_malloc(packed_b_total_size * 4, 4096) - packa_buf_ptr = avx_malloc(packed_a_total_size * 4, 4096) - - packb_buf = as_tensor_pointer(packb_buf_ptr, dtype=float32, shape=[packed_b_total_size]) - packa_buf = as_tensor_pointer(packa_buf_ptr, dtype=float32, shape=[packed_a_total_size]) - - packed_a_type = tensor_type( - dtype='float32', - layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) - ) - - ################### Start of the main kernel ################### - @hidet.script - def matmul_kernel_x86_v2(a: float32[m_size, k_size], b: float32[k_size, n_size], - c: float32[m_size, n_size]): - b_width_nr_partitions = (n_size + NR - 1) // NR - b_width_nr_remainder = n_size % NR - # TODO: Since we(they, BLIS) use a memory broker... Allocate a little more memory is OK I think??? - # packed_b_individual_width = NC - - parallel_attr = 'p' + str(nthreads) - # The outermost loop spawning threads - for tidx in grid(nthreads, attrs=parallel_attr): - tid_5th_loop = tidx - work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) - comm_id_5th_loop = tid_5th_loop - - # Before each loop, we compute the work id and comm id for the loop after it. - comm_id_4th_loop = comm_id_5th_loop % loop4_nways - work_id_4th_loop = comm_id_4th_loop // (loop4_nthreads // loop4_nways) - - my_start = -1 - my_end = -1 - b_alg_loop5 = NC - thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~my_start, ~my_end) - loop5_iter = my_start - while loop5_iter < my_end: - b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, my_end, NC) - - loop5_partition_c_width = b_alg_loop5 - loop5_partition_c_start_col = loop5_iter - - loop5_partition_b_width = b_alg_loop5 - loop5_partition_b_start_col = loop5_iter - - comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads - work_id_3rd_loop = comm_id_3rd_loop // (loop3_nthreads // loop3_nways) - - # After getting the communicator and work id for the 3rd loop, - # we can now get the ids for the packing of B. - comm_id_packb = comm_id_3rd_loop - work_id_packb = comm_id_3rd_loop - packb_nways = loop3_nthreads - - # Below: The start of loop4 - b_alg_loop4 = KC - i_loop4 = 0 - while i_loop4 < k_size: - b_alg_loop4 = determine_blocksize_f_sub(i_loop4, k_size, KC) - loop4_partition_b_height = b_alg_loop4 - loop4_partition_b_width = loop5_partition_b_width - loop4_partition_b_start_row = i_loop4 - loop4_partition_b_start_col = loop5_partition_b_start_col - - loop4_partition_a_start_col = i_loop4 - - is_first = (i_loop4 == 0) - - # Get the thread's partition of buffer and matrix - packed_b_buf = packb_buf + ( - packb_start_offsets[work_id_5th_loop] * packed_b_height) # TODO: Check this - loop4_partition_b = b + (loop4_partition_b_start_row * n_size + loop4_partition_b_start_col) - npanels_full_b = loop4_partition_b_width // NR - npanels_b_remainder = loop4_partition_b_width % NR - - npanels_b = npanels_full_b + (npanels_b_remainder != 0) - packedb_panel_stride = packed_b_height * NR - - # TODO: If passed, see if this barrier is really needed - thrcomm_barrier( - comm_id_packb, - ~packb_thrcomm_barrier_sense[work_id_4th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], - packb_nthreads # TODO: Check this last parameter - ) - - # Start of the packing of B - for i_panel in range(npanels_b): - if i_panel % packb_nways != work_id_packb % packb_nways: - continue - packed_b_buf_curr = packed_b_buf + (i_panel * packedb_panel_stride) - - curr_panel_start = i_panel * NR - curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) - if curr_panel_width == NR: - k_iters = loop4_partition_b_height // 8 - k_remainder = loop4_partition_b_height % 8 - row = 0 - for k_iter in range(k_iters): - row = k_iter * 8 - b_panel = loop4_partition_b + (row * n_size + curr_panel_start) - b00 = avx_f32x8_load(b_panel) - b08 = avx_f32x8_load(b_panel + 8) - - avx_f32x8_store_aligned(packed_b_buf_curr, b00) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b08) - packed_b_buf_curr += 16 - - b10 = avx_f32x8_load(b_panel + n_size) - b18 = avx_f32x8_load(b_panel + (n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b10) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b18) - packed_b_buf_curr += 16 - - b20 = avx_f32x8_load(b_panel + (2 * n_size)) - b28 = avx_f32x8_load(b_panel + (2 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b20) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b28) - - packed_b_buf_curr += 16 - - b30 = avx_f32x8_load(b_panel + (3 * n_size)) - b38 = avx_f32x8_load(b_panel + (3 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b30) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b38) - - packed_b_buf_curr += 16 - - b40 = avx_f32x8_load(b_panel + (4 * n_size)) - b48 = avx_f32x8_load(b_panel + (4 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b40) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b48) - - packed_b_buf_curr += 16 - - b50 = avx_f32x8_load(b_panel + (5 * n_size)) - b58 = avx_f32x8_load(b_panel + (5 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b50) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b58) - - packed_b_buf_curr += 16 - - b60 = avx_f32x8_load(b_panel + (6 * n_size)) - b68 = avx_f32x8_load(b_panel + (6 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b60) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b68) - - packed_b_buf_curr += 16 - - b70 = avx_f32x8_load(b_panel + (7 * n_size)) - b78 = avx_f32x8_load(b_panel + (7 * n_size + 8)) - - avx_f32x8_store_aligned(packed_b_buf_curr, b70) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b78) - - packed_b_buf_curr += 16 - - row += 8 - for remaining_row in range(k_remainder): - b_panel = loop4_partition_b + (row * n_size + curr_panel_start) - b00 = avx_f32x8_load(b_panel) - b08 = avx_f32x8_load(b_panel + 8) - - avx_f32x8_store_aligned(packed_b_buf_curr, b00) - avx_f32x8_store_aligned(packed_b_buf_curr + 8, b08) - packed_b_buf_curr += 16 - row += 1 - else: - packed_b_remaining_buf = packed_b_buf + (npanels_full_b * packedb_panel_stride) - if npanels_b_remainder > 0: - remain_col_start = npanels_full_b * NR - for remain_row in range(loop4_partition_b_height): - packed_b_remaining_buf_curr = packed_b_remaining_buf + (remain_row * NR) - for remain_col in range(npanels_b_remainder): - packed_b_remaining_buf_curr[0] = loop4_partition_b[ - (remain_row * n_size) + (remain_col_start + remain_col)] - packed_b_remaining_buf_curr += 1 - zero_fill_col = npanels_b_remainder - while zero_fill_col < NR: - packed_b_remaining_buf_curr[0] = 0.0 - packed_b_remaining_buf_curr += 1 - zero_fill_col += 1 - - # The barrier at the end of packing - thrcomm_barrier(comm_id_packb, - ~packb_thrcomm_barrier_sense[work_id_4th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], - packb_nthreads - ) - - # TODO: Loop 3 should start here! - # Loop 3 - comm_id_macro = work_id_3rd_loop % macro_nthreads - work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) - - comm_id_packa = comm_id_macro - work_id_packa = comm_id_macro - packa_nways = macro_nthreads - - m_start_loop3 = 0 - m_end_loop3 = 0 - # Partition the dimension M for loop 3 - thread_range_sub( - loop3_nways, - work_id_3rd_loop, - m_size, - MR, - ~m_start_loop3, - ~m_end_loop3 - ) - - b_alg_loop3 = -1 - ii = m_start_loop3 - while ii < m_end_loop3: - b_alg_loop3 = determine_blocksize_f_sub(ii, m_size, MC) - - # Acquire the partition at Loop 3 - loop3_partition_c_start_row = ii - loop3_partition_a_start_row = ii - - loop3_partition_a_start_col = loop4_partition_a_start_col - loop3_partition_b_start_col = loop4_partition_b_start_col - loop3_partition_c_start_col = loop4_partition_b_start_col - - loop3_partition_height = b_alg_loop3 - # TODO: Is this right? - loop3_partition_a_width = loop4_partition_b_height - loop3_partition_b_width = loop4_partition_b_width - loop3_partition_a_height = b_alg_loop3 - loop3_partition_c_height = b_alg_loop3 - - loop3_partition_a = a + ( - loop3_partition_a_start_row * k_size + loop3_partition_a_start_col) - npanels_full_a = loop3_partition_a_height // MR - panel_a_remainder = loop3_partition_a_height % MR - - npanels_a = npanels_full_a + (1 if panel_a_remainder > 0 else 0) - packeda_panel_stride = MR * loop3_partition_a_width - - # Get our position within the packed A global buffer - packed_a_buf = packa_buf + (work_id_packa * packed_a_individual_size) - packed_a_tensor = as_tensor_pointer( - packed_a_buf, - float32, - layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width) - ) - - thrcomm_barrier( - comm_id_packa, - ~packa_thrcomm_barrier_sense[work_id_3rd_loop], - ~packa_thrcomm_threads_arrived[work_id_3rd_loop], - packa_nthreads - ) - - # Pack A - for ii_panel in range(npanels_a): - if ii_panel % packa_nways != work_id_packa % packa_nways: - continue - a_curr_panel_row_start = ii_panel * MR - a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) - - if a_curr_panel_height == MR: # we unroll the packing by 8 - k_iters = loop3_partition_a_width // 8 - k_remainder = loop3_partition_a_width % 8 - col = 0 - for k_iter in range(k_iters): - col = k_iter * 8 - a_curr_panel_col = loop3_partition_a + ( - a_curr_panel_row_start * k_size + col) - v0 = avx_f32x8_load(a_curr_panel_col) - v1 = avx_f32x8_load(a_curr_panel_col + k_size) - v2 = avx_f32x8_load(a_curr_panel_col + (2 * k_size)) - v3 = avx_f32x8_load(a_curr_panel_col + (3 * k_size)) - v4 = avx_f32x8_load(a_curr_panel_col + (4 * k_size)) - v5 = avx_f32x8_load(a_curr_panel_col + (5 * k_size)) - - unpack0 = avx_f32x8_unpacklo(v0, v1) - unpack1 = avx_f32x8_unpackhi(v0, v1) - unpack2 = avx_f32x8_unpacklo(v2, v3) - unpack3 = avx_f32x8_unpackhi(v2, v3) - unpack4 = avx_f32x8_unpacklo(v4, v5) - unpack5 = avx_f32x8_unpackhi(v4, v5) - - shf0 = avx_f32x8_shuffle(unpack0, unpack2, 0x44) - shf1 = avx_f32x8_shuffle(unpack4, unpack0, 0xE4) - shf2 = avx_f32x8_shuffle(unpack2, unpack4, 0xEE) - shf3 = avx_f32x8_shuffle(unpack5, unpack1, 0xE4) - shf4 = avx_f32x8_shuffle(unpack3, unpack5, 0xEE) - shf5 = avx_f32x8_shuffle(unpack1, unpack3, 0x44) - - low_shf1 = avx_f32x8_cast_f32x4(shf1) - res0 = avx_f32x8_insert_f32x4(shf0, low_shf1, 0x1) - res1 = avx_f32x8_permute2f32x4(shf0, shf1, 0x31) - - low_shf5 = avx_f32x8_cast_f32x4(shf5) - res2 = avx_f32x8_insert_f32x4(shf2, low_shf5, 0x1) - res3 = avx_f32x8_permute2f32x4(shf2, shf5, 0x31) - - low_shf4 = avx_f32x8_cast_f32x4(shf4) - res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) - res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) - - avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start, col], res0) - avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 2, col + 1], res2) - avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 4, col + 2], res4) - avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start, col + 4], res1) - avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 2, col + 5], res3) - avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 4, col + 6], res5) - - remaining_start_col = k_iters * 8 - for remain_off in range(k_remainder): - curr_remain_col = remaining_start_col + remain_off - for micropanel_row in range(MR): - packed_a_tensor[a_curr_panel_row_start + micropanel_row, curr_remain_col] = \ - loop3_partition_a[(micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] - else: - remain_start_row = npanels_a * MR - for remain_col in range(loop3_partition_a_width): - for remain_row in range(panel_a_remainder): - packed_a_tensor[remain_start_row + remain_row, remain_col] = \ - loop3_partition_a[(remain_row + remain_start_row) * k_size + remain_col] - remain_row = panel_a_remainder - while remain_row < MR: - packed_a_tensor[remain_start_row + remain_row, remain_col] = 0 - remain_row += 1 - - # This marks the end of the packing of A, or so I wish - # Now let's go to the macrokernel - # But first, barrier... - thrcomm_barrier( - comm_id_packa, - ~packa_thrcomm_barrier_sense[work_id_3rd_loop], - ~packa_thrcomm_threads_arrived[work_id_3rd_loop], - packa_nthreads - ) - - comm_id_1st_loop = comm_id_macro % loop1_nthreads - work_id_1st_loop = comm_id_macro // (loop1_nthreads // loop1_nways) - - jr_nt = macro_nways - jr_tid = work_id_macro - ir_nt = loop1_nways - ir_tid = work_id_1st_loop - - jr_start = -1 - jr_end = -1 - ir_start = -1 - ir_end = -1 - jr_inc = -1 - ir_inc = -1 - - macro_m = loop3_partition_a_height - macro_n = loop3_partition_b_width - macro_k = loop3_partition_a_width - - n_iter = macro_n // NR - n_remainder = macro_n % NR - m_iter = macro_m // MR - m_remainder = macro_m % MR - - if n_remainder > 0: - n_iter += 1 - if m_remainder > 0: - m_iter += 1 - - thread_range_jrir( - work_id_macro, - macro_nways, - n_iter, - 1, - ~jr_start, - ~jr_end, - ~jr_inc - ) - - thread_range_jrir( - work_id_1st_loop, - loop1_nways, - m_iter, - 1, - ~ir_start, - ~ir_end, - ~ir_inc - ) - - # Some variables as in the original code... - # TODO: There must be some useless ones, delete after passing tests - rs_packeda = 1 - cs_packeda = MR - panel_dim_packeda = MR - ps_packed_a = packeda_panel_stride - rs_packedb = MR - cs_packedb = 1 - ps_packed_b = packedb_panel_stride - - rstep_a = ps_packed_a - cstep_b = ps_packed_b - - rstep_c = rs_packeda * MR - - cstep_c = NR - rstep_c = n_size * MR - - macro_c_cast = as_tensor_pointer( - ~c[loop3_partition_a_start_row, loop3_partition_b_start_col], - dtype=float32, - shape=(m_size, n_size) - ) - - temp_c = tensor( - scope=DeclareScope.Default, - dtype='float32', - layout=row_major(MR, NR), - is_static=True - ) - - - j = jr_start - while j < jr_end: - b1 = packed_b_buf + j * cstep_b - c1 = macro_c_cast + j * cstep_c - - n_cur = NR if not_edge(j, n_iter, n_remainder) else n_remainder - b2 = b1 - # Loop over the m dimension, MR rows at a time - i = ir_start - while i < ir_end: - a1 = packed_a_buf + i * rstep_a - c11 = c1 + i * rstep_c - - m_cur = MR if not_edge(i, m_iter, m_remainder) else m_remainder - - if m_cur == MR and n_cur == NR: - micro_kernel(a1, b1, c11, macro_k, macro_m, macro_n, is_first) - else: - for i in range(MR): - for j in range(NR): - temp_c[i, j] = 0 - micro_kernel(a1, b1, temp_c, macro_k, macro_m, macro_n, is_first) - if not is_first: - for mm in range(m_cur): - for nn in range(n_cur): - c11[mm, nn] += temp_c[mm, nn] - else: - for mm in range(m_cur): - for nn in range(n_cur): - c11[mm, nn] = temp_c[mm, nn] - i += ir_inc - j += jr_inc - ii += b_alg_loop3 - # End of loop4 - # According to the original code, there seems to be a barrier here - # TODO: Looks weird, check later, especially about whether it's really the ids of packb - # arrays that are used here - thrcomm_barrier( - comm_id_packb, - ~packb_thrcomm_barrier_sense[work_id_4th_loop], - ~packb_thrcomm_barrier_threads_arrived[work_id_4th_loop], - packb_nthreads - ) - i_loop4 += b_alg_loop4 - # End of loop5 - loop5_iter += b_alg_loop5 - - return - - assert isinstance(matmul_kernel_x86_v2, hidet.ir.Function) - matmul_kernel_x86_v2.kind = "cpu_kernel" - ir_module = module.ir_module() - return ir_module - - -class Matmulx86Op_v2(Operator): - def __init__(self, a: Tensor, b: Tensor): - if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): - raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) - task = MatmulF32Taskx86_v2(input_like(a, 'a'), input_like(b, 'b')) - super().__init__(inputs=[a, b], attributes={}, task=task) - - -def matmul_x86_v2(a: Tensor, b: Tensor) -> Tensor: - return Matmulx86Op_v2(a, b).outputs[0] From 656bbd02d163ec0d3b9e981c89f3728378ef00cc Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 16 Nov 2023 22:05:58 -0500 Subject: [PATCH 132/148] ...... --- python/hidet/backend/codegen.py | 4 +- python/hidet/graph/ops/__init__.py | 1 - python/hidet/graph/ops/matmul/__init__.py | 8 +- .../hidet/graph/ops/matmul/matmul_f32_x86.py | 429 +++++++----------- python/hidet/ir/primitives/cpu/__init__.py | 1 - python/hidet/ir/primitives/cpu/atomic.py | 19 +- python/hidet/ir/primitives/cpu/avx.py | 21 +- python/hidet/lang/cpu.py | 2 +- python/mat_new.py | 4 +- tests/operators/test_matmul.py | 2 +- 10 files changed, 195 insertions(+), 296 deletions(-) diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index 4ee704db9..4e43e4260 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -496,9 +496,7 @@ def visit_ForStmt(self, stmt: ForStmt): doc += NewLine() + '#pragma unroll' elif stmt.attr.parallel: if stmt.attr.parallel_threads: - doc += NewLine() + '#pragma omp parallel for num_threads({})'.format( - stmt.attr.parallel_threads - ) + doc += NewLine() + '#pragma omp parallel for num_threads({})'.format(stmt.attr.parallel_threads) else: doc += NewLine() + '#pragma omp parallel for' doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') ' diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 54ce4c657..8a2fcd5f3 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -11,7 +11,6 @@ # limitations under the License. # pylint: disable=redefined-builtin from .matmul import batch_matmul, matmul, matmul_x86 -from .matmul import matmul_x86_refactored from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 diff --git a/python/hidet/graph/ops/matmul/__init__.py b/python/hidet/graph/ops/matmul/__init__.py index c33a08573..a1e8c0be5 100644 --- a/python/hidet/graph/ops/matmul/__init__.py +++ b/python/hidet/graph/ops/matmul/__init__.py @@ -13,10 +13,6 @@ from .batch_matmul import batch_matmul, BatchMatmulOp, BatchMatmulTask from . import resolve -from .matmul_f32_x86 import matmul_x86 - -from .matmul_f32_x86 import MatmulF32Taskx86, Matmulx86Op - -from .matmul_f32_x86_refactored import Matmulx86Op_refactored, MatmulF32Taskx86_refactored -from .matmul_f32_x86_refactored import matmul_x86_refactored +from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86 +from .matmul_f32_x86 import matmul_x86 diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index 28c355fb2..ba08c63d9 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -23,8 +23,7 @@ from hidet.graph.ops.utils import broadcast_indices -class MatmulF32Taskx86_refactored(Task): - +class MatmulF32Taskx86(Task): def __init__(self, a: TensorNode, b: TensorNode): a_shape = a.const_shape b_shape = b.const_shape @@ -60,7 +59,7 @@ def __init__(self, a: TensorNode, b: TensorNode): fcompute=lambda *indices: reduce( shape=[k_size], fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] - * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], + * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], reduce_type='sum', ), ) @@ -82,9 +81,7 @@ def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) - def schedule_matmulf32_x86( - self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1) - ) -> IRModule: + def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> IRModule: import hidet from hidet.ir.type import tensor_type from hidet.lang import tensor, grid, as_tensor_pointer @@ -108,25 +105,20 @@ def schedule_matmulf32_x86( with hidet.script_module() as module: # Get the number of threads... loop5_nways, loop3_nways, macro_nways, loop1_nways = ways - loop4_nways = 1 nthreads = loop5_nways * loop3_nways * macro_nways * loop1_nways packa_thrcomm_barrier_sense = module.define_global_var( - name="pack_a_barrier_sense", - var_type=int32[nthreads] + name="pack_a_barrier_sense", var_type=int32[nthreads] ) packa_thrcomm_threads_arrived = module.define_global_var( - name="pack_a_threads_arrived", - var_type=int32[nthreads] + name="pack_a_threads_arrived", var_type=int32[nthreads] ) packb_thrcomm_barrier_sense = module.define_global_var( - name='pack_b_barrier_sense', - var_type=int32[nthreads] + name='pack_b_barrier_sense', var_type=int32[nthreads] ) packb_thrcomm_barrier_threads_arrived = module.define_global_var( - name="pack_b_threads_arrived", - var_type=int32[nthreads] + name="pack_b_threads_arrived", var_type=int32[nthreads] ) @hidet.script @@ -138,12 +130,8 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): init_thr.kind = "cpu_internal" # Helpers - packed_a_type = tensor_type('float32', layout=row_major(MC // MR, - 1) * column_major( - MR, KC)) - packed_b_type = tensor_type('float32', layout=row_major(1, - NC // NR) * row_major( - KC, NR)) + packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC)) + packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR)) # Get the number of threads remaining at each level loop4_nthreads = nthreads // loop5_nways @@ -173,7 +161,8 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~ n_bf_hi = n_bf_whole // n_way n_th_lo = n_bf_whole % n_way - # If some partitions must have more block_factors than others, assign the slightly larger partitions to lower index threads + # If some partitions must have more block_factors than others, + # assign the slightly larger partitions to lower index threads if n_th_lo != 0: n_bf_lo += 1 # Compute the actual widths (in units of rows/columns) of individual threads in the low and high groups @@ -200,8 +189,9 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~ thread_range_sub.kind = "cpu_internal" @hidet.script - def thread_range_jrir(work_id: int32, n_way: int32, n: int32, bf: int32, - start: ~int32, end: ~int32, inc: ~int32): + def thread_range_jrir( + work_id: int32, n_way: int32, n: int32, bf: int32, start: ~int32, end: ~int32, inc: ~int32 + ): start[0] = work_id end[0] = n inc[0] = n_way @@ -235,15 +225,13 @@ def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32: # Thread barrier @hidet.script - def thrcomm_barrier(barrier_sense: ~int32, - barrier_threads_arrived: ~int32, num_threads: int32): + def thrcomm_barrier(barrier_sense: ~int32, barrier_threads_arrived: ~int32, num_threads: int32): if num_threads == 1: return orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED # Register the current thread's arrival by incrementing the counter - my_threads_arrived = cpu_atomic_add_fetch( - barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL + my_threads_arrived = cpu_atomic_add_fetch(barrier_threads_arrived, 1, 4) # _ATOMIC_ACQ_REL if my_threads_arrived == num_threads: barrier_threads_arrived[0] = 0 @@ -256,8 +244,13 @@ def thrcomm_barrier(barrier_sense: ~int32, @hidet.script def micro_kernel( - a: packed_a_type, b: packed_b_type, c_ptr: ~float32, pb: int32, msize: int32, nsize: int32, - is_first: bool + a: packed_a_type, + b: packed_b_type, + c_ptr: ~float32, + pb: int32, + msize: int32, + nsize: int32, + is_first: bool, ): c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) @@ -340,38 +333,28 @@ def micro_kernel( avx_f32x8_store(c_ptr + (5 * nsize + 8), c58) #### Some setup code #### - packed_b_height = KC - if packed_b_height > k_size: - packed_b_height = k_size - packed_b_width = NC - if packed_b_width > n_size: - packed_b_width = (n_size + NR - 1) // NR * NR + packed_b_height = min(KC, k_size) + packed_b_width = min(NC, (n_size + NR - 1) // NR * NR) packed_b_total_width = packed_b_width * loop5_nways packed_b_total_size = packed_b_total_width * packed_b_height packed_b_individual_size = packed_b_width * packed_b_height - packed_a_individual_height = MC - if packed_a_individual_height > m_size: - packed_a_individual_height = (m_size + MR - 1) // MR * MR + packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed - packed_a_width = KC - if packed_a_width > k_size: - packed_a_width = k_size - # pad this to be able to use the aligned version of the avx store - packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 + # packed_a_width = KC + # if packed_a_width > k_size: + # packed_a_width = k_size + # # pad this to be able to use the aligned version of the avx store + # packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 + packed_a_width = min(KC, (k_size + 8 - 1) // 8 * 8) + packed_a_total_size = packed_a_total_height * packed_a_width packed_a_individual_size = packed_a_width * packed_a_individual_height - packb_buf_ptr = module.define_global_var( - name='packb_buf_ptr', - var_type=float32[packed_b_total_size] - ) - packa_buf_ptr = module.define_global_var( - name='packa_buf_ptr', - var_type=float32[packed_a_total_size] - ) + packb_buf_ptr = module.define_global_var(name='packb_buf_ptr', var_type=float32[packed_b_total_size]) + packa_buf_ptr = module.define_global_var(name='packa_buf_ptr', var_type=float32[packed_a_total_size]) packb_buf = cast(packb_buf_ptr, ~float32) packa_buf = cast(packa_buf_ptr, ~float32) @@ -380,17 +363,16 @@ def micro_kernel( @hidet.script def gemm_pack_a( - loop3_partition_a: ~float32, - loop3_partition_a_width: int32, - loop3_partition_a_height: int32, - packed_a_buf: ~float32, - work_id_packa: int32, + loop3_partition_a: ~float32, + loop3_partition_a_width: int32, + loop3_partition_a_height: int32, + packed_a_buf: ~float32, + work_id_packa: int32, ): packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, - layout=row_major(packed_a_individual_height // MR, 1) * - column_major(MR, packed_a_width) + layout=row_major(packed_a_individual_height // MR, 1) * column_major(MR, packed_a_width), ) npanels_full_a = loop3_partition_a_height // MR @@ -403,8 +385,7 @@ def gemm_pack_a( continue a_curr_panel_row_start = ii_panel * MR - a_curr_panel_height = min(MR, - loop3_partition_a_height - a_curr_panel_row_start) + a_curr_panel_height = min(MR, loop3_partition_a_height - a_curr_panel_row_start) if a_curr_panel_height == MR: # unroll the packing by 8 k_iters = loop3_partition_a_width // 8 @@ -412,9 +393,7 @@ def gemm_pack_a( col = 0 for k_iter in range(k_iters): col = k_iter * 8 - a_curr_panel_col = loop3_partition_a + ( - a_curr_panel_row_start * k_size + col - ) + a_curr_panel_col = loop3_partition_a + (a_curr_panel_row_start * k_size + col) v0 = avx_f32x8_load(a_curr_panel_col) v1 = avx_f32x8_load(a_curr_panel_col + k_size) @@ -449,64 +428,40 @@ def gemm_pack_a( res4 = avx_f32x8_insert_f32x4(shf3, low_shf4, 0x1) res5 = avx_f32x8_permute2f32x4(shf3, shf4, 0x31) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start, col], - res0 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 2, - col + 1], - res2 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 4, - col + 2], - res4) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start, - col + 4], - res1 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 2, - col + 5], - res3 - ) - avx_f32x8_store_aligned( - ~packed_a_tensor[a_curr_panel_row_start + 4, - col + 6], - res5 - ) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start, col], res0) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 2, col + 1], res2) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 4, col + 2], res4) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start, col + 4], res1) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 2, col + 5], res3) + avx_f32x8_store_aligned(~packed_a_tensor[a_curr_panel_row_start + 4, col + 6], res5) remaining_start_col = k_iters * 8 for remain_off in range(k_remainder): curr_remain_col = remaining_start_col + remain_off for micropanel_row in range(MR): packed_a_tensor[ - a_curr_panel_row_start + micropanel_row, - curr_remain_col] = \ - loop3_partition_a[( - micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col] + a_curr_panel_row_start + micropanel_row, curr_remain_col + ] = loop3_partition_a[ + (micropanel_row + a_curr_panel_row_start) * k_size + curr_remain_col + ] else: remain_start_row = npanels_full_a * MR for remain_col in range(loop3_partition_a_width): for remain_row in range(panel_a_remainder): - packed_a_tensor[ - remain_start_row + remain_row, remain_col] = \ - loop3_partition_a[( - remain_row + remain_start_row) * k_size + remain_col] + packed_a_tensor[remain_start_row + remain_row, remain_col] = loop3_partition_a[ + (remain_row + remain_start_row) * k_size + remain_col + ] remain_row = panel_a_remainder while remain_row < MR: - packed_a_tensor[ - remain_start_row + remain_row, remain_col] = 0.0 + packed_a_tensor[remain_start_row + remain_row, remain_col] = 0.0 remain_row += 1 @hidet.script def gemm_pack_b( - loop4_partition_b: ~float32, - loop4_partition_b_width: int32, - loop4_partition_b_height: int32, - packed_b_buf: ~float32, - work_id_packb: int32 + loop4_partition_b: ~float32, + loop4_partition_b_width: int32, + loop4_partition_b_height: int32, + packed_b_buf: ~float32, + work_id_packb: int32, ): npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR @@ -518,11 +473,9 @@ def gemm_pack_b( for i_panel in range(npanels_b): if i_panel % packb_nthreads != work_id_packb % packb_nthreads: continue - packed_b_buff_curr = packed_b_buf + ( - i_panel * packedb_panel_stride) + packed_b_buff_curr = packed_b_buf + (i_panel * packedb_panel_stride) curr_panel_start = i_panel * NR - curr_panel_width = min(NR, - loop4_partition_b_width - curr_panel_start) + curr_panel_width = min(NR, loop4_partition_b_width - curr_panel_start) if curr_panel_width == NR: k_iters = loop4_partition_b_height // 8 @@ -531,8 +484,7 @@ def gemm_pack_b( row = 0 for k_iter in range(k_iters): row = k_iter * 8 - b_panel = loop4_partition_b + ( - row * n_size + curr_panel_start) + b_panel = loop4_partition_b + (row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) @@ -592,8 +544,7 @@ def gemm_pack_b( row = k_iters * 8 for _ in range(k_remainder): - b_panel = loop4_partition_b + ( - row * n_size + curr_panel_start) + b_panel = loop4_partition_b + (row * n_size + curr_panel_start) b00 = avx_f32x8_load(b_panel) b08 = avx_f32x8_load(b_panel + 8) avx_f32x8_store_aligned(packed_b_buff_curr, b00) @@ -602,19 +553,15 @@ def gemm_pack_b( row += 1 else: - packed_b_remaining_buf = packed_b_buf + ( - npanels_full_b * packedb_panel_stride) + packed_b_remaining_buf = packed_b_buf + (npanels_full_b * packedb_panel_stride) if npanels_b_remainder > 0: remain_col_start = npanels_full_b * NR for remain_row in range(loop4_partition_b_height): - packed_b_remaining_buf_curr = packed_b_remaining_buf + ( - remain_row * NR) + packed_b_remaining_buf_curr = packed_b_remaining_buf + (remain_row * NR) for remain_col in range(npanels_b_remainder): - packed_b_remaining_buf_curr[0] = \ - loop4_partition_b[ - (remain_row * n_size) + ( - remain_col_start + remain_col) - ] + packed_b_remaining_buf_curr[0] = loop4_partition_b[ + (remain_row * n_size) + (remain_col_start + remain_col) + ] packed_b_remaining_buf_curr += 1 zero_fill_col = npanels_b_remainder while zero_fill_col < NR: @@ -628,19 +575,19 @@ def gemm_pack_b( @hidet.script def gemm_macro( - packed_a: ~float32, - packed_b: ~float32, - c: float32[m_size, n_size], - c_row_off: int32, - c_col_off: int32, - macro_m: int32, - macro_n: int32, - macro_k: int32, - ps_packed_a: int32, - ps_packed_b: int32, - comm_id_macro: int32, - work_id_macro: int32, - is_first: bool + packed_a: ~float32, + packed_b: ~float32, + c: float32[m_size, n_size], + c_row_off: int32, + c_col_off: int32, + macro_m: int32, + macro_n: int32, + macro_k: int32, + ps_packed_a: int32, + ps_packed_b: int32, + comm_id_macro: int32, + work_id_macro: int32, + is_first: bool, ): comm_id_1st_loop = comm_id_macro % loop1_nthreads work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways) @@ -662,25 +609,9 @@ def gemm_macro( jr_inc = -1 ir_inc = -1 - thread_range_jrir( - work_id_macro, - macro_nways, - n_iter, - 1, - ~jr_start, - ~jr_end, - ~jr_inc - ) + thread_range_jrir(work_id_macro, macro_nways, n_iter, 1, ~jr_start, ~jr_end, ~jr_inc) - thread_range_jrir( - work_id_1st_loop, - loop1_nways, - m_iter, - 1, - ~ir_start, - ~ir_end, - ~ir_inc - ) + thread_range_jrir(work_id_1st_loop, loop1_nways, m_iter, 1, ~ir_start, ~ir_end, ~ir_inc) rstep_a = ps_packed_a cstep_b = ps_packed_b @@ -688,15 +619,8 @@ def gemm_macro( cstep_c = NR rstep_c = n_size * MR - macro_c_cast = as_tensor_pointer( - ~c[c_row_off, c_col_off], - dtype=float32, - shape=(m_size, n_size) - ) - temp_c = tensor(scope=DeclareScope.Default, - dtype=float32, - layout=row_major(MR, NR), - is_static=False) + macro_c_cast = as_tensor_pointer(~c[c_row_off, c_col_off], dtype=float32, shape=(m_size, n_size)) + temp_c = tensor(scope=DeclareScope.Default, dtype=float32, layout=row_major(MR, NR), is_static=False) j = jr_start while j < jr_end: b1 = packed_b + j * cstep_b @@ -730,53 +654,44 @@ def gemm_macro( @hidet.script def gemm_3rd_loop( - a: float32[m_size, k_size], - packed_b: ~float32, - c: float32[m_size, n_size], - loop3_partition_a_start_col: int32, - loop3_partition_b_start_col: int32, - loop3_partition_a_width: int32, - loop3_partition_b_width: int32, - comm_id_3rd_loop: int32, - work_id_3rd_loop: int32, - is_first: bool, work_id_5th_loop: int32): + a: float32[m_size, k_size], + packed_b: ~float32, + c: float32[m_size, n_size], + loop3_partition_a_start_col: int32, + loop3_partition_b_start_col: int32, + loop3_partition_a_width: int32, + loop3_partition_b_width: int32, + comm_id_3rd_loop: int32, + work_id_3rd_loop: int32, + is_first: bool, + work_id_5th_loop: int32, + ): comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) work_id_packa = comm_id_macro m_start_loop3 = 0 m_end_loop3 = 0 - thread_range_sub( - loop3_nways, - work_id_3rd_loop, - m_size, - MR, - ~m_start_loop3, - ~m_end_loop3 - ) + thread_range_sub(loop3_nways, work_id_3rd_loop, m_size, MR, ~m_start_loop3, ~m_end_loop3) ii = m_start_loop3 while ii < m_end_loop3: - b_alg_loop3 = determine_blocksize_f_sub( - ii, m_size, MC - ) + b_alg_loop3 = determine_blocksize_f_sub(ii, m_size, MC) b_alg_loop3 = min(b_alg_loop3, m_end_loop3 - ii) loop3_partition_a_start_row = ii loop3_partition_a_height = b_alg_loop3 loop3_partition_a = cast(a, ~float32) + ( - loop3_partition_a_start_row * k_size + - loop3_partition_a_start_col + loop3_partition_a_start_row * k_size + loop3_partition_a_start_col ) packed_a_idx = packa_index(work_id_5th_loop, work_id_3rd_loop) packed_a_buf = packa_buf + (packed_a_idx * packed_a_individual_size) - thrcomm_barrier( ~packa_thrcomm_barrier_sense[packed_a_idx], ~packa_thrcomm_threads_arrived[packed_a_idx], - packa_nthreads + packa_nthreads, ) gemm_pack_a( @@ -792,35 +707,38 @@ def gemm_3rd_loop( thrcomm_barrier( ~packa_thrcomm_barrier_sense[packed_a_idx], ~packa_thrcomm_threads_arrived[packed_a_idx], - packa_nthreads + packa_nthreads, ) - gemm_macro(packed_a_buf, - packed_b, - c, - loop3_partition_a_start_row, - loop3_partition_b_start_col, - loop3_partition_a_height, - loop3_partition_b_width, - loop3_partition_a_width, - MR * packed_a_width, - packed_b_height * NR, - comm_id_macro, - work_id_macro, - is_first - ) + gemm_macro( + packed_a_buf, + packed_b, + c, + loop3_partition_a_start_row, + loop3_partition_b_start_col, + loop3_partition_a_height, + loop3_partition_b_width, + loop3_partition_a_width, + MR * packed_a_width, + packed_b_height * NR, + comm_id_macro, + work_id_macro, + is_first, + ) ii += b_alg_loop3 gemm_3rd_loop.kind = "cpu_internal" @hidet.script - def gemm_4th_loop(a: float32[m_size, k_size], - b: float32[k_size, n_size], - c: float32[k_size, n_size], - loop5_partition_b_width: int32, - loop5_partition_b_start_col: int32, - comm_id_4th_loop: int32, - work_id_5th_loop: int32): + def gemm_4th_loop( + a: float32[m_size, k_size], + b: float32[k_size, n_size], + c: float32[k_size, n_size], + loop5_partition_b_width: int32, + loop5_partition_b_start_col: int32, + comm_id_4th_loop: int32, + work_id_5th_loop: int32, + ): i_loop4 = 0 comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads @@ -837,34 +755,38 @@ def gemm_4th_loop(a: float32[m_size, k_size], loop4_partition_b_start_col = loop5_partition_b_start_col loop4_partition_a_start_col = i_loop4 - is_first = (i_loop4 == 0) + is_first = i_loop4 == 0 - packed_b_buf = packb_buf + ( - packed_b_individual_size * work_id_5th_loop - ) + packed_b_buf = packb_buf + (packed_b_individual_size * work_id_5th_loop) - loop4_partition_b = cast(b, ~float32) + \ - (loop4_partition_b_start_row * n_size + - loop4_partition_b_start_col) + loop4_partition_b = cast(b, ~float32) + ( + loop4_partition_b_start_row * n_size + loop4_partition_b_start_col + ) thrcomm_barrier( ~packb_thrcomm_barrier_sense[work_id_5th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - packb_nthreads + packb_nthreads, ) - gemm_pack_b(loop4_partition_b, loop4_partition_b_width, - loop4_partition_b_height, packed_b_buf, - work_id_packb) + gemm_pack_b( + loop4_partition_b, + loop4_partition_b_width, + loop4_partition_b_height, + packed_b_buf, + work_id_packb, + ) thrcomm_barrier( ~packb_thrcomm_barrier_sense[work_id_5th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - packb_nthreads + packb_nthreads, ) gemm_3rd_loop( - a, packed_b_buf, c, + a, + packed_b_buf, + c, loop4_partition_a_start_col, loop4_partition_b_start_col, loop4_partition_b_height, @@ -872,13 +794,13 @@ def gemm_4th_loop(a: float32[m_size, k_size], comm_id_3rd_loop, work_id_3rd_loop, is_first, - work_id_5th_loop + work_id_5th_loop, ) thrcomm_barrier( ~packb_thrcomm_barrier_sense[work_id_5th_loop], ~packb_thrcomm_barrier_threads_arrived[work_id_5th_loop], - packb_nthreads + packb_nthreads, ) i_loop4 += b_alg_loop4 @@ -886,46 +808,47 @@ def gemm_4th_loop(a: float32[m_size, k_size], gemm_4th_loop.kind = "cpu_internal" @hidet.script - def gemm_5th_loop(a: float32[m_size, k_size], - b: float32[k_size, n_size], - c: float32[m_size, n_size], - work_id_5th_loop: int32, - comm_id_5th_loop: int32): + def gemm_5th_loop( + a: float32[m_size, k_size], + b: float32[k_size, n_size], + c: float32[m_size, n_size], + work_id_5th_loop: int32, + comm_id_5th_loop: int32, + ): comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads loop5_my_start = -1 loop5_my_end = -1 - thread_range_sub(loop5_nways, work_id_5th_loop, n_size, - NR, ~loop5_my_start, ~loop5_my_end) + thread_range_sub(loop5_nways, work_id_5th_loop, n_size, NR, ~loop5_my_start, ~loop5_my_end) loop5_iter = loop5_my_start while loop5_iter < loop5_my_end: - b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, - loop5_my_end, NC) + b_alg_loop5 = determine_blocksize_f_sub(loop5_iter, loop5_my_end, NC) b_alg_loop5 = min(b_alg_loop5, loop5_my_end - loop5_iter) - loop5_partition_b_width = b_alg_loop5, + loop5_partition_b_width = (b_alg_loop5,) loop5_partition_b_start_col = loop5_iter - gemm_4th_loop(a, b, c, - loop5_partition_b_width, - loop5_partition_b_start_col, - comm_id_4th_loop, - work_id_5th_loop) + gemm_4th_loop( + a, + b, + c, + loop5_partition_b_width, + loop5_partition_b_start_col, + comm_id_4th_loop, + work_id_5th_loop, + ) loop5_iter += b_alg_loop5 gemm_5th_loop.kind = 'cpu_internal' ################### Start of the main kernel ################### @hidet.script - def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], - c: float32[m_size, n_size]): + def matmul_kernel_x86_v3( + a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size] + ): - init_thr(packa_thrcomm_barrier_sense, - packa_thrcomm_threads_arrived, - loop3_nways) - init_thr(packb_thrcomm_barrier_sense, - packb_thrcomm_barrier_threads_arrived, - loop5_nways) + init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) + init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads @@ -942,13 +865,13 @@ def matmul_kernel_x86_v3(a: float32[m_size, k_size], b: float32[k_size, n_size], return ir_module -class Matmulx86Op_refactored(Operator): +class Matmulx86Op(Operator): def __init__(self, a: Tensor, b: Tensor): if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) - task = MatmulF32Taskx86_refactored(input_like(a, 'a'), input_like(b, 'b')) + task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b')) super().__init__(inputs=[a, b], attributes={}, task=task) -def matmul_x86_refactored(a: Tensor, b: Tensor) -> Tensor: - return Matmulx86Op_refactored(a, b).outputs[0] +def matmul_x86(a: Tensor, b: Tensor) -> Tensor: + return Matmulx86Op(a, b).outputs[0] diff --git a/python/hidet/ir/primitives/cpu/__init__.py b/python/hidet/ir/primitives/cpu/__init__.py index f62df917e..ddb89978e 100644 --- a/python/hidet/ir/primitives/cpu/__init__.py +++ b/python/hidet/ir/primitives/cpu/__init__.py @@ -26,4 +26,3 @@ ) from .atomic import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor - diff --git a/python/hidet/ir/primitives/cpu/atomic.py b/python/hidet/ir/primitives/cpu/atomic.py index b38a52247..dbf94e4c9 100644 --- a/python/hidet/ir/primitives/cpu/atomic.py +++ b/python/hidet/ir/primitives/cpu/atomic.py @@ -11,7 +11,7 @@ # limitations under the License. from typing import Union -from hidet.ir.expr import Expr, Call +from hidet.ir.expr import Expr from hidet.ir.type import FuncType, VoidType, PointerType from hidet.ir.primitives.func import register_primitive_function from hidet.utils import initialize @@ -40,20 +40,3 @@ def cpu_atomic_add_fetch(ptr: Expr, val: Union[Expr, int], order: Union[Expr, in def cpu_atomic_fetch_xor(ptr: Expr, val: Union[Expr, int], order: Union[Expr, int]) -> Expr: return call_primitive_func('cpu_atomic_fetch_xor', [ptr, val, order]) - - - - - - - - - - - - - - - - - diff --git a/python/hidet/ir/primitives/cpu/avx.py b/python/hidet/ir/primitives/cpu/avx.py index bc0392df1..fc6646f83 100644 --- a/python/hidet/ir/primitives/cpu/avx.py +++ b/python/hidet/ir/primitives/cpu/avx.py @@ -27,15 +27,21 @@ def register_primitive_functions(): ('avx_x86_float32x4_load_aligned', '_mm_load_ps', FuncType([PointerType('float32')], 'float32x4')), ('avx_x86_float32x4_store', '_mm_storeu_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())), ( - 'avx_x86_float32x4_store_aligned', '_mm_store_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())), + 'avx_x86_float32x4_store_aligned', + '_mm_store_ps', + FuncType([PointerType('float32'), 'float32x4'], VoidType()), + ), ('avx_x86_float32x4_setzero', '_mm_setzero_ps', FuncType([], 'float32x4')), ('avx_x86_float32x8_broadcast', '_mm256_broadcast_ss', FuncType([PointerType('float32')], 'float32x8')), ('avx_x86_float32x8_fmadd', '_mm256_fmadd_ps', FuncType(['float32x8', 'float32x8', 'float32x8'], 'float32x8')), ('avx_x86_float32x8_load', '_mm256_loadu_ps', FuncType([PointerType('float32')], 'float32x8')), ('avx_x86_float32x8_load_aligned', '_mm256_load_ps', FuncType([PointerType('float32')], 'float32x8')), ('avx_x86_float32x8_store', '_mm256_storeu_ps', FuncType([PointerType('float32'), 'float32x8'], VoidType())), - ('avx_x86_float32x8_store_aligned', '_mm256_store_ps', - FuncType([PointerType('float32'), 'float32x8'], VoidType())), + ( + 'avx_x86_float32x8_store_aligned', + '_mm256_store_ps', + FuncType([PointerType('float32'), 'float32x8'], VoidType()), + ), ('avx_x86_float32x8_setzero', '_mm256_setzero_ps', FuncType([], 'float32x8')), ('avx_x86_malloc', '_mm_malloc', FuncType(['uint64', 'uint64'], PointerType(VoidType()))), ('avx_x86_free', '_mm_free', FuncType([PointerType(VoidType())], VoidType())), @@ -52,12 +58,12 @@ def register_primitive_functions(): ( 'avx_x86_float32x8_insert_float32x4', '_mm256_insertf128_ps', - FuncType(['float32x8', 'float32x4', 'int32'], 'float32x8') + FuncType(['float32x8', 'float32x4', 'int32'], 'float32x8'), ), ( 'avx_x86_float32x8_permute2float32x4', '_mm256_permute2f128_ps', - FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8') + FuncType(['float32x8', 'float32x8', 'int32'], 'float32x8'), ), ] for name, codegen_name, func_type in functions: @@ -162,8 +168,3 @@ def avx_f32x8_insert_f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call: def avx_f32x8_permute2f32x4(a: Expr, b: Expr, imm: Union[int, Expr]) -> Call: return call_primitive_func('avx_x86_float32x8_permute2float32x4', [a, b, imm]) - - - - - diff --git a/python/hidet/lang/cpu.py b/python/hidet/lang/cpu.py index 0a2da1da8..7706e8fd2 100644 --- a/python/hidet/lang/cpu.py +++ b/python/hidet/lang/cpu.py @@ -45,4 +45,4 @@ avx_f32x8_permute2f32x4, ) -from hidet.ir.primitives.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor \ No newline at end of file +from hidet.ir.primitives.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor diff --git a/python/mat_new.py b/python/mat_new.py index 265fe205f..8b4fcbe2a 100644 --- a/python/mat_new.py +++ b/python/mat_new.py @@ -2,7 +2,7 @@ import pytest import hidet -from hidet.graph.ops import matmul_x86_refactored +from hidet.graph.ops import matmul_x86 from hidet.testing import check_binary from hidet.option import debug_cache_tuning @@ -63,7 +63,7 @@ def matmul_ansor(M, K, N, dtype): x1 = hidet.symbol_like(a) x2 = hidet.symbol_like(b) - y = matmul_x86_refactored(x1, x2) + y = matmul_x86(x1, x2) graph = hidet.trace_from( y, inputs=[x1, x2] ) diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index 1e308ed47..b7acdffcf 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -17,7 +17,7 @@ from hidet.testing import check_binary, check_binary_dynamic -@pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.") +# @pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.") @pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]]) def test_matmul_x86(a_shape, b_shape): # TODO: Doesn't support broadcasting yet; need to add it later? From ebcc78f42fc084c74794ba595e672a9de6ee185a Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Fri, 17 Nov 2023 17:25:41 -0500 Subject: [PATCH 133/148] avoid changing function attributes from outside --- .../hidet/graph/ops/matmul/matmul_f32_x86.py | 49 +++++++------------ 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index ba08c63d9..25dbed008 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -21,6 +21,7 @@ from hidet.ir.library import tune from hidet.graph.operator import Operator, Tensor from hidet.graph.ops.utils import broadcast_indices +from hidet.lang import attrs class MatmulF32Taskx86(Task): @@ -123,12 +124,11 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> @hidet.script def init_thr(sense: ~int32, arrived: ~int32, size: int32): + attrs.func_kind = 'cpu_internal' for i in range(size): sense[i] = 0 arrived[i] = 0 - init_thr.kind = "cpu_internal" - # Helpers packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC)) packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR)) @@ -146,6 +146,7 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32): @hidet.script def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32): + attrs.func_kind = "cpu_internal" if n_way == 1: start[0] = 0 end[0] = n @@ -186,20 +187,18 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~ end[0] += n_bf_left end[0] = min(end[0], all_end) - thread_range_sub.kind = "cpu_internal" - @hidet.script def thread_range_jrir( work_id: int32, n_way: int32, n: int32, bf: int32, start: ~int32, end: ~int32, inc: ~int32 ): + attrs.func_kind = "cpu_internal" start[0] = work_id end[0] = n inc[0] = n_way - thread_range_jrir.kind = "cpu_internal" - @hidet.script def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: + attrs.func_kind = 'cpu_internal' dim_left_now = dim - i b_now = -1 if dim_left_now <= b_alg: @@ -209,23 +208,20 @@ def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32: assert b_now >= 0 return b_now - determine_blocksize_f_sub.kind = "cpu_internal" - @hidet.script def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool: + attrs.func_kind = 'cpu_internal' return i != n_iter - 1 or n_left == 0 - not_edge.kind = 'cpu_internal' - @hidet.script def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32: + attrs.func_kind = 'cpu_internal' return work_id_loop5 * loop3_nways + work_id_loop3 - packa_index.kind = 'cpu_internal' - # Thread barrier @hidet.script def thrcomm_barrier(barrier_sense: ~int32, barrier_threads_arrived: ~int32, num_threads: int32): + attrs.func_kind = 'cpu_internal' if num_threads == 1: return orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED @@ -240,8 +236,6 @@ def thrcomm_barrier(barrier_sense: ~int32, barrier_threads_arrived: ~int32, num_ while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE pass - thrcomm_barrier.kind = 'cpu_internal' - @hidet.script def micro_kernel( a: packed_a_type, @@ -252,6 +246,7 @@ def micro_kernel( nsize: int32, is_first: bool, ): + attrs.func_kind = 'cpu_internal' c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize]) c0 = avx_f32x8_load(~c[0, 0]) c08 = avx_f32x8_load(~c[0, 8]) @@ -343,11 +338,6 @@ def micro_kernel( packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed - # packed_a_width = KC - # if packed_a_width > k_size: - # packed_a_width = k_size - # # pad this to be able to use the aligned version of the avx store - # packed_a_width = (packed_a_width + 8 - 1) // 8 * 8 packed_a_width = min(KC, (k_size + 8 - 1) // 8 * 8) packed_a_total_size = packed_a_total_height * packed_a_width @@ -369,6 +359,7 @@ def gemm_pack_a( packed_a_buf: ~float32, work_id_packa: int32, ): + attrs.func_kind = 'cpu_internal' packed_a_tensor = as_tensor_pointer( packed_a_buf, float32, @@ -463,6 +454,7 @@ def gemm_pack_b( packed_b_buf: ~float32, work_id_packb: int32, ): + attrs.func_kind = 'cpu_internal' npanels_full_b = loop4_partition_b_width // NR npanels_b_remainder = loop4_partition_b_width % NR @@ -569,10 +561,6 @@ def gemm_pack_b( packed_b_remaining_buf_curr += 1 zero_fill_col += 1 - gemm_pack_b.kind = "cpu_internal" - gemm_pack_a.kind = "cpu_internal" - micro_kernel.kind = "cpu_internal" - @hidet.script def gemm_macro( packed_a: ~float32, @@ -589,6 +577,7 @@ def gemm_macro( work_id_macro: int32, is_first: bool, ): + attrs.func_kind = 'cpu_internal' comm_id_1st_loop = comm_id_macro % loop1_nthreads work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways) @@ -650,8 +639,6 @@ def gemm_macro( i += ir_inc j += jr_inc - gemm_macro.kind = "cpu_internal" - @hidet.script def gemm_3rd_loop( a: float32[m_size, k_size], @@ -666,6 +653,7 @@ def gemm_3rd_loop( is_first: bool, work_id_5th_loop: int32, ): + attrs.func_kind = 'cpu_internal' comm_id_macro = comm_id_3rd_loop % macro_nthreads work_id_macro = comm_id_macro // (macro_nthreads // macro_nways) work_id_packa = comm_id_macro @@ -727,8 +715,6 @@ def gemm_3rd_loop( ) ii += b_alg_loop3 - gemm_3rd_loop.kind = "cpu_internal" - @hidet.script def gemm_4th_loop( a: float32[m_size, k_size], @@ -739,6 +725,7 @@ def gemm_4th_loop( comm_id_4th_loop: int32, work_id_5th_loop: int32, ): + attrs.func_kind = 'cpu_internal' i_loop4 = 0 comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads @@ -805,8 +792,6 @@ def gemm_4th_loop( i_loop4 += b_alg_loop4 - gemm_4th_loop.kind = "cpu_internal" - @hidet.script def gemm_5th_loop( a: float32[m_size, k_size], @@ -815,6 +800,7 @@ def gemm_5th_loop( work_id_5th_loop: int32, comm_id_5th_loop: int32, ): + attrs.func_kind = 'cpu_internal' comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads loop5_my_start = -1 @@ -839,13 +825,12 @@ def gemm_5th_loop( ) loop5_iter += b_alg_loop5 - gemm_5th_loop.kind = 'cpu_internal' - ################### Start of the main kernel ################### @hidet.script def matmul_kernel_x86_v3( a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size] ): + attrs.func_kind = 'cpu_kernel' init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) @@ -860,7 +845,7 @@ def matmul_kernel_x86_v3( gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) - matmul_kernel_x86_v3.kind = "cpu_kernel" + # matmul_kernel_x86_v3.kind = "cpu_kernel" ir_module = module.ir_module() return ir_module From fa3945644c05d2362377b031e8cdab3df414c398 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Mon, 11 Dec 2023 20:42:49 -0500 Subject: [PATCH 134/148] Delete python/mat_new.py deleted redundant file --- python/mat_new.py | 148 ---------------------------------------------- 1 file changed, 148 deletions(-) delete mode 100644 python/mat_new.py diff --git a/python/mat_new.py b/python/mat_new.py deleted file mode 100644 index 8b4fcbe2a..000000000 --- a/python/mat_new.py +++ /dev/null @@ -1,148 +0,0 @@ -import numpy as np -import pytest - -import hidet -from hidet.graph.ops import matmul_x86 -from hidet.testing import check_binary -from hidet.option import debug_cache_tuning - -import torch - -import tvm -from tvm import te, auto_scheduler - -@auto_scheduler.register_workload -def matmul_ansor(M, K, N, dtype): - A = te.placeholder((M, K), name='A', dtype=dtype) - B = te.placeholder((K, N), name='B', dtype=dtype) - - k = te.reduce_axis((0, K), name='k') - rst = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), - name='matmul_ansor', - attrs={"layout_free_placeholders": [B], - # Enable automatic layout transform for B} - } - ) - - return [A, B, rst] -hidet.option.cache_dir("./wtf") - -target = tvm.target.Target("llvm -mcpu=core-avx2") -debug_cache_tuning(True) -hidet.option.search_space(0) - -np.random.seed(42) -# for m, n, k in [(33, 65, 60), (32, 92, 128)]: -# for m, n, k in [(7, 1, 17), (256, 256, 256), (512, 512, 512), (768, 768, 768)]: -# for m, n, k in [(7, 1, 17), (32, 32, 32), (36, 36, 36), (37, 37, 37)]: -# for m, n, k in [(7, 17, 1), (16, 16, 16), (333, 444, 555), (768, 768, 768)]: -# for m, n, k in [(7, 17, 1), (16, 16, 16), (17, 17, 17), (36, 36, 36), (37, 37, 37), (128, 128, 128), (256, 256, 256), (333, 444, 555), (768, 768, 768)]: -# for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768)]: -for m, n, k in [(20, 20, 20), (333, 444, 555), (768, 768, 768), (555, 256, 3072), (2048, 2048, 2048)]: - # a = hidet.randn([m, k], device='cpu') - # b = hidet.randn([k, n], device='cpu') - - # a_torch = torch.arange(0, m*k).reshape(m, k).float().to('cpu') - # b_torch = torch.arange(0, k*n).reshape(k, n).float().to('cpu') - # # - # # print(f"a_torch: {a_torch}") - # # print(f"b_torch: {b_torch}") - # - # a = hidet.from_torch(a_torch).to(dtype='float32', device='cpu') - # b = hidet.from_torch(b_torch).to(dtype='float32', device='cpu') - # print(f"a: {a}") - # print(f"b: {b}") - - a = hidet.randn([m, k], device='cpu') - b = hidet.randn([k, n], device='cpu') - # a = hidet.ones([m, k], device='cpu') - # b = hidet.ones([k, n], device='cpu') - # - - x1 = hidet.symbol_like(a) - x2 = hidet.symbol_like(b) - y = matmul_x86(x1, x2) - graph = hidet.trace_from( - y, inputs=[x1, x2] - ) - opt_graph = hidet.graph.optimize(graph) - compiled_func = opt_graph.nodes[0].compiled_task - c = compiled_func(a, b) - - actual = c.numpy() - desired = a.numpy() @ b.numpy() - - fails = 0 - - for i in range(m): - for j in range(n): - if abs(actual[i, j] - desired[i, j]) < 1e-3: - # print(f"Actually passed for i={i}, j={j}") - continue - else: - print(f"Failed for i={i}, j={j}, and we have [i, j] = {actual[i, j]} and desired [i, j] = {desired[i, j]}") - fails += 1 - - print(f"Total fails: {fails}") - - # for i in range(m): - # for j in range(n): - # if actual[i, j] == 0.0: - # print(f"element is 0 for i={i}, j={j}") - - - np.testing.assert_allclose( - actual=actual, - desired=desired, - rtol=1e-3, - atol=1e-3 - ) - - print("passed for m={}, n={}, k={}".format(m, n, k)) - - # hidet_latency = hidet.utils.benchmark_func( - # lambda: compiled_func(a, b), repeat=50 - # ) - # np_latency = hidet.utils.benchmark_func( - # lambda: a.numpy() @ b.numpy(), repeat=50 - # ) - # - # ansor_task = tvm.auto_scheduler.SearchTask( - # func=matmul_ansor, args=(m, k, n, "float32"), target=target - # ) - # log_file = f"matmul_{m}x{k}x{n}.json" - # tune_option = auto_scheduler.TuningOptions( - # num_measure_trials=1000, - # measure_callbacks=[auto_scheduler.RecordToFile(log_file)], - # verbose=2, - # ) - # - # ansor_task.tune(tune_option) - # sch, args = ansor_task.apply_best(log_file) - # with open(f"./matmul_TIR_{m}x{k}x{n}", 'w') as f: - # f.write(str(tvm.lower(sch, args, simple_mode=True))) - # ansor_func = tvm.build(sch, args, target) - # dev = tvm.cpu() - # a_tvm = tvm.nd.array(a.numpy(), device=dev) - # b_tvm = tvm.nd.array(b.numpy(), device=dev) - # c_tvm = tvm.nd.empty((m, n), device=dev) - # - # ansor_func(a_tvm, b_tvm, c_tvm) - # - # np.testing.assert_allclose( - # actual=c_tvm.numpy(), - # desired=a_tvm.numpy() @ b_tvm.numpy(), - # rtol=1e-3, - # atol=1e-3 - # ) - # - # ansor_latency = hidet.utils.benchmark_func( - # lambda: ansor_func(a_tvm, b_tvm, c_tvm), repeat=30 - # ) - # - # with open(f"./perf_{m}x{k}x{n}.txt", 'w') as f: - # f.write(f"m={m}, k={k}, n={n}: hidet takes {hidet_latency:.2f} ms\n") - # f.write(f"m={m}, k={k}, n={n}: numpy takes {np_latency: .2f} ms\n") - # f.write(f"m={m}, k={k}, n={n}: ansor takes {ansor_latency: .2f} ms\n") From b61722d0b6869006eb9cdea70e63caaef4478e82 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Mon, 11 Dec 2023 21:10:03 -0500 Subject: [PATCH 135/148] Update matmul_f32_x86.py use the original name --- python/hidet/graph/ops/matmul/matmul_f32_x86.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index 25dbed008..eeb467b30 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -66,7 +66,7 @@ def __init__(self, a: TensorNode, b: TensorNode): ) super().__init__( - name='matmul_f32_x86_v2', + name='matmul_f32_x86', inputs=[a, b], outputs=[c], attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, From 0e7eb63bd8cdfdc3283b615b756fb3c7b06492a2 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 21 Dec 2023 00:24:19 -0500 Subject: [PATCH 136/148] adding batch support --- .../hidet/graph/ops/matmul/matmul_f32_x86.py | 100 +++++++++--------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index eeb467b30..04700af17 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -11,7 +11,7 @@ # limitations under the License. from typing import List, Union from hidet.ir.dtypes import float32, int32 -from hidet.ir.expr import cast +from hidet.ir.expr import cast, is_constant from hidet.ir.module import IRModule from hidet.ir.compute import TensorNode from hidet.ir.stmt import DeclareScope @@ -26,42 +26,19 @@ class MatmulF32Taskx86(Task): def __init__(self, a: TensorNode, b: TensorNode): - a_shape = a.const_shape - b_shape = b.const_shape - if not a.type.dtype == float32 or not b.type.dtype == float32: - raise ValueError('Both inputs must be float32 tensors') - - if len(a_shape) < 2 or len(b_shape) < 2: - raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) - - self._assert( - a_shape[-1] == b_shape[-2], - msg=( - 'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' - ', got {} and {}'.format(a_shape, b_shape) - ), - ) - - self._assert( - can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), - msg=( - 'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' - 'got {} and {}'.format(a_shape, b_shape) - ), - ) - - k_size = a_shape[-1] - c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] + batch_size, m_size, k_size = a.shape + batch_size, k_size, n_size = b.shape + self.batch_size = batch_size + self.m_size = m_size + self.n_size = n_size + self.k_size = k_size c = compute( name='c', - shape=c_shape, - fcompute=lambda *indices: reduce( - shape=[k_size], - fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] - * b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], - reduce_type='sum', + shape=[batch_size, m_size, n_size], + fcompute = lambda r, i, j: reduce( + shape=[k_size], fcompute=lambda k: a[r, i, k] * b[r, k, j], reduce_type='sum' ), ) @@ -69,9 +46,10 @@ def __init__(self, a: TensorNode, b: TensorNode): name='matmul_f32_x86', inputs=[a, b], outputs=[c], - attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, + attributes={'batch_size': batch_size, 'm_size': m_size, 'n_size': n_size, 'k_size': k_size}, ) + def allow_epilogue(self) -> bool: return True @@ -81,7 +59,9 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) + # @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) + @tune.space(2, MC=[144, 288, 432, 576, 720], NC=[800], KC=[256, 560, 768, 384], + ways=[(1, 4, 2, 1), (2, 4, 4, 1), (1, 4, 4, 1), (1, 2, 4, 2), (1, 4, 4, 2), (2, 4, 2, 2)]) def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -94,10 +74,15 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor + task = self node_a, node_b = self.inputs[0], self.inputs[1] a_shape = node_a.const_shape b_shape = node_b.const_shape - m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] + + batch_size = task.batch_size + m_size = task.m_size + n_size = task.n_size + k_size = task.k_size MR, NR = 6, 16 @@ -828,21 +813,32 @@ def gemm_5th_loop( ################### Start of the main kernel ################### @hidet.script def matmul_kernel_x86_v3( - a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size] + a: float32[batch_size, m_size, k_size], b: float32[batch_size, k_size, n_size], c: float32[batch_size, m_size, n_size] ): attrs.func_kind = 'cpu_kernel' - - init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) - init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) + a_ptr = cast(a, ~float32) + b_ptr = cast(b, ~float32) + c_ptr = cast(c, ~float32) parallel_attr = 'p' + str(nthreads) # The outermost loop spawning threads - for tidx in grid(nthreads, attrs=parallel_attr): - tid_5th_loop = tidx - work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) - comm_id_5th_loop = tid_5th_loop - - gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) + for batch in range(batch_size): + init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) + init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) + # Iterate through the batch dimension, and for each batch, + # locate the corresponding a, b, and c matrices, and then call the single matmul kernel + a_matrix_size = m_size * k_size + b_matrix_size = k_size * n_size + c_matrix_size = m_size * n_size + a_matrix = as_tensor_pointer(a_ptr + (batch * a_matrix_size), dtype=float32, shape=[m_size, k_size]) + b_matrix = as_tensor_pointer(b_ptr + (batch * b_matrix_size), dtype=float32, shape=[k_size, n_size]) + c_matrix = as_tensor_pointer(c_ptr + (batch * c_matrix_size), dtype=float32, shape=[m_size, n_size]) + for tidx in grid(nthreads, attrs=parallel_attr): + tid_5th_loop = tidx + work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) + comm_id_5th_loop = tid_5th_loop + + gemm_5th_loop(a_matrix, b_matrix, c_matrix, work_id_5th_loop, comm_id_5th_loop) assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) # matmul_kernel_x86_v3.kind = "cpu_kernel" @@ -852,8 +848,16 @@ def matmul_kernel_x86_v3( class Matmulx86Op(Operator): def __init__(self, a: Tensor, b: Tensor): - if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): - raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) + # if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): + # raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) + if not ( + len(a.shape) == len(b.shape) == 3 + and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0]) + and (not is_constant(a.shape[2], b.shape[1]) or a.shape[2] == b.shape[1]) + ): + raise ValueError("Matrix multiplication expects tensor A and B with shape [B, M, K] and [B, K, N]" + + ", got {} and {}".format(a.shape, b.shape) + ) task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b')) super().__init__(inputs=[a, b], attributes={}, task=task) From 4d9505d5eb91edc448ca86851f57044fd9095aed Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Wed, 10 Jan 2024 18:46:16 -0500 Subject: [PATCH 137/148] . --- python/hidet/graph/ops/__init__.py | 2 +- python/hidet/graph/ops/matmul/__init__.py | 2 +- .../hidet/graph/ops/matmul/matmul_f32_x86.py | 8 +++--- python/hidet/graph/ops/matmul/resolve.py | 26 ++++++++++++++----- tests/operators/test_matmul.py | 2 +- 5 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 8a2fcd5f3..c7deb291a 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=redefined-builtin -from .matmul import batch_matmul, matmul, matmul_x86 +from .matmul import batch_matmul, matmul, batch_matmul_x86 from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 diff --git a/python/hidet/graph/ops/matmul/__init__.py b/python/hidet/graph/ops/matmul/__init__.py index a1e8c0be5..4eed4fcf5 100644 --- a/python/hidet/graph/ops/matmul/__init__.py +++ b/python/hidet/graph/ops/matmul/__init__.py @@ -15,4 +15,4 @@ from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86 -from .matmul_f32_x86 import matmul_x86 +from .matmul_f32_x86 import batch_matmul_x86 diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index 04700af17..8c09f2fd4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -812,7 +812,7 @@ def gemm_5th_loop( ################### Start of the main kernel ################### @hidet.script - def matmul_kernel_x86_v3( + def matmul_kernel_x86( a: float32[batch_size, m_size, k_size], b: float32[batch_size, k_size, n_size], c: float32[batch_size, m_size, n_size] ): attrs.func_kind = 'cpu_kernel' @@ -840,8 +840,8 @@ def matmul_kernel_x86_v3( gemm_5th_loop(a_matrix, b_matrix, c_matrix, work_id_5th_loop, comm_id_5th_loop) - assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) - # matmul_kernel_x86_v3.kind = "cpu_kernel" + assert isinstance(matmul_kernel_x86, hidet.ir.Function) + # matmul_kernel_x86.kind = "cpu_kernel" ir_module = module.ir_module() return ir_module @@ -862,5 +862,5 @@ def __init__(self, a: Tensor, b: Tensor): super().__init__(inputs=[a, b], attributes={}, task=task) -def matmul_x86(a: Tensor, b: Tensor) -> Tensor: +def batch_matmul_x86(a: Tensor, b: Tensor) -> Tensor: return Matmulx86Op(a, b).outputs[0] diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 8d6adbdbf..46371416a 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -22,6 +22,7 @@ from .matmul import MatmulOp from .batch_matmul import batch_matmul +from .matmul_f32_x86 import batch_matmul_x86 from .matmul_f16 import matmul_f16 from ..transform import broadcast, flatten from ..utils import broadcast_shapes @@ -90,7 +91,7 @@ class MatmulResolveRule(ResolveRule): The generic matrix multiplication operator has the same semantics as numpy.matmul that accepts variable dimensions of inputs. - On ther other hand, the batched matrix multiplication operator accepts inputs with shape: + On the other hand, the batched matrix multiplication operator accepts inputs with shape: [batch_size, m_size, k_size] x [batch_size, k_size, n_size] This resolve rule also parallelize k dimension when possible, and determine the mma instruction. @@ -125,6 +126,9 @@ def run_batch_matmul(self, a: Tensor, b: Tensor) -> Tensor: c = batch_matmul(aa, bb, mma=mma).reshape([batch_size, nparts, m_size, n_size]).sum(1) return c + def run_batch_matmul_cpu(self, a: Tensor, b: Tensor) -> Tensor: + return batch_matmul_x86(a, b) + def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: assert isinstance(op, MatmulOp) a: Tensor = op.inputs[0] @@ -133,30 +137,38 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: if a.dtype.nbytes > 4 or b.dtype.nbytes > 4: return None + run_func = self.run_batch_matmul + if op.device.is_cpu(): + run_func = self.run_batch_matmul_cpu + if len(a.shape) == 1: # shape: [a] a = a.unsqueeze([0, 1]) # [1, 1, a] if len(b.shape) == 2: # shape: [a, b] # [a] x [a, b] -> [b] b = b.unsqueeze([0]) # [1, a, b] - c = self.run_batch_matmul(a, b) # [1, 1, b] + # c = self.run_batch_matmul(a, b) # [1, 1, b] FIXME: Delete later + c = run_func(a, b) # [1, 1, b] c = c.squeeze([0, 1]) # [b] else: assert len(b.shape) >= 3 # shape example: [b, c, a, d] # [a] x [b, c, a, d] -> [b, c, d] b = flatten(b, start_dim=0, end_dim=-3) # [b * c, a, d] - c = self.run_batch_matmul(a, b) # [b * c, 1, d] + # c = self.run_batch_matmul(a, b) # [b * c, 1, d] FIXME: Delete later + c = run_func(a, b) # [b * c, 1, d] c = c.reshape(c_shape) # [b, c, d] elif len(b.shape) == 1: # shape: [b] b = b.unsqueeze([0, 2]) # [1, b, 1] if len(a.shape) == 2: # shape: [a, b] a = a.unsqueeze([0]) # [1, a, b] - c = self.run_batch_matmul(a, b) # [1, a, 1] + # c = self.run_batch_matmul(a, b) # [1, a, 1] FIXME: Delete later + c = run_func(a, b) # [1, a, 1] c = c.squeeze([0, 2]) # [a] else: assert len(a.shape) >= 3 # shape example: [a, c, d, b] # [a, c, d, b] x [b] -> [a, c, d] a = flatten(a, start_dim=0, end_dim=-3) # [a * c, d, b] - c = self.run_batch_matmul(a, b) # [a * c, d, 1] + # c = self.run_batch_matmul(a, b) # [a * c, d, 1] FIXME: Delete later + c = run_func(a, b) # [a * c, d, 1] c = c.reshape(c_shape) # [a, c, d] else: # example: [a, b, c] x [c, d] -> [a, b, d] @@ -168,7 +180,8 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: b_broadcast_shape = c_head + list(b.shape[-2:]) a = flatten(broadcast(a, a_broadcast_shape), start_dim=0, end_dim=-3) b = flatten(broadcast(b, b_broadcast_shape), start_dim=0, end_dim=-3) - c = self.run_batch_matmul(a, b) + # c = self.run_batch_matmul(a, b) FIXME: Delete later + c = run_func(a, b) c = c.reshape(c_shape) return [c] @@ -239,6 +252,7 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: c = matmul_f16(a, b, parallel_k_parts=k_parts).sum(0) return [c] + def resolve(self, op: Operator) -> Optional[List[Tensor]]: if op.device.is_cpu(): return None diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index b7acdffcf..ccdd5c0e0 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -25,7 +25,7 @@ def test_matmul_x86(a_shape, b_shape): a_shape, b_shape, lambda x, y: np.matmul(x, y), - lambda x, y: ops.matmul_x86(x, y) - ops.matmul_x86(x, y) + ops.matmul_x86(x, y), + lambda x, y: ops.batch_matmul_x86(x, y) - ops.batch_matmul_x86(x, y) + ops.batch_matmul_x86(x, y), dtype="float32", atol=1e-4, rtol=1e-4, From 170896e2205b944991cdb5f861cfe632d753e027 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 14:08:57 -0500 Subject: [PATCH 138/148] resolve rule + batch support --- .../hidet/graph/ops/matmul/matmul_f32_x86.py | 31 ++++++++++--------- python/hidet/graph/ops/matmul/resolve.py | 1 - 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index 8c09f2fd4..b39e7f6b4 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -17,10 +17,9 @@ from hidet.ir.stmt import DeclareScope from hidet.ir.task import Task from hidet.ir.compute import compute, reduce -from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast +from hidet.graph.ops.utils import input_like from hidet.ir.library import tune from hidet.graph.operator import Operator, Tensor -from hidet.graph.ops.utils import broadcast_indices from hidet.lang import attrs @@ -37,7 +36,7 @@ def __init__(self, a: TensorNode, b: TensorNode): c = compute( name='c', shape=[batch_size, m_size, n_size], - fcompute = lambda r, i, j: reduce( + fcompute=lambda r, i, j: reduce( shape=[k_size], fcompute=lambda k: a[r, i, k] * b[r, k, j], reduce_type='sum' ), ) @@ -49,7 +48,6 @@ def __init__(self, a: TensorNode, b: TensorNode): attributes={'batch_size': batch_size, 'm_size': m_size, 'n_size': n_size, 'k_size': k_size}, ) - def allow_epilogue(self) -> bool: return True @@ -59,9 +57,14 @@ def allow_prologue(self) -> bool: def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: return tune.extract_ir_modules(self.schedule_matmulf32_x86) - # @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) - @tune.space(2, MC=[144, 288, 432, 576, 720], NC=[800], KC=[256, 560, 768, 384], - ways=[(1, 4, 2, 1), (2, 4, 4, 1), (1, 4, 4, 1), (1, 2, 4, 2), (1, 4, 4, 2), (2, 4, 2, 2)]) + @tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(2, 4, 2, 2)]) + @tune.space( + 2, + MC=[144, 288, 432, 576, 720], + NC=[800], + KC=[256, 560, 768, 384], + ways=[(1, 4, 2, 1), (2, 4, 4, 1), (1, 4, 4, 1), (1, 2, 4, 2), (1, 4, 4, 2), (2, 4, 2, 2)], + ) def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> IRModule: import hidet from hidet.ir.type import tensor_type @@ -75,9 +78,6 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor task = self - node_a, node_b = self.inputs[0], self.inputs[1] - a_shape = node_a.const_shape - b_shape = node_b.const_shape batch_size = task.batch_size m_size = task.m_size @@ -813,7 +813,9 @@ def gemm_5th_loop( ################### Start of the main kernel ################### @hidet.script def matmul_kernel_x86( - a: float32[batch_size, m_size, k_size], b: float32[batch_size, k_size, n_size], c: float32[batch_size, m_size, n_size] + a: float32[batch_size, m_size, k_size], + b: float32[batch_size, k_size, n_size], + c: float32[batch_size, m_size, n_size], ): attrs.func_kind = 'cpu_kernel' a_ptr = cast(a, ~float32) @@ -855,9 +857,10 @@ def __init__(self, a: Tensor, b: Tensor): and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0]) and (not is_constant(a.shape[2], b.shape[1]) or a.shape[2] == b.shape[1]) ): - raise ValueError("Matrix multiplication expects tensor A and B with shape [B, M, K] and [B, K, N]" - + ", got {} and {}".format(a.shape, b.shape) - ) + raise ValueError( + "Matrix multiplication expects tensor A and B with shape [B, M, K] and [B, K, N]" + + ", got {} and {}".format(a.shape, b.shape) + ) task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b')) super().__init__(inputs=[a, b], attributes={}, task=task) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 46371416a..231ec3c37 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -252,7 +252,6 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: c = matmul_f16(a, b, parallel_k_parts=k_parts).sum(0) return [c] - def resolve(self, op: Operator) -> Optional[List[Tensor]]: if op.device.is_cpu(): return None From 71fcd6a16afd7ae698f62b6a41b6cbac2b8dc3c8 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 14:18:46 -0500 Subject: [PATCH 139/148] modify test --- python/hidet/graph/ops/__init__.py | 2 +- tests/operators/test_matmul.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index c7deb291a..0472aefbe 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=redefined-builtin -from .matmul import batch_matmul, matmul, batch_matmul_x86 +from .matmul import batch_matmul, matmul, matmul_cublas, batch_matmul_x86 from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index 00b09de72..6d0c858cd 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -18,10 +18,8 @@ from hidet.testing import check_binary, check_binary_dynamic, check_torch_binary -# @pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.") -@pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]]) +@pytest.mark.parametrize("a_shape, b_shape", [[[1, 333, 444], [1, 444, 555]], [[1, 133, 1], [1, 1, 177]]]) def test_matmul_x86(a_shape, b_shape): - # TODO: Doesn't support broadcasting yet; need to add it later? check_binary( a_shape, b_shape, From 60319ca2f690798de908fc3fcbf909a0934266fa Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:58:11 -0500 Subject: [PATCH 140/148] Update python/hidet/graph/ops/matmul/matmul_f32_x86.py Co-authored-by: Yaoyao Ding --- python/hidet/graph/ops/matmul/matmul_f32_x86.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index b39e7f6b4..1883d6480 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -850,8 +850,6 @@ def matmul_kernel_x86( class Matmulx86Op(Operator): def __init__(self, a: Tensor, b: Tensor): - # if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): - # raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) if not ( len(a.shape) == len(b.shape) == 3 and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0]) From 4a6f641d390f99ada8749bd546215a19100dde18 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:58:26 -0500 Subject: [PATCH 141/148] Update python/hidet/graph/ops/matmul/resolve.py Co-authored-by: Yaoyao Ding --- python/hidet/graph/ops/matmul/resolve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 231ec3c37..b3671a8b3 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -146,7 +146,6 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: if len(b.shape) == 2: # shape: [a, b] # [a] x [a, b] -> [b] b = b.unsqueeze([0]) # [1, a, b] - # c = self.run_batch_matmul(a, b) # [1, 1, b] FIXME: Delete later c = run_func(a, b) # [1, 1, b] c = c.squeeze([0, 1]) # [b] else: From c4152e2d3727c08e49a983a6a25c394e08835ce4 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:58:36 -0500 Subject: [PATCH 142/148] Update python/hidet/graph/ops/matmul/resolve.py Co-authored-by: Yaoyao Ding --- python/hidet/graph/ops/matmul/resolve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index b3671a8b3..29a54914d 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -152,7 +152,6 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: assert len(b.shape) >= 3 # shape example: [b, c, a, d] # [a] x [b, c, a, d] -> [b, c, d] b = flatten(b, start_dim=0, end_dim=-3) # [b * c, a, d] - # c = self.run_batch_matmul(a, b) # [b * c, 1, d] FIXME: Delete later c = run_func(a, b) # [b * c, 1, d] c = c.reshape(c_shape) # [b, c, d] elif len(b.shape) == 1: # shape: [b] From f1bddb5d7cf7b493ae4e9de7372d7eaf6f87a5fc Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:58:44 -0500 Subject: [PATCH 143/148] Update python/hidet/graph/ops/matmul/resolve.py Co-authored-by: Yaoyao Ding --- python/hidet/graph/ops/matmul/resolve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 29a54914d..0f0ed2d1c 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -178,7 +178,6 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: b_broadcast_shape = c_head + list(b.shape[-2:]) a = flatten(broadcast(a, a_broadcast_shape), start_dim=0, end_dim=-3) b = flatten(broadcast(b, b_broadcast_shape), start_dim=0, end_dim=-3) - # c = self.run_batch_matmul(a, b) FIXME: Delete later c = run_func(a, b) c = c.reshape(c_shape) return [c] From c2ad5de6d561993450f57ce26ef99c98e525f17e Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:59:03 -0500 Subject: [PATCH 144/148] Update tests/operators/test_matmul.py Co-authored-by: Yaoyao Ding --- tests/operators/test_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/operators/test_matmul.py b/tests/operators/test_matmul.py index 6d0c858cd..de38cde78 100644 --- a/tests/operators/test_matmul.py +++ b/tests/operators/test_matmul.py @@ -24,7 +24,7 @@ def test_matmul_x86(a_shape, b_shape): a_shape, b_shape, lambda x, y: np.matmul(x, y), - lambda x, y: ops.batch_matmul_x86(x, y) - ops.batch_matmul_x86(x, y) + ops.batch_matmul_x86(x, y), + lambda x, y: ops.batch_matmul_x86(x, y), dtype="float32", atol=1e-4, rtol=1e-4, From 95dd0fb1f1c37c1626f922562e741c5ce59253c6 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:59:11 -0500 Subject: [PATCH 145/148] Update python/hidet/graph/ops/matmul/resolve.py Co-authored-by: Yaoyao Ding --- python/hidet/graph/ops/matmul/resolve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 0f0ed2d1c..9720d6ca8 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -158,7 +158,6 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: b = b.unsqueeze([0, 2]) # [1, b, 1] if len(a.shape) == 2: # shape: [a, b] a = a.unsqueeze([0]) # [1, a, b] - # c = self.run_batch_matmul(a, b) # [1, a, 1] FIXME: Delete later c = run_func(a, b) # [1, a, 1] c = c.squeeze([0, 2]) # [a] else: From 8d09697d8709bbfc175c05cdc6cb9a49b1cbc687 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Thu, 11 Jan 2024 19:59:19 -0500 Subject: [PATCH 146/148] Update python/hidet/graph/ops/matmul/resolve.py Co-authored-by: Yaoyao Ding --- python/hidet/graph/ops/matmul/resolve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 9720d6ca8..1e20a752e 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -164,7 +164,6 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: assert len(a.shape) >= 3 # shape example: [a, c, d, b] # [a, c, d, b] x [b] -> [a, c, d] a = flatten(a, start_dim=0, end_dim=-3) # [a * c, d, b] - # c = self.run_batch_matmul(a, b) # [a * c, d, 1] FIXME: Delete later c = run_func(a, b) # [a * c, d, 1] c = c.reshape(c_shape) # [a, c, d] else: From f9caaf672718d4ec45270e43ec95f284cc0f0e3e Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Fri, 23 Feb 2024 14:44:40 -0500 Subject: [PATCH 147/148] resolve asdfasdf --- python/hidet/graph/ops/matmul/resolve.py | 3 ++- python/hidet/ir/schedulers/cpu/scheduler.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 1e20a752e..df86fe04b 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -137,9 +137,10 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: if a.dtype.nbytes > 4 or b.dtype.nbytes > 4: return None - run_func = self.run_batch_matmul if op.device.is_cpu(): run_func = self.run_batch_matmul_cpu + else: + run_func = self.run_batch_matmul if len(a.shape) == 1: # shape: [a] a = a.unsqueeze([0, 1]) # [1, 1, a] diff --git a/python/hidet/ir/schedulers/cpu/scheduler.py b/python/hidet/ir/schedulers/cpu/scheduler.py index 9089c288c..9e8a5301a 100644 --- a/python/hidet/ir/schedulers/cpu/scheduler.py +++ b/python/hidet/ir/schedulers/cpu/scheduler.py @@ -26,7 +26,15 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode, params, param_map, call_args = self.grid_compute_params_and_args(node, tensor_map) if self.task is not None: - name = f'{self.task.name}_compute_{node.name}' + # name = f'{self.task.name}_compute_{node.name}' + from hidet.graph.ops.fusion.fused_operator import FusedTask + + if isinstance(self.task, FusedTask): + fused_name = self.task.attrs['fused_ops'].replace(' ', '_') + name = f'fused_{fused_name}_{node.name}' + else: + name = f'{self.task.name}_{node.name}' + else: name = f'compute_{node.name}' From 2220e8f7aa3bd7003bb3ece838753fed87d030cb Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Wed, 13 Mar 2024 16:57:22 -0400 Subject: [PATCH 148/148] commit before fixing matmul for global var --- .../hidet/graph/ops/fusion/apply_prologue_epilogue.py | 2 +- python/hidet/graph/ops/matmul/matmul_f32_x86.py | 6 +++++- python/hidet/graph/ops/matmul/resolve.py | 3 ++- python/hidet/graph/ops/softmax.py | 10 +++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py b/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py index 2b2a7f74e..ea744164d 100644 --- a/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py +++ b/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py @@ -464,7 +464,7 @@ def visit_Call(self, e: Call): func_name = e.func_var.name if func_name in self.func_records: args = self.process_call(func_name, list(e.args)) - return Call(e.func_var, args) + return Call(e.func_var, tuple(args)) return super().visit_Call(e) def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt): diff --git a/python/hidet/graph/ops/matmul/matmul_f32_x86.py b/python/hidet/graph/ops/matmul/matmul_f32_x86.py index 1883d6480..81e352d24 100644 --- a/python/hidet/graph/ops/matmul/matmul_f32_x86.py +++ b/python/hidet/graph/ops/matmul/matmul_f32_x86.py @@ -21,6 +21,7 @@ from hidet.ir.library import tune from hidet.graph.operator import Operator, Tensor from hidet.lang import attrs +from hidet.ir.expr import if_then_else class MatmulF32Taskx86(Task): @@ -320,7 +321,10 @@ def micro_kernel( packed_b_total_size = packed_b_total_width * packed_b_height packed_b_individual_size = packed_b_width * packed_b_height - packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) + # packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) # FIXME: what? Error on this line? + temp_packed_a_ind = (m_size + MR - 1) // MR * MR + packed_a_individual_height = if_then_else(temp_packed_a_ind > MR, MR, temp_packed_a_ind) + packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed packed_a_width = min(KC, (k_size + 8 - 1) // 8 * 8) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index df86fe04b..4424b1362 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -249,8 +249,9 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: return [c] def resolve(self, op: Operator) -> Optional[List[Tensor]]: + print("Here resolve is called.......") if op.device.is_cpu(): - return None + return self.resolve_generic(op) resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_f16, self.resolve_generic] for resolve_func in resolve_funcs: outs = resolve_func(op) diff --git a/python/hidet/graph/ops/softmax.py b/python/hidet/graph/ops/softmax.py index 4421eae61..d17544adb 100644 --- a/python/hidet/graph/ops/softmax.py +++ b/python/hidet/graph/ops/softmax.py @@ -157,11 +157,6 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]): return ir_module - def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: - if self.inputs[0].type.dtype != float32: - return NotImplemented # use auto-scheduler - return tune.extract_ir_modules(self.schedule_softmax_cpu) - class CPUSoftmaxTask(SoftmaxTask): def allow_epilogue(self) -> bool: @@ -170,6 +165,11 @@ def allow_epilogue(self) -> bool: def allow_prologue(self) -> bool: return False + def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: + if self.inputs[-1].type.dtype != float32: + return NotImplemented # use auto-scheduler + return tune.extract_ir_modules(self.schedule_softmax_cpu) + @tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96]) @tune.space(1, nthreads=['', 8, 16]) def schedule_softmax_cpu(self, nthreads='') -> IRModule: