Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operators] Add batch support for x86 CPU matrix multiplication + resolve rule #415

Closed
wants to merge 162 commits into from
Closed
Show file tree
Hide file tree
Changes from 148 commits
Commits
Show all changes
162 commits
Select commit Hold shift + click to select a range
efe3e14
.
BolinSNLHM May 28, 2023
b19a212
Merge branch 'hidet-org:main' into main
BolinSNLHM May 29, 2023
d7e4043
.
BolinSNLHM Jun 21, 2023
e13af0a
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM Jul 17, 2023
a7bce75
added basic openMP primitives
BolinSNLHM Jul 17, 2023
bad483c
Merge branch 'main' into omp
BolinSNLHM Aug 8, 2023
d7f6469
added those primitives back
BolinSNLHM Aug 8, 2023
f211a48
let me pretend like it's all good for tonight
BolinSNLHM Aug 13, 2023
bbb5afc
...
BolinSNLHM Aug 13, 2023
569fb49
working on refactoring
BolinSNLHM Aug 17, 2023
b32ea73
ready to be tested on the eco server
BolinSNLHM Aug 20, 2023
dbbb2b6
fix stupid error
BolinSNLHM Aug 20, 2023
014f5c1
..
BolinSNLHM Aug 20, 2023
2d82325
fix more error
BolinSNLHM Aug 21, 2023
11c9e70
..
BolinSNLHM Aug 21, 2023
4586e89
fixing hidet script error
BolinSNLHM Aug 23, 2023
65c3b9d
...:
BolinSNLHM Aug 23, 2023
286c107
....
BolinSNLHM Aug 23, 2023
bfacaf8
...
BolinSNLHM Aug 23, 2023
8246466
..
BolinSNLHM Aug 23, 2023
7518042
..
BolinSNLHM Aug 23, 2023
f8a97b2
fixing strange error
BolinSNLHM Aug 23, 2023
1a87c27
more errors
BolinSNLHM Aug 23, 2023
3104473
more err
BolinSNLHM Aug 23, 2023
68bc03d
...
BolinSNLHM Aug 23, 2023
9059ca3
...
BolinSNLHM Aug 23, 2023
df5a177
global
BolinSNLHM Aug 23, 2023
27da1ba
global var
BolinSNLHM Aug 25, 2023
fca3694
.
BolinSNLHM Aug 25, 2023
14973b4
.
BolinSNLHM Aug 25, 2023
45ad16a
...
BolinSNLHM Aug 25, 2023
36b3c52
..:
BolinSNLHM Aug 25, 2023
c79fcca
cast
BolinSNLHM Aug 25, 2023
87cdd76
cast
BolinSNLHM Aug 25, 2023
8648ced
...
BolinSNLHM Aug 25, 2023
075cc64
.
BolinSNLHM Aug 25, 2023
7814d6d
now segfault not internal errors
BolinSNLHM Aug 25, 2023
ff058bf
stupid error
BolinSNLHM Aug 26, 2023
f9f3b81
err
BolinSNLHM Aug 26, 2023
0a7b2fe
...
BolinSNLHM Aug 26, 2023
99954e1
..
BolinSNLHM Aug 26, 2023
b884a95
..
BolinSNLHM Aug 27, 2023
12a139a
.
BolinSNLHM Aug 27, 2023
8cf009d
.
BolinSNLHM Aug 27, 2023
717069f
...
BolinSNLHM Aug 27, 2023
7b53554
.
BolinSNLHM Aug 27, 2023
f933711
small fix
BolinSNLHM Aug 27, 2023
42054a4
..
BolinSNLHM Aug 27, 2023
60599c2
..
BolinSNLHM Aug 27, 2023
2d65005
.
BolinSNLHM Aug 27, 2023
747508b
.
BolinSNLHM Aug 27, 2023
4e5c7da
.
BolinSNLHM Aug 27, 2023
23f2768
try single thread first
BolinSNLHM Aug 27, 2023
0ab4888
..
BolinSNLHM Aug 27, 2023
1631d77
dumb mistake again
BolinSNLHM Aug 27, 2023
62c075c
..
BolinSNLHM Aug 27, 2023
5d4a314
..
BolinSNLHM Aug 27, 2023
e30ab31
keep debugging
BolinSNLHM Aug 28, 2023
134a1d5
..
BolinSNLHM Aug 28, 2023
e1e2d29
..
BolinSNLHM Aug 28, 2023
7a7ff5e
.
BolinSNLHM Aug 28, 2023
29de46f
..
BolinSNLHM Aug 28, 2023
ca9e67d
...
BolinSNLHM Aug 28, 2023
43d4a60
..:
BolinSNLHM Aug 28, 2023
3d67673
.
BolinSNLHM Aug 28, 2023
3c9d792
..
BolinSNLHM Aug 29, 2023
6782047
.
BolinSNLHM Aug 29, 2023
9401c1e
..
BolinSNLHM Aug 29, 2023
e655035
..
BolinSNLHM Aug 29, 2023
4c7ed70
..
BolinSNLHM Aug 29, 2023
21978bb
..
BolinSNLHM Aug 29, 2023
c90991f
..
BolinSNLHM Aug 29, 2023
7c3ef0a
continue fixing
BolinSNLHM Aug 29, 2023
4acf6c0
..
BolinSNLHM Aug 29, 2023
c740a3a
.
BolinSNLHM Aug 29, 2023
8f0ee0e
...
BolinSNLHM Aug 29, 2023
01e84ec
...
BolinSNLHM Aug 29, 2023
90505e7
..
BolinSNLHM Aug 29, 2023
805959e
...
BolinSNLHM Aug 29, 2023
8bb52d3
..
BolinSNLHM Aug 29, 2023
94abfa7
..
BolinSNLHM Aug 29, 2023
a3f35dc
.
BolinSNLHM Aug 29, 2023
230e6d0
..
BolinSNLHM Aug 29, 2023
e3bf60a
..
BolinSNLHM Aug 29, 2023
e5e4466
.
BolinSNLHM Aug 29, 2023
601e6b2
.
BolinSNLHM Aug 29, 2023
2df7355
..
BolinSNLHM Aug 29, 2023
ee30078
bruh
BolinSNLHM Aug 29, 2023
cb54a7e
..
BolinSNLHM Aug 29, 2023
8e07dad
.
BolinSNLHM Aug 29, 2023
8723df6
.
BolinSNLHM Aug 29, 2023
0919d12
..
BolinSNLHM Aug 29, 2023
b2a6c15
..
BolinSNLHM Aug 29, 2023
43922bb
..
BolinSNLHM Aug 29, 2023
553dfc4
..
BolinSNLHM Aug 29, 2023
ae29fb3
...
BolinSNLHM Aug 29, 2023
0572ace
.
BolinSNLHM Aug 29, 2023
ce1f5fd
.
BolinSNLHM Aug 29, 2023
aaa500c
..
BolinSNLHM Aug 29, 2023
6445811
.
BolinSNLHM Aug 29, 2023
d3e1a1d
.
BolinSNLHM Aug 29, 2023
6589848
..
BolinSNLHM Aug 29, 2023
4bc93c8
.
BolinSNLHM Aug 29, 2023
563b121
.
BolinSNLHM Aug 29, 2023
17011a1
.
BolinSNLHM Aug 29, 2023
18f8b53
..
BolinSNLHM Aug 29, 2023
12e44c2
..
BolinSNLHM Aug 29, 2023
ceb22dd
..
BolinSNLHM Aug 29, 2023
68fbba8
..
BolinSNLHM Aug 29, 2023
0c3639f
.
BolinSNLHM Aug 30, 2023
76d55a1
..
BolinSNLHM Aug 30, 2023
9e289e4
..
BolinSNLHM Aug 30, 2023
165c3d5
..
BolinSNLHM Aug 30, 2023
e898772
..
BolinSNLHM Aug 30, 2023
4cb35cb
.
BolinSNLHM Aug 30, 2023
6ba8075
..
BolinSNLHM Aug 30, 2023
073266a
.
BolinSNLHM Aug 30, 2023
d736d96
..
BolinSNLHM Aug 30, 2023
83118f3
.
BolinSNLHM Aug 30, 2023
df1cc83
....
BolinSNLHM Aug 30, 2023
a85e56f
..
BolinSNLHM Aug 30, 2023
728ec9a
kept debugging the matrix mul kernel
BolinSNLHM Oct 6, 2023
dfdf084
bruh
BolinSNLHM Oct 27, 2023
d2e1ab4
fixed a dumb bug that got me stuck for way too much longer than neces…
BolinSNLHM Nov 9, 2023
0c0efe0
.
BolinSNLHM Nov 9, 2023
1bd2cfe
remove prints
BolinSNLHM Nov 9, 2023
6721ed2
.
BolinSNLHM Nov 9, 2023
442fbd2
..
BolinSNLHM Nov 9, 2023
b4e00e9
logic error fix in packing of A
BolinSNLHM Nov 9, 2023
ad9c453
seems like still bugs, but they disappear with print...
BolinSNLHM Nov 10, 2023
d34f031
fix bug caused by static local vairable
BolinSNLHM Nov 11, 2023
954da89
...
BolinSNLHM Nov 15, 2023
78d09c4
fix alignment
BolinSNLHM Nov 15, 2023
838a61e
cleanup
BolinSNLHM Nov 17, 2023
6f572a4
Merge branch 'fix-zero-init' into main
BolinSNLHM Nov 17, 2023
3fbb635
ready for PR
BolinSNLHM Nov 17, 2023
656bbd0
......
BolinSNLHM Nov 17, 2023
ebcc78f
avoid changing function attributes from outside
BolinSNLHM Nov 17, 2023
fa39456
Delete python/mat_new.py
BolinSNLHM Dec 12, 2023
b61722d
Update matmul_f32_x86.py
BolinSNLHM Dec 12, 2023
575acaf
Merge branch 'hidet-org:main' into main
BolinSNLHM Dec 13, 2023
ef57171
Merge branch 'hidet-org:main' into main
BolinSNLHM Dec 20, 2023
0e7eb63
adding batch support
BolinSNLHM Dec 21, 2023
d33093e
Merge branch 'hidet-org:main' into main
BolinSNLHM Jan 10, 2024
4d9505d
.
BolinSNLHM Jan 10, 2024
091492c
fix conflict
BolinSNLHM Jan 10, 2024
170896e
resolve rule + batch support
BolinSNLHM Jan 11, 2024
71fcd6a
modify test
BolinSNLHM Jan 11, 2024
60319ca
Update python/hidet/graph/ops/matmul/matmul_f32_x86.py
BolinSNLHM Jan 12, 2024
4a6f641
Update python/hidet/graph/ops/matmul/resolve.py
BolinSNLHM Jan 12, 2024
c4152e2
Update python/hidet/graph/ops/matmul/resolve.py
BolinSNLHM Jan 12, 2024
f1bddb5
Update python/hidet/graph/ops/matmul/resolve.py
BolinSNLHM Jan 12, 2024
c2ad5de
Update tests/operators/test_matmul.py
BolinSNLHM Jan 12, 2024
95dd0fb
Update python/hidet/graph/ops/matmul/resolve.py
BolinSNLHM Jan 12, 2024
8d09697
Update python/hidet/graph/ops/matmul/resolve.py
BolinSNLHM Jan 12, 2024
40b7f51
Merge branch 'hidet-org:main' into main
BolinSNLHM Jan 18, 2024
0cc633a
Merge branch 'hidet-org:main' into main
BolinSNLHM Jan 26, 2024
94c97b8
Merge branch 'main' into gpt2-benchmark
BolinSNLHM Jan 26, 2024
b843c57
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM Jan 26, 2024
04cc68a
Merge branch 'main' into gpt2-benchmark
BolinSNLHM Jan 26, 2024
f9caaf6
resolve asdfasdf
BolinSNLHM Feb 23, 2024
2220e8f
commit before fixing matmul for global var
BolinSNLHM Mar 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=redefined-builtin
from .matmul import batch_matmul, matmul, matmul_x86, matmul_cublas
from .matmul import batch_matmul, matmul, matmul_cublas, batch_matmul_x86
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@


