Skip to content

Commit

Permalink
Merge branch 'hidet-org:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
BolinSNLHM authored Jan 26, 2024
2 parents 40b7f51 + 072a606 commit 0cc633a
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 104 deletions.
21 changes: 14 additions & 7 deletions .github/scripts/bench/bench_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def bench_matmul_f16(params: str, *args, **kwargs) -> float:
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, inputs=[a, b])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_batch_matmul(params: str, *args, **kwargs) -> float:
# Default to benchmarking f32 for now, though this op can run other dtypes
Expand All @@ -26,7 +27,8 @@ def bench_batch_matmul(params: str, *args, **kwargs) -> float:
c = hidet.ops.matmul(a, b)
g = hidet.trace_from(c, inputs=[a, b])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_conv2d(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
Expand All @@ -37,7 +39,8 @@ def bench_conv2d(params: str, *args, **kwargs) -> float:
o = hidet.ops.conv2d(x, w)
g = hidet.trace_from(o, inputs=[x, w])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
Expand All @@ -48,7 +51,8 @@ def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
o = hidet.ops.conv2d(x, w)
g = hidet.trace_from(o, inputs=[x, w])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_attn(params: str, *args, **kwargs) -> float:
bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')]
Expand All @@ -61,7 +65,8 @@ def bench_attn(params: str, *args, **kwargs) -> float:
o = hidet.ops.attention(q, k, v)
g = hidet.trace_from(o, inputs=[q, k, v])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_attn_mask_add(params: str, *args, **kwargs) -> float:
bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')]
Expand All @@ -76,7 +81,8 @@ def bench_attn_mask_add(params: str, *args, **kwargs) -> float:
o = hidet.ops.attention(q, k, v, mask=mask)
g = hidet.trace_from(o, inputs=[q, k, v, mask])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_reduce(params: str, *args, **kwargs) -> float:
x_shape, axis = params.split(',', maxsplit=1)
Expand All @@ -88,7 +94,8 @@ def bench_reduce(params: str, *args, **kwargs) -> float:
o = hidet.ops.sum(x, dims=axis)
g = hidet.trace_from(o, inputs=[x])
g = hidet.graph.optimize(g)
return g.latency()
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

bench_func_map = {
'matmul_f16': bench_matmul_f16,
Expand Down
13 changes: 7 additions & 6 deletions .github/scripts/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10):
return latency

def enable_compile_server(enable=True):
hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME'))
hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT')))
hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME'))
hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD'))
hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip())
hidet.option.compile_server.enable(flag=enable)
if os.environ.get('CI_CS_HOSTNAME'):
hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME'))
hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT')))
hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME'))
hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD'))
hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip())
hidet.option.compile_server.enable(flag=enable)
2 changes: 1 addition & 1 deletion gallery/how-to-guides/visualize-flow-graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor):
#
# You can download the generated json file
# :download:`attention-graph.json <../../../../gallery/how-to-guides/attention-graph.json>`
# and open it with the `customized Netron viewer </netron>`_.
# and open it with the `customized Netron viewer </docs/netron>`_.
#

# %%
Expand Down
1 change: 0 additions & 1 deletion python/hidet/graph/ops/fusion/apply_prologue_epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,6 @@ def process_module(self, ir_module: IRModule) -> IRModule:
try:
rewriter = PrologueEpilogueFuseRewriter(self.fused_task, prologues, epilogues, tensor_map, marks)
ir_module = rewriter.rewrite(ir_module)
print('success')
return ir_module
except CanNotFuseError:
pass
Expand Down
80 changes: 73 additions & 7 deletions python/hidet/graph/ops/matmul/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def implement_cuda(self, working_dir: str) -> List[IRModule]:
spatial(2, 1) * spatial(1, 8) * spatial(2, 1),
],
warp_inner=[(4, 4)],
use_cublas=[True, False],
)
@tune.space(
1,
Expand All @@ -84,12 +85,41 @@ def implement_cuda(self, working_dir: str) -> List[IRModule]:
warp_outer=[(1, 1), (1, 2), (2, 1), (2, 2)],
warp_mid=[spatial(4, 8)],
warp_inner=[(4, 4), (4, 8), (8, 4)],
use_cublas=[True, False],
)
def schedule_simt(
self, block_warps_k=8, block_warps=(4, 2), warp_outer=(2, 2), warp_mid=spatial(4, 8), warp_inner=(4, 4)
self,
block_warps_k=8,
block_warps=(4, 2),
warp_outer=(2, 2),
warp_mid=spatial(4, 8),
warp_inner=(4, 4),
use_cublas=False,
) -> IRModule:
task = self
dtype = task.inputs[0].type.dtype

