-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operators] Add batch support for x86 CPU matrix multiplication + resolve rule #415
Changes from all commits
efe3e14
b19a212
d7e4043
e13af0a
a7bce75
bad483c
d7f6469
f211a48
bbb5afc
569fb49
b32ea73
dbbb2b6
014f5c1
2d82325
11c9e70
4586e89
65c3b9d
286c107
bfacaf8
8246466
7518042
f8a97b2
1a87c27
3104473
68bc03d
9059ca3
df5a177
27da1ba
fca3694
14973b4
45ad16a
36b3c52
c79fcca
87cdd76
8648ced
075cc64
7814d6d
ff058bf
f9f3b81
0a7b2fe
99954e1
b884a95
12a139a
8cf009d
717069f
7b53554
f933711
42054a4
60599c2
2d65005
747508b
4e5c7da
23f2768
0ab4888
1631d77
62c075c
5d4a314
e30ab31
134a1d5
e1e2d29
7a7ff5e
29de46f
ca9e67d
43d4a60
3d67673
3c9d792
6782047
9401c1e
e655035
4c7ed70
21978bb
c90991f
7c3ef0a
4acf6c0
c740a3a
8f0ee0e
01e84ec
90505e7
805959e
8bb52d3
94abfa7
a3f35dc
230e6d0
e3bf60a
e5e4466
601e6b2
2df7355
ee30078
cb54a7e
8e07dad
8723df6
0919d12
b2a6c15
43922bb
553dfc4
ae29fb3
0572ace
ce1f5fd
aaa500c
6445811
d3e1a1d
6589848
4bc93c8
563b121
17011a1
18f8b53
12e44c2
ceb22dd
68fbba8
0c3639f
76d55a1
9e289e4
165c3d5
e898772
4cb35cb
6ba8075
073266a
d736d96
83118f3
df1cc83
a85e56f
728ec9a
dfdf084
d2e1ab4
0c0efe0
1bd2cfe
6721ed2
442fbd2
b4e00e9
ad9c453
d34f031
954da89
78d09c4
838a61e
6f572a4
3fbb635
656bbd0
ebcc78f
fa39456
b61722d
575acaf
ef57171
0e7eb63
d33093e
4d9505d
091492c
170896e
71fcd6a
60319ca
4a6f641
c4152e2
f1bddb5
c2ad5de
95dd0fb
8d09697
40b7f51
0cc633a
94c97b8
b843c57
04cc68a
f9caaf6
2220e8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,65 +11,42 @@ | |
# limitations under the License. | ||
from typing import List, Union | ||
from hidet.ir.dtypes import float32, int32 | ||
from hidet.ir.expr import cast | ||
from hidet.ir.expr import cast, is_constant | ||
from hidet.ir.module import IRModule | ||
from hidet.ir.compute import TensorNode | ||
from hidet.ir.stmt import DeclareScope | ||
from hidet.ir.task import Task | ||
from hidet.ir.compute import compute, reduce | ||
from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast | ||
from hidet.graph.ops.utils import input_like | ||
from hidet.ir.library import tune | ||
from hidet.graph.operator import Operator, Tensor | ||
from hidet.graph.ops.utils import broadcast_indices | ||
from hidet.lang import attrs | ||
from hidet.ir.expr import if_then_else | ||
|
||
|
||
class MatmulF32Taskx86(Task): | ||
def __init__(self, a: TensorNode, b: TensorNode): | ||
a_shape = a.const_shape | ||
b_shape = b.const_shape | ||
|
||
if not a.type.dtype == float32 or not b.type.dtype == float32: | ||
raise ValueError('Both inputs must be float32 tensors') | ||
|
||
if len(a_shape) < 2 or len(b_shape) < 2: | ||
raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape)) | ||
|
||
self._assert( | ||
a_shape[-1] == b_shape[-2], | ||
msg=( | ||
'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]' | ||
', got {} and {}'.format(a_shape, b_shape) | ||
), | ||
) | ||
|
||
self._assert( | ||
can_mutually_broadcast(a_shape[:-2], b_shape[:-2]), | ||
msg=( | ||
'Matrix multiplication expect tensor A and B with compatible broadcast shape, ' | ||
'got {} and {}'.format(a_shape, b_shape) | ||
), | ||
) | ||
|
||
k_size = a_shape[-1] | ||
c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]] | ||
batch_size, m_size, k_size = a.shape | ||
batch_size, k_size, n_size = b.shape | ||
self.batch_size = batch_size | ||
self.m_size = m_size | ||
self.n_size = n_size | ||
self.k_size = k_size | ||
|
||
c = compute( | ||
name='c', | ||
shape=c_shape, | ||
fcompute=lambda *indices: reduce( | ||
shape=[k_size], | ||
fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]] | ||
* b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]], | ||
reduce_type='sum', | ||
shape=[batch_size, m_size, n_size], | ||
fcompute=lambda r, i, j: reduce( | ||
shape=[k_size], fcompute=lambda k: a[r, i, k] * b[r, k, j], reduce_type='sum' | ||
), | ||
) | ||
|
||
super().__init__( | ||
name='matmul_f32_x86', | ||
inputs=[a, b], | ||
outputs=[c], | ||
attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]}, | ||
attributes={'batch_size': batch_size, 'm_size': m_size, 'n_size': n_size, 'k_size': k_size}, | ||
) | ||
|
||
def allow_epilogue(self) -> bool: | ||
|
@@ -81,7 +58,14 @@ def allow_prologue(self) -> bool: | |
def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: | ||
return tune.extract_ir_modules(self.schedule_matmulf32_x86) | ||
|
||
@tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)]) | ||
@tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(2, 4, 2, 2)]) | ||
@tune.space( | ||
2, | ||
MC=[144, 288, 432, 576, 720], | ||
NC=[800], | ||
KC=[256, 560, 768, 384], | ||
ways=[(1, 4, 2, 1), (2, 4, 4, 1), (1, 4, 4, 1), (1, 2, 4, 2), (1, 4, 4, 2), (2, 4, 2, 2)], | ||
) | ||
def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> IRModule: | ||
import hidet | ||
from hidet.ir.type import tensor_type | ||
|
@@ -94,10 +78,12 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> | |
from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4 | ||
from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor | ||
|
||
node_a, node_b = self.inputs[0], self.inputs[1] | ||
a_shape = node_a.const_shape | ||
b_shape = node_b.const_shape | ||
m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1] | ||
task = self | ||
|
||
batch_size = task.batch_size | ||
m_size = task.m_size | ||
n_size = task.n_size | ||
k_size = task.k_size | ||
|
||
MR, NR = 6, 16 | ||
|
||
|
@@ -335,7 +321,10 @@ def micro_kernel( | |
packed_b_total_size = packed_b_total_width * packed_b_height | ||
packed_b_individual_size = packed_b_width * packed_b_height | ||
|
||
packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) | ||
# packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR) # FIXME: what? Error on this line? | ||
temp_packed_a_ind = (m_size + MR - 1) // MR * MR | ||
packed_a_individual_height = if_then_else(temp_packed_a_ind > MR, MR, temp_packed_a_ind) | ||
|
||
packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed | ||
|
||
packed_a_width = min(KC, (k_size + 8 - 1) // 8 * 8) | ||
|
@@ -827,36 +816,56 @@ def gemm_5th_loop( | |
|
||
################### Start of the main kernel ################### | ||
@hidet.script | ||
def matmul_kernel_x86_v3( | ||
a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size] | ||
def matmul_kernel_x86( | ||
a: float32[batch_size, m_size, k_size], | ||
b: float32[batch_size, k_size, n_size], | ||
c: float32[batch_size, m_size, n_size], | ||
): | ||
attrs.func_kind = 'cpu_kernel' | ||
|
||
init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) | ||
init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) | ||
a_ptr = cast(a, ~float32) | ||
b_ptr = cast(b, ~float32) | ||
c_ptr = cast(c, ~float32) | ||
|
||
parallel_attr = 'p' + str(nthreads) | ||
# The outermost loop spawning threads | ||
for tidx in grid(nthreads, attrs=parallel_attr): | ||
tid_5th_loop = tidx | ||
work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) | ||
comm_id_5th_loop = tid_5th_loop | ||
|
||
gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop) | ||
|
||
assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function) | ||
# matmul_kernel_x86_v3.kind = "cpu_kernel" | ||
for batch in range(batch_size): | ||
init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways) | ||
init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways) | ||
# Iterate through the batch dimension, and for each batch, | ||
# locate the corresponding a, b, and c matrices, and then call the single matmul kernel | ||
a_matrix_size = m_size * k_size | ||
b_matrix_size = k_size * n_size | ||
c_matrix_size = m_size * n_size | ||
a_matrix = as_tensor_pointer(a_ptr + (batch * a_matrix_size), dtype=float32, shape=[m_size, k_size]) | ||
b_matrix = as_tensor_pointer(b_ptr + (batch * b_matrix_size), dtype=float32, shape=[k_size, n_size]) | ||
c_matrix = as_tensor_pointer(c_ptr + (batch * c_matrix_size), dtype=float32, shape=[m_size, n_size]) | ||
for tidx in grid(nthreads, attrs=parallel_attr): | ||
tid_5th_loop = tidx | ||
work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways) | ||
comm_id_5th_loop = tid_5th_loop | ||
|
||
gemm_5th_loop(a_matrix, b_matrix, c_matrix, work_id_5th_loop, comm_id_5th_loop) | ||
|
||
assert isinstance(matmul_kernel_x86, hidet.ir.Function) | ||
# matmul_kernel_x86.kind = "cpu_kernel" | ||
ir_module = module.ir_module() | ||
return ir_module | ||
|
||
|
||
class Matmulx86Op(Operator): | ||
def __init__(self, a: Tensor, b: Tensor): | ||
if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]): | ||
raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape)) | ||
if not ( | ||
len(a.shape) == len(b.shape) == 3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use the same template to support matmul like:
You can have a look at https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/matmul/batch_matmul.py as a reference. |
||
and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0]) | ||
and (not is_constant(a.shape[2], b.shape[1]) or a.shape[2] == b.shape[1]) | ||
Comment on lines
+859
to
+860
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you require the shape of a and b constant? Is it possible to support dynamic shape like https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/matmul/batch_matmul.py |
||
): | ||
raise ValueError( | ||
"Matrix multiplication expects tensor A and B with shape [B, M, K] and [B, K, N]" | ||
+ ", got {} and {}".format(a.shape, b.shape) | ||
) | ||
task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b')) | ||
super().__init__(inputs=[a, b], attributes={}, task=task) | ||
|
||
|
||
def matmul_x86(a: Tensor, b: Tensor) -> Tensor: | ||
def batch_matmul_x86(a: Tensor, b: Tensor) -> Tensor: | ||
return Matmulx86Op(a, b).outputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use this function https://github.com/hidet-org/hidet/blob/main/python/hidet/ir/compute/cops/matmul.py#L30 to add computation definition for matmul.
Like
hidet/python/hidet/graph/ops/matmul/matmul.py
Line 20 in 359c5f3