from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86
from .matmul_f32_x86 import matmul_x86
from .matmul_f32_x86 import batch_matmul_x86
123 changes: 65 additions & 58 deletions python/hidet/graph/ops/matmul/matmul_f32_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,41 @@
# limitations under the License.
from typing import List, Union
from hidet.ir.dtypes import float32, int32
from hidet.ir.expr import cast
from hidet.ir.expr import cast, is_constant
from hidet.ir.module import IRModule
from hidet.ir.compute import TensorNode
from hidet.ir.stmt import DeclareScope
from hidet.ir.task import Task
from hidet.ir.compute import compute, reduce
from hidet.graph.ops.utils import input_like, broadcast_shape, can_mutually_broadcast
from hidet.graph.ops.utils import input_like
from hidet.ir.library import tune
from hidet.graph.operator import Operator, Tensor
from hidet.graph.ops.utils import broadcast_indices
from hidet.lang import attrs


class MatmulF32Taskx86(Task):
def __init__(self, a: TensorNode, b: TensorNode):
a_shape = a.const_shape
b_shape = b.const_shape

if not a.type.dtype == float32 or not b.type.dtype == float32:
raise ValueError('Both inputs must be float32 tensors')

if len(a_shape) < 2 or len(b_shape) < 2:
raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a_shape, b_shape))