if use_cublas:
from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule

a_shape = task.inputs[0].type.shape
b_shape = task.inputs[1].type.shape
c_shape = task.outputs[0].type.shape
# Hack to reduce redundant schedules. When use_cublas == False, other tuning params are irrelevant
# and we only need one copy of the schedule.
from hidet.ir.mapping import SpatialTaskMapping

schedule_filter = (
block_warps_k == 8
and block_warps == (1, 1)
and warp_outer == (1, 1)
and isinstance(warp_mid, SpatialTaskMapping)
and warp_mid.task_shape == (4, 8)
and warp_inner == (4, 4)
)
tune.check(schedule_filter)
return get_cublas_matmul_schedule(a_shape, b_shape, c_shape, dtype, dtype, dtype)

warp_k = 1

# Task Layouts
Expand Down Expand Up @@ -369,6 +399,7 @@ def batch_matmul_kernel(
warp_n=[16, 32, 64],
warp_k=[8, 16, 32],
mma_config=MmaConfig.all(),
use_cublas=[True, False],
)
@tune.space(
1,
Expand All @@ -379,9 +410,18 @@ def batch_matmul_kernel(
warp_n=[32, 64],
warp_k=[8, 16, 32],
mma_config=MmaConfig.all(),
use_cublas=[True, False],
)
def schedule_mma(
self, block_m=64, block_n=64, block_k=16, warp_m=32, warp_n=32, warp_k=16, mma_config: MmaConfig = None
self,
block_m=64,
block_n=64,
block_k=16,
warp_m=32,
warp_n=32,
warp_k=16,
mma_config: MmaConfig = None,
use_cublas=False,
) -> IRModule:
def resolve_mma_type(a_dtype: DataType, b_dtype: DataType, c_dtype: DataType):
dtype_rank = {'float16': 0, 'bfloat16': 1, 'tfloat32': 2, 'float32': 4}
Expand All @@ -398,9 +438,35 @@ def resolve_mma_type(a_dtype: DataType, b_dtype: DataType, c_dtype: DataType):

task = self

input_a, input_b, input_c = task.inputs[0], task.inputs[1], task.outputs[0]
input_a_dtype, input_b_dtype, input_c_dtype = [t.type.dtype for t in [input_a, input_b, input_c]]
mma_type = resolve_mma_type(input_a_dtype, input_b_dtype, input_c_dtype)
input_a, input_b, output_c = task.inputs[0], task.inputs[1], task.outputs[0]
input_a_dtype, input_b_dtype, output_c_dtype = [t.type.dtype for t in [input_a, input_b, output_c]]
input_a_shape, input_b_shape, output_c_shape = [t.type.shape for t in [input_a, input_b, output_c]]

if use_cublas:
from hidet.graph.ops.utils.schedule_utils import get_cublas_matmul_schedule

# Hack to reduce redundant schedules. When use_cublas == False, other tuning params are irrelevant
# and we only need one copy of the schedule.
schedule_filter = (
block_m == 64
and block_n == 64
and block_k == 8
and warp_m == 32
and warp_n == 32
and warp_k == 8
and mma_config
and mma_config.m == 16
and mma_config.n == 8
and mma_config.k == 8
and mma_config.input_dtype == 'f16'
and mma_config.output_dtype == 'f16'
)
tune.check(schedule_filter)
return get_cublas_matmul_schedule(
input_a_shape, input_b_shape, output_c_shape, input_a_dtype, input_b_dtype, output_c_dtype
)

mma_type = resolve_mma_type(input_a_dtype, input_b_dtype, output_c_dtype)

# Resolve parameters when space level is 0
if mma_config is None:
Expand Down Expand Up @@ -549,7 +615,7 @@ def copy_b_s2r(
@hidet.script
def copy_c_r2g(
regs_c: TensorType(dtype=c_dtype, layout=regs_c_layout),
c: input_c_dtype[bs, m_size, n_size],
c: output_c_dtype[bs, m_size, n_size],
offset_m: i32,
offset_n: i32,
smem: void_p,
Expand Down Expand Up @@ -610,7 +676,7 @@ def mma(
def batch_matmul_kernel(
a: input_a_dtype[bs, m_size, k_size],
b: input_b_dtype[bs, k_size, n_size],
c: input_c_dtype[bs, m_size, n_size],
c: output_c_dtype[bs, m_size, n_size],
):
attrs.cuda.grid_dim = (m_tiles * n_tiles, bs)
attrs.cuda.block_dim = block_size
Expand Down
82 changes: 5 additions & 77 deletions python/hidet/graph/ops/matmul/matmul_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, Optional
from typing import Union, Optional

import hidet
from hidet.ir.module import IRModule
from hidet.ir.type import DataType
from hidet.ir.expr import Expr, is_true
from hidet.ir.dtypes import f16, f32
from hidet.utils import prod
from hidet.ir.expr import Expr
from hidet.cuda.cublas import cublasComputeType
from ..utils import Task, Operator, Tensor, input_like
from ..utils import TensorInput
from ..utils.schedule_utils import convert_to_cublas_strided_gemm, resolve_cublas_compute_type


class CublasMatmulTask(Task):
Expand All @@ -30,83 +28,13 @@ def __init__(self, a: TensorInput, b: TensorInput, compute_type: Optional[Union[
if a.type.dtype != b.type.dtype:
raise ValueError('dtype of a and b must be the same, got {} and {}'.format(a.type.dtype, b.type.dtype))

self.compute_type: cublasComputeType = self.resolve_compute_type(a.type.dtype, a.type.dtype, compute_type)
self.compute_type: cublasComputeType = resolve_cublas_compute_type(a.type.dtype, a.type.dtype, compute_type)

c = cops.matmul(a, b, allow_1d=True)
super().__init__(
name='cublas_matmul', inputs=[a, b], outputs=[c], attributes={'compute_type': self.compute_type}
)

def resolve_compute_type(
self, in_dtype: DataType, out_dtype: DataType, compute_type: Optional[Union[int, cublasComputeType]]
) -> cublasComputeType:
if compute_type is not None:
return cublasComputeType(compute_type)
if in_dtype == out_dtype == f16:
# use tensor core whenever possible
return cublasComputeType.CUBLAS_COMPUTE_16F
elif in_dtype == out_dtype == f32:
# use tensor core whenever possible
return cublasComputeType.CUBLAS_COMPUTE_32F
else:
raise NotImplementedError(
'not implemented resolve rules for compute_type with in_dtype={}, out_dtype={}'.format(
in_dtype, out_dtype
)
)

def convert_to_strided_gemm(self, a_shape: List[Expr], b_shape: List[Expr], c_shape: List[Expr]):
a_rank: int = len(a_shape)
b_rank: int = len(b_shape)

assert a_rank >= 1 and b_rank >= 1 and (a_rank >= 2 or b_rank >= 2)
if a_rank == 1:
bs = prod(b_shape[:-2])
m = 1
n = b_shape[-1]
k = a_shape[0]
stride_a = 0
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
elif b_rank == 1:
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = 1
k = b_shape[0]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = 0
stride_c = c_shape[-1]
else:
if is_true(prod(a_shape[:-2]) == 1):
bs = prod(b_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = 0
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
elif is_true(prod(b_shape[:-2]) == 1):
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = 0
stride_c = c_shape[-2] * c_shape[-1]
elif all(is_true(a == b) for a, b in zip(a_shape[:-2], b_shape[:-2])):
bs = prod(a_shape[:-2])
m = a_shape[-2]
n = b_shape[-1]
k = a_shape[-1]
stride_a = a_shape[-2] * a_shape[-1]
stride_b = b_shape[-2] * b_shape[-1]
stride_c = c_shape[-2] * c_shape[-1]
else:
# todo: add cublasGemmBatchedEx to support this case
# https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmbatchedex
raise NotImplementedError('Can not convert matmul {} @ {} to strided_gemm'.format(a_shape, b_shape))
return bs, m, n, k, stride_a, stride_b, stride_c

def implement_cuda(self, working_dir: str) -> IRModule:
from hidet.lang import attrs
from hidet.lang.cuda import cublas
Expand All @@ -120,7 +48,7 @@ def implement_cuda(self, working_dir: str) -> IRModule:
with hidet.script_module() as script_module:

def generate(a: Expr, b: Expr, c: Expr) -> Expr:
bs, m, n, k, stride_a, stride_b, stride_c = self.convert_to_strided_gemm(a_shape, b_shape, c_shape)
bs, m, n, k, stride_a, stride_b, stride_c = convert_to_cublas_strided_gemm(a_shape, b_shape, c_shape)
return cublas.strided_gemm(
bs,
m,
Expand Down
Loading

0 comments on commit 0cc633a

Please sign in to comment.