Skip to content

Commit

Permalink
remove shady batch mat mul
Browse files Browse the repository at this point in the history
  • Loading branch information
fishingguy456 committed Jan 4, 2024
1 parent 587ba64 commit 89d5646
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 35 deletions.
60 changes: 26 additions & 34 deletions python/hidet/graph/ops/matmul/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from .matmul import MatmulOp
from .batch_matmul import batch_matmul
from .matmul_f32_x86 import matmul_x86
from .matmul_f16 import matmul_f16
from ..transform import broadcast, flatten
from ..utils import broadcast_shapes
Expand Down Expand Up @@ -98,40 +97,33 @@ class MatmulResolveRule(ResolveRule):
"""

def run_batch_matmul(self, a: Tensor, b: Tensor) -> Tensor:
if a.device.is_cpu():
cc = [matmul_x86(a[i], b[i]).unsqueeze(0) for i in range(a.shape[0])]
c = cc[0]
for i in range(1, a.shape[0]):
c = hidet.ops.concat([cc[i], c], axis=0)
return c
else:
parallel_k = self.get_config('parallel_k', default='default') # 'default', 'search', 2, 4, ...
mma = self.get_config('mma', default='simt') # 'simt', 'mma'
parallel_k = self.get_config('parallel_k', default='default') # 'default', 'search', 2, 4, ...
mma = self.get_config('mma', default='simt') # 'simt', 'mma'

if any(not isinstance(v, int) for v in a.shape + b.shape):
if any(not isinstance(v, int) for v in a.shape + b.shape):
nparts = 1
else:
batch_size, m_size, n_size, k_size = a.shape[0], a.shape[1], b.shape[2], a.shape[2]
if parallel_k == 'default':
nparts = parallel_k_heuristic_nparts(batch_size, m_size, n_size, k_size)
elif parallel_k == 'search':
nparts = parallel_k_search_nparts(a.dtype.name, mma, batch_size, m_size, n_size, k_size)
elif parallel_k == 'disabled':
nparts = 1
elif isinstance(parallel_k, int):
nparts = gcd(parallel_k, k_size)
else:
batch_size, m_size, n_size, k_size = a.shape[0], a.shape[1], b.shape[2], a.shape[2]
if parallel_k == 'default':
nparts = parallel_k_heuristic_nparts(batch_size, m_size, n_size, k_size)
elif parallel_k == 'search':
nparts = parallel_k_search_nparts(a.dtype.name, mma, batch_size, m_size, n_size, k_size)
elif parallel_k == 'disabled':
nparts = 1
elif isinstance(parallel_k, int):
nparts = gcd(parallel_k, k_size)
else:
raise ValueError(f'invalid parallel_k: {parallel_k}')
raise ValueError(f'invalid parallel_k: {parallel_k}')

if nparts == 1:
c = batch_matmul(a, b, mma=mma)
else:
# [batch_size * nparts, m_size, k_size // nparts]
aa = a.reshape([batch_size, m_size, nparts, k_size // nparts]).rearrange([[0, 2], [1], [3]])
# [batch_size * nparts, k_size // nparts, n_size]
bb = b.reshape([batch_size, nparts, k_size // nparts, n_size]).rearrange([[0, 1], [2], [3]])
c = batch_matmul(aa, bb, mma=mma).reshape([batch_size, nparts, m_size, n_size]).sum(1)
return c
if nparts == 1:
c = batch_matmul(a, b, mma=mma)
else:
# [batch_size * nparts, m_size, k_size // nparts]
aa = a.reshape([batch_size, m_size, nparts, k_size // nparts]).rearrange([[0, 2], [1], [3]])
# [batch_size * nparts, k_size // nparts, n_size]
bb = b.reshape([batch_size, nparts, k_size // nparts, n_size]).rearrange([[0, 1], [2], [3]])
c = batch_matmul(aa, bb, mma=mma).reshape([batch_size, nparts, m_size, n_size]).sum(1)
return c

def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, MatmulOp)
Expand Down Expand Up @@ -186,9 +178,6 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:
# if op.task.has_symbolic_shape():
# return None

if op.device.is_cpu():
return None

a: Tensor = op.inputs[0]
b: Tensor = op.inputs[1]
c: Tensor = op.outputs[0]
Expand Down Expand Up @@ -251,9 +240,12 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:
return [c]

def resolve(self, op: Operator) -> Optional[List[Tensor]]:
if op.device.is_cpu():
return None
resolve_funcs: List[Callable[[Operator], Any]] = [self.resolve_f16, self.resolve_generic]
for resolve_func in resolve_funcs:
outs = resolve_func(op)
if outs is not None:
return outs
return None

2 changes: 1 addition & 1 deletion python/hidet/graph/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,4 @@ def softmax_cpu_kernel(x: float32[shape], out: float32[shape]):

assert isinstance(softmax_cpu_kernel, hidet.ir.Function)
ir_module = module.ir_module()
return ir_module
return ir_module

0 comments on commit 89d5646

Please sign in to comment.