self._assert(
a_shape[-1] == b_shape[-2],
msg=(
'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]'
', got {} and {}'.format(a_shape, b_shape)
),
)

self._assert(
can_mutually_broadcast(a_shape[:-2], b_shape[:-2]),
msg=(
'Matrix multiplication expect tensor A and B with compatible broadcast shape, '
'got {} and {}'.format(a_shape, b_shape)
),
)

k_size = a_shape[-1]
c_shape = broadcast_shape(a_shape[:-2], b_shape[:-2]) + [a_shape[-2], b_shape[-1]]
batch_size, m_size, k_size = a.shape
batch_size, k_size, n_size = b.shape
self.batch_size = batch_size
self.m_size = m_size
self.n_size = n_size
self.k_size = k_size

c = compute(
name='c',
shape=c_shape,
fcompute=lambda *indices: reduce(
shape=[k_size],
fcompute=lambda k: a[broadcast_indices(indices[:-2], a_shape[:-2], c_shape[1:-2]) + [indices[-2], k]]
* b[broadcast_indices(indices[:-2], b_shape[:-2], c_shape[1:-2]) + [k, indices[-1]]],
reduce_type='sum',
shape=[batch_size, m_size, n_size],
fcompute=lambda r, i, j: reduce(
shape=[k_size], fcompute=lambda k: a[r, i, k] * b[r, k, j], reduce_type='sum'
),
)

super().__init__(
name='matmul_f32_x86',
inputs=[a, b],
outputs=[c],
attributes={'m_size': a_shape[-2], 'n_size': b_shape[-1], 'k_size': a_shape[-1]},
attributes={'batch_size': batch_size, 'm_size': m_size, 'n_size': n_size, 'k_size': k_size},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use this function https://github.com/hidet-org/hidet/blob/main/python/hidet/ir/compute/cops/matmul.py#L30 to add computation definition for matmul.

Like

c = cops.matmul(a, b, allow_1d=True)

)

def allow_epilogue(self) -> bool:
Expand All @@ -81,7 +57,14 @@ def allow_prologue(self) -> bool:
def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return tune.extract_ir_modules(self.schedule_matmulf32_x86)

@tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(1, 4, 2, 1)])
@tune.space(1, MC=[2016], NC=[256, 384, 512], KC=[384, 512, 560], ways=[(2, 4, 2, 2)])
@tune.space(
2,
MC=[144, 288, 432, 576, 720],
NC=[800],
KC=[256, 560, 768, 384],
ways=[(1, 4, 2, 1), (2, 4, 4, 1), (1, 4, 4, 1), (1, 2, 4, 2), (1, 4, 4, 2), (2, 4, 2, 2)],
)
def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) -> IRModule:
import hidet
from hidet.ir.type import tensor_type
Expand All @@ -94,10 +77,12 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) ->
from hidet.lang.cpu import avx_f32x8_insert_f32x4, avx_f32x8_permute2f32x4
from hidet.lang.cpu import cpu_atomic_load_n, cpu_atomic_add_fetch, cpu_atomic_fetch_xor

node_a, node_b = self.inputs[0], self.inputs[1]
a_shape = node_a.const_shape
b_shape = node_b.const_shape
m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1]
task = self

batch_size = task.batch_size
m_size = task.m_size
n_size = task.n_size
k_size = task.k_size

MR, NR = 6, 16

Expand Down Expand Up @@ -827,36 +812,58 @@ def gemm_5th_loop(

################### Start of the main kernel ###################
@hidet.script
def matmul_kernel_x86_v3(
a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size]
def matmul_kernel_x86(
a: float32[batch_size, m_size, k_size],
b: float32[batch_size, k_size, n_size],
c: float32[batch_size, m_size, n_size],
):
attrs.func_kind = 'cpu_kernel'

init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways)
init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways)
a_ptr = cast(a, ~float32)
b_ptr = cast(b, ~float32)
c_ptr = cast(c, ~float32)

parallel_attr = 'p' + str(nthreads)
# The outermost loop spawning threads
for tidx in grid(nthreads, attrs=parallel_attr):
tid_5th_loop = tidx
work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways)
comm_id_5th_loop = tid_5th_loop

gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop)

assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function)
# matmul_kernel_x86_v3.kind = "cpu_kernel"
for batch in range(batch_size):
init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways)
init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways)
# Iterate through the batch dimension, and for each batch,
# locate the corresponding a, b, and c matrices, and then call the single matmul kernel
a_matrix_size = m_size * k_size
b_matrix_size = k_size * n_size
c_matrix_size = m_size * n_size
a_matrix = as_tensor_pointer(a_ptr + (batch * a_matrix_size), dtype=float32, shape=[m_size, k_size])
b_matrix = as_tensor_pointer(b_ptr + (batch * b_matrix_size), dtype=float32, shape=[k_size, n_size])
c_matrix = as_tensor_pointer(c_ptr + (batch * c_matrix_size), dtype=float32, shape=[m_size, n_size])
for tidx in grid(nthreads, attrs=parallel_attr):
tid_5th_loop = tidx
work_id_5th_loop = tid_5th_loop // (nthreads // loop5_nways)
comm_id_5th_loop = tid_5th_loop

gemm_5th_loop(a_matrix, b_matrix, c_matrix, work_id_5th_loop, comm_id_5th_loop)

assert isinstance(matmul_kernel_x86, hidet.ir.Function)
# matmul_kernel_x86.kind = "cpu_kernel"
ir_module = module.ir_module()
return ir_module


class Matmulx86Op(Operator):
def __init__(self, a: Tensor, b: Tensor):
if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]):
raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape))
# if not (len(a.shape) == len(b.shape) == 2 and a.shape[1] == b.shape[0]):
# raise ValueError('Matrix multiplication: incompatible sizes: {} and {}'.format(a.shape, b.shape))
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
if not (
len(a.shape) == len(b.shape) == 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the same template to support matmul like:

  • [12, 1024, 1024] @ [1024, 1024]
  • [12, 1024, 1024] @ [1, 1024, 1024]
  • [1024, 1024] @ [4, 5, 1024, 1024]

You can have a look at https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/matmul/batch_matmul.py as a reference.

and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0])
and (not is_constant(a.shape[2], b.shape[1]) or a.shape[2] == b.shape[1])
Comment on lines +859 to +860
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you require the shape of a and b constant?

Is it possible to support dynamic shape like https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/ops/matmul/batch_matmul.py

):
raise ValueError(
"Matrix multiplication expects tensor A and B with shape [B, M, K] and [B, K, N]"
+ ", got {} and {}".format(a.shape, b.shape)
)
task = MatmulF32Taskx86(input_like(a, 'a'), input_like(b, 'b'))
super().__init__(inputs=[a, b], attributes={}, task=task)


def matmul_x86(a: Tensor, b: Tensor) -> Tensor:
def batch_matmul_x86(a: Tensor, b: Tensor) -> Tensor:
return Matmulx86Op(a, b).outputs[0]
25 changes: 19 additions & 6 deletions python/hidet/graph/ops/matmul/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .matmul import MatmulOp
from .batch_matmul import batch_matmul
from .matmul_f32_x86 import batch_matmul_x86
from .matmul_f16 import matmul_f16
from ..transform import broadcast, flatten
from ..utils import broadcast_shapes
Expand Down Expand Up @@ -90,7 +91,7 @@ class MatmulResolveRule(ResolveRule):
The generic matrix multiplication operator has the same semantics as numpy.matmul that accepts
variable dimensions of inputs.

On ther other hand, the batched matrix multiplication operator accepts inputs with shape:
On the other hand, the batched matrix multiplication operator accepts inputs with shape:
[batch_size, m_size, k_size] x [batch_size, k_size, n_size]

This resolve rule also parallelize k dimension when possible, and determine the mma instruction.
Expand Down Expand Up @@ -125,6 +126,9 @@ def run_batch_matmul(self, a: Tensor, b: Tensor) -> Tensor:
c = batch_matmul(aa, bb, mma=mma).reshape([batch_size, nparts, m_size, n_size]).sum(1)
return c

def run_batch_matmul_cpu(self, a: Tensor, b: Tensor) -> Tensor:
return batch_matmul_x86(a, b)

def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, MatmulOp)
a: Tensor = op.inputs[0]
Expand All @@ -133,30 +137,38 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
if a.dtype.nbytes > 4 or b.dtype.nbytes > 4:
return None

run_func = self.run_batch_matmul
if op.device.is_cpu():
run_func = self.run_batch_matmul_cpu

if len(a.shape) == 1: # shape: [a]
a = a.unsqueeze([0, 1]) # [1, 1, a]
if len(b.shape) == 2: # shape: [a, b]
# [a] x [a, b] -> [b]
b = b.unsqueeze([0]) # [1, a, b]
c = self.run_batch_matmul(a, b) # [1, 1, b]
# c = self.run_batch_matmul(a, b) # [1, 1, b] FIXME: Delete later
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
c = run_func(a, b) # [1, 1, b]
c = c.squeeze([0, 1]) # [b]
else:
assert len(b.shape) >= 3 # shape example: [b, c, a, d]
# [a] x [b, c, a, d] -> [b, c, d]
b = flatten(b, start_dim=0, end_dim=-3) # [b * c, a, d]
c = self.run_batch_matmul(a, b) # [b * c, 1, d]
# c = self.run_batch_matmul(a, b) # [b * c, 1, d] FIXME: Delete later
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
c = run_func(a, b) # [b * c, 1, d]
c = c.reshape(c_shape) # [b, c, d]
elif len(b.shape) == 1: # shape: [b]
b = b.unsqueeze([0, 2]) # [1, b, 1]
if len(a.shape) == 2: # shape: [a, b]
a = a.unsqueeze([0]) # [1, a, b]
c = self.run_batch_matmul(a, b) # [1, a, 1]
# c = self.run_batch_matmul(a, b) # [1, a, 1] FIXME: Delete later
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
c = run_func(a, b) # [1, a, 1]
c = c.squeeze([0, 2]) # [a]
else:
assert len(a.shape) >= 3 # shape example: [a, c, d, b]
# [a, c, d, b] x [b] -> [a, c, d]
a = flatten(a, start_dim=0, end_dim=-3) # [a * c, d, b]
c = self.run_batch_matmul(a, b) # [a * c, d, 1]
# c = self.run_batch_matmul(a, b) # [a * c, d, 1] FIXME: Delete later
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
c = run_func(a, b) # [a * c, d, 1]
c = c.reshape(c_shape) # [a, c, d]
else:
# example: [a, b, c] x [c, d] -> [a, b, d]
Expand All @@ -168,7 +180,8 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
b_broadcast_shape = c_head + list(b.shape[-2:])
a = flatten(broadcast(a, a_broadcast_shape), start_dim=0, end_dim=-3)
b = flatten(broadcast(b, b_broadcast_shape), start_dim=0, end_dim=-3)
c = self.run_batch_matmul(a, b)
# c = self.run_batch_matmul(a, b) FIXME: Delete later
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
c = run_func(a, b)
c = c.reshape(c_shape)
return [c]

Expand Down
6 changes: 2 additions & 4 deletions tests/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
from hidet.testing import check_binary, check_binary_dynamic, check_torch_binary


# @pytest.mark.skip(reason="when running matmul_x86 multiple times, it will produce wrong result. need fix.")
@pytest.mark.parametrize("a_shape, b_shape", [[[333, 444], [444, 555]], [[133, 1], [1, 177]]])
@pytest.mark.parametrize("a_shape, b_shape", [[[1, 333, 444], [1, 444, 555]], [[1, 133, 1], [1, 1, 177]]])
def test_matmul_x86(a_shape, b_shape):
# TODO: Doesn't support broadcasting yet; need to add it later?
check_binary(
a_shape,
b_shape,
lambda x, y: np.matmul(x, y),
lambda x, y: ops.matmul_x86(x, y) - ops.matmul_x86(x, y) + ops.matmul_x86(x, y),
lambda x, y: ops.batch_matmul_x86(x, y) - ops.batch_matmul_x86(x, y) + ops.batch_matmul_x86(x, y),
BolinSNLHM marked this conversation as resolved.
Show resolved Hide resolved
dtype="float32",
atol=1e-4,
rtol=1e-4,
Expand Down