diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 901ede4d2..599b403c9 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -12,7 +12,15 @@ # pylint: disable=redefined-builtin from .conv1d import conv1d from .conv1d_transpose import conv1d_transpose -from .conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform +from .conv2d import ( + conv2d, + conv2d_channel_last, + conv2d_winograd, + conv2d_gemm, + conv2d_gemm_fp16, + conv2d_gemm_fp16_channel_last, + conv2d_gemm_image_transform, +) from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm from .conv3d import conv3d, conv3d_gemm from .conv3d_transpose import conv3d_transpose diff --git a/python/hidet/graph/ops/conv2d/__init__.py b/python/hidet/graph/ops/conv2d/__init__.py index 3f4aca142..2719f0fe0 100644 --- a/python/hidet/graph/ops/conv2d/__init__.py +++ b/python/hidet/graph/ops/conv2d/__init__.py @@ -9,13 +9,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .conv2d import conv2d -from .conv2d import Conv2dOp +from .conv2d import conv2d, conv2d_channel_last +from .conv2d import Conv2dOp, Conv2dChannelLastOp from .conv2d_winograd import conv2d_winograd, conv2d_winograd_image_transform, conv2d_winograd_filter_transform from .conv2d_winograd import conv2d_winograd_inverse_transform from .conv2d_winograd import Conv2dWinogradInverseTransformOp, Conv2dWinogradFilterTransformOp from .conv2d_winograd import Conv2dWinogradImageTransformOp -from .conv2d_gemm import conv2d_gemm, conv2d_gemm_image_transform, conv2d_gemm_filter_transform +from .conv2d_gemm import ( + conv2d_gemm, + conv2d_gemm_fp16, + conv2d_gemm_fp16_channel_last, + conv2d_gemm_image_transform, + conv2d_gemm_filter_transform, +) from .conv2d_gemm import conv2d_gemm_inverse_transform from .conv2d_gemm import Conv2dGemmImageTransformOp diff --git a/python/hidet/graph/ops/conv2d/conv2d.py b/python/hidet/graph/ops/conv2d/conv2d.py index 90e44cb39..beedb08e9 100644 --- a/python/hidet/graph/ops/conv2d/conv2d.py +++ b/python/hidet/graph/ops/conv2d/conv2d.py @@ -57,6 +57,48 @@ def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dila super().__init__(name='conv2d', inputs=[data, weight], outputs=[output]) +class Conv2dChannelLastTask(Task): + def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int): + # pylint: disable=too-many-locals + # we assume that only data needs to have dynamic shape + n, h, w, c = data.shape + oc, wc, kx, ky = weight.shape + sx, sy = stride + dilx, dily = dilations + p, q = (h - dilx * (kx - 1) - 1) // sx + 1, (w - dily * (ky - 1) - 1) // sy + 1 + self._assert( + ir.logical_or(c % groups == 0, oc % groups == 0), + msg=( + 'Conv2d expect the in_channels % groups == 0 and out_channels % groups == 0, \n' + 'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups) + ), + ) + self._assert( + wc * groups == c, + msg=( + 'Conv2d expect the weight has shape [out_channels, in_channels / groups, kx, ky], \n' + 'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, kx, ky], c, groups) + ), + ) + out_group_size = oc // groups + output = compute( + name='out', + shape=[n, p, q, oc], + fcompute=lambda ni, pi, qi, oci: reduce( + shape=[wc, kx, ky], + fcompute=lambda wci, kxi, kyi: ( + data[ni, pi * sx + kxi * dilx, qi * sy + kyi * dily, (oci // out_group_size) * wc + wci] + * weight[oci, wci, kxi, kyi] + ), + reduce_type='sum', + ), + ) + self.channels = c + self.stride = stride + self.groups = groups + super().__init__(name='conv2d_channel_last', inputs=[data, weight], outputs=[output]) + + class Conv2dOp(Operator): def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int): stride = normalize_stride(stride) @@ -68,6 +110,17 @@ def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union ) +class Conv2dChannelLastOp(Operator): + def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int): + stride = normalize_stride(stride) + dilations = normalize_dilations(dilations) + super().__init__( + inputs=[x, w], + attributes={'stride': stride, 'groups': groups, 'dilations': dilations}, + task=Conv2dChannelLastTask(input_like(x, 'x'), input_like(w, 'w'), stride, dilations, groups), + ) + + def conv2d( data: Tensor, weight: Tensor, @@ -76,3 +129,13 @@ def conv2d( groups: int = 1, ) -> Tensor: return Conv2dOp(data, weight, stride, dilations, groups).get_output(0) + + +def conv2d_channel_last( + data: Tensor, + weight: Tensor, + stride: Union[int, Sequence[int]] = (1, 1), + dilations: Union[int, Sequence[int]] = (1, 1), + groups: int = 1, +) -> Tensor: + return Conv2dChannelLastOp(data, weight, stride, dilations, groups).get_output(0) diff --git a/python/hidet/graph/ops/conv2d/conv2d_gemm.py b/python/hidet/graph/ops/conv2d/conv2d_gemm.py index 972b6c9b6..1fd17c806 100644 --- a/python/hidet/graph/ops/conv2d/conv2d_gemm.py +++ b/python/hidet/graph/ops/conv2d/conv2d_gemm.py @@ -9,12 +9,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Sequence +from typing import List, Tuple, Sequence -from hidet.ir.expr import is_constant +from hidet.graph.tensor import Tensor +from hidet.ir import dtypes +from hidet.ir.dtypes import float16 +from hidet.ir.expr import if_then_else, is_constant, Int +from hidet.ir.func import Function +from hidet.ir.module import IRModule +from hidet.ir.compute import TensorNode +from hidet.ir.task import Task +from hidet.ir.compute import compute, reduce from hidet.graph.ops.matmul import matmul -from hidet.graph.ops.utils import Task, Operator, Tensor, compute, input_like, TensorNode -from hidet.graph.ops.utils import normalize_kernel, normalize_stride +from hidet.graph.ops.utils import Operator, input_like +from hidet.graph.ops.utils import normalize_kernel, normalize_stride, tune +from hidet.utils.py import is_power_of_two, cdiv from .utils import infer_conv2d_shape @@ -94,3 +103,531 @@ def conv2d_gemm(data: Tensor, weight: Tensor, stride, dilations: List[int], grou y_shape = infer_conv2d_shape(data.shape, weight.shape, stride, groups, dilations) y = conv2d_gemm_inverse_transform(gemm_y, out_height=y_shape[2], out_width=y_shape[3]) return y + + +class Conv2dGemmFp16Task(Task): + def __init__( + self, + img: TensorNode, + weight: TensorNode, + orig_weight_shape: List[int], + stride: List[int], + dilations: List[int], + groups: int = 1, + parallel_k_parts: int = 1, + ): + # Channel last + # This kernel expects the weight to be transformed in the following way: + # weight.shape [OC, WC, KY, KX] -> [KY * KX * WC, OC] + self._assert(len(img.shape) == 4, f"expect img shape to be in NHWC format, got {img.shape}") + self._assert( + len(weight.shape) == 2, + f"expected weight to be transformed from [OC, WC, KY, kX] to [KY * KX * WC, OC], got {weight.shape}", + ) + self._assert(img.type.dtype == float16 and weight.type.dtype == float16, 'Both inputs must be float16 tensors') + + self.groups = groups + self.dilations = dilations + self.stride = stride + self.img_shape = img.shape + self.orig_weight_shape = orig_weight_shape + + DILY, DILX = dilations + STRY, STRX = stride + # orig_weight_shape == [OC, WC, KY, KX] + N, H, W, C = img.shape + OC, WC, KY, KX = orig_weight_shape + + self._assert(C % groups == 0, f"expected input channels to be divisible by groups, got {C}") + self._assert(OC % groups == 0, f"expected output channels to be divisible by groups, got {OC}") + self._assert( + groups * WC == C, + f"expected groups * WC == C, got groups: {groups}, WC: {WC}, C: {C}; make sure the image is channels last!", + ) + self._assert( + DILX > 0 and DILY > 0 and STRX > 0 and STRY > 0, + f"dilations and strides must be larger than 0, got strides={(STRY, STRX)}, dilations={(DILY, DILX)}", + ) + self._assert(parallel_k_parts > 0, "expected parallel_k_parts to be greater than 0") + self._assert(H >= KY and W >= KX, "expected image dimensions to be greater than filter dimensions") + + OUT_H = (H - DILY * (KY - 1) - 1) // STRY + 1 + OUT_W = (W - DILX * (KX - 1) - 1) // STRX + 1 + + self.out_shape = [parallel_k_parts, N, OUT_H, OUT_W, OC] + + k_size = WC * KY * KX + k_part_extent = cdiv(k_size, parallel_k_parts) + + # k is tiled from [ky, kx, wc] + # this compute definition is not ever going to be used, since we always + # implement cuda on fp16 + def f_compute(k, ni, hi, wi, oci): + wci = k % WC + ky = (k // (WC * KX)) % KY + kx = (k // WC) % KX + out_group_size = OC // groups + return ( + img[ni, hi * STRY + ky * DILY, wi * STRX + kx * DILX, (oci // out_group_size) * WC + wci] + * weight[k, oci] + ) + + c = compute( + name='c', + shape=self.out_shape, + fcompute=lambda kpi, ni, hi, wi, oci: reduce( + shape=[k_part_extent], + fcompute=lambda k: if_then_else( + kpi * k_part_extent + k < k_size, f_compute(kpi * k_part_extent + k, ni, hi, wi, oci), float16(0.0) + ), + reduce_type='sum', + ), + ) + + super().__init__( + name='conv_gemm_fp16_pk', + inputs=[img, weight], + outputs=[c], + attributes={ + 'stride': stride, + 'dilations': dilations, + 'orig_weight_shape': orig_weight_shape, + 'groups': groups, + 'parallel_k_parts': parallel_k_parts, + }, + ) + + def allow_prologue(self) -> bool: + return False + + def allow_epilogue(self) -> bool: + return True + + def implement_cuda(self, working_dir: str) -> List[IRModule]: + return tune.extract_ir_modules(self.schedule) + + @tune.space( + 2, + block_m=[32, 64, 128, 256], + block_n=[32, 64, 128, 256], + block_k=[8, 16, 32, 64, 128], + warp_m=[16, 32, 48, 64], + warp_n=[16, 32, 48, 64], + warp_k=[8, 16, 32, 64], + mma=['m16n8k16'], + ) + @tune.space(1, block_m=[128], block_n=[128], block_k=[16], warp_m=[64], warp_n=[64], warp_k=[16], mma=['m16n8k16']) + def schedule( + self, block_m=64, block_n=128, block_k=16, warp_m=32, warp_n=64, warp_k=16, mma: str = 'm16n8k16' + ) -> IRModule: + # pylint: disable=unused-variable + import hidet + from hidet.ir.type import tensor_type + from hidet.lang import attrs, view, u32, tensor_pointer, grid + from hidet.lang.layout import row_layout + from hidet.lang.mapping import spatial, auto_map + from hidet.lang.cuda import blockIdx, threadIdx, syncthreads, dynamic_shared_memory + from hidet.lang.cuda import MmaConfig, mma_sync, cp_async, cp_async_wait_all, ldmatrix + from hidet.lang.cuda import register_tensor + + DILY, DILX = self.dilations + STRY, STRX = self.stride + N, H, W, C = self.img_shape + OC, WC, KY, KX = self.orig_weight_shape + GROUPS = self.groups + + GROUP_C = C // GROUPS + GROUP_OC = OC // GROUPS + # actual shape = [KY * KX * WC, OC] + + K_PARTS, _, OUT_H, OUT_W, _ = self.out_shape + + # the problem is that the block_k is not contiguous across the channel dimension, depending on certain + # configuration of parameters + TILES_K = cdiv(GROUP_C, block_k) * KX * KY + K_TILES_PER_BLOCK = cdiv(TILES_K, K_PARTS) # number of tiles assigned to each block + + # schedule parameters + mma_configs = {'m16n8k8': MmaConfig.m16n8k8_f16_f16(), 'm16n8k16': MmaConfig.m16n8k16_f16_f16()} + tune.check(mma in mma_configs) + mma_config = mma_configs[mma] + + # number of elements each warp handles at once + mma_m, mma_n, mma_k = mma_config.m, mma_config.n, mma_config.k # 16, 8, 16 + # number of warps in each dimension + warp_count_m, warp_count_n, warp_count_k = block_m // warp_m, block_n // warp_n, block_k // warp_k + # number of repeats that each warp has to do + mma_count_m, mma_count_n, mma_count_k = warp_m // mma_m, warp_n // mma_n, warp_k // mma_k + threads = warp_count_m * warp_count_n * warp_count_k * 32 + + grid_dim: Tuple[Int, Int, Int] = cdiv(OUT_H * OUT_W, block_m), cdiv(GROUP_OC, block_n), N * K_PARTS * GROUPS + dynamic_smem_bytes = max(2 * (block_m + block_n) * block_k * 2, block_m * block_n * 2) + + ### checks + tune.check(block_m % warp_m == block_n % warp_n == block_k % warp_k == 0, 'warp dims divide block dims') + tune.check(warp_m % mma_m == warp_n % mma_n == warp_k % mma_k == 0, 'mma dims divide warp dims') + tune.check(threads <= 1024, 'threads in a block <= 1024') + maximum_smem_bytes = 49152 + tune.check(dynamic_smem_bytes <= maximum_smem_bytes, 'dynamic shared memory <= 49152') + + tune.check(block_n % 64 == 0, 'block_n must be multiple of 64, required by async gmem -> smem loading') + tune.check(block_k % 8 == 0) + tune.check(is_power_of_two(block_k // 8)) + + smem_img_type = tensor_type( + 'float16', + shape=[block_m, block_k], + layout=row_layout(block_m, block_k // 8).swizzle(1) * row_layout(1, 8) + # layout=row_layout(block_m, block_k) + ) + smem_weight_type = tensor_type( + 'float16', + shape=[block_k, block_n], + layout=row_layout(block_k // 8, block_n // 64) * row_layout(8, 8).swizzle(1) * row_layout(1, 8), + # layout=row_layout(block_k, block_n) + ) + load_smem_a_map = auto_map(block_m, block_k // 8, workers=threads, on_fail=lambda msg: tune.check(False, msg)) + load_smem_b_map = auto_map(block_k, block_n // 8, workers=threads, on_fail=lambda msg: tune.check(False, msg)) + store_smem_c_map = auto_map(block_m, block_n, workers=threads, on_fail=lambda msg: tune.check(False, msg)) + + with hidet.script_module() as module: + + @hidet.script + def load_regs_a(mi: int, k1: int, smem_a: smem_img_type, regs_a: float16[mma_config.a_elements]): + # mi - mma_count_m + # k1 - mma_count_k + # block - [warp_count_m, warp_count_n, warp_count_k] + # each warp handles: [warp_m, warp_k] == [mma_count_m * mma_m, mma_count_k * mma_k] + # smem_a - [block_m, block_k] + warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 + wk = warp_id % warp_count_k + wi = warp_id // (warp_count_k * warp_count_n) + p = lane_id % 16 + q = lane_id // 16 + row_addr = ~smem_a[wi * warp_m + mi * mma_m + p, wk * warp_k + k1 * mma_k + q * 8] + b32_regs = view(regs_a, u32[4]) + ldmatrix( + regs=[b32_regs[0], b32_regs[1], b32_regs[2], b32_regs[3]], + smem_addr=row_addr, + shared_space_addr=False, + trans=False, + ) + + @hidet.script + def load_regs_b(mj: int, k1: int, smem_b: smem_weight_type, regs_b: float16[mma_config.b_elements]): + # mj - mma_count_n + # k1 - mma_count_k + # each warp handles: [warp_k, warp_n] == [mma_count_k * mma_k, mma_count_n * mma_n] + # smem_b - [block_k, block_n] + warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 + wj = (warp_id // warp_count_k) % warp_count_n + wk = warp_id % warp_count_k + + p = lane_id % 16 + # have not used q as we only use the address of the first 16 threads to load 2 of 8x8 f16 matrix. + row_addr = ~smem_b[wk * warp_k + k1 * mma_k + p, wj * warp_n + mj * mma_n] + regs = view(regs_b, u32[2]) + ldmatrix(regs=[regs[0], regs[1]], smem_addr=row_addr, trans=True) + + @hidet.script + def warp_mma( + regs_a: float16[mma_config.a_elements], + regs_b: float16[mma_config.b_elements], + regs_c: float16[mma_config.c_elements], + ): + mma_sync(mma_config, regs_a, regs_b, regs_c) + + @hidet.script + def load_smem_img(k0: int, img: float16[N, H, W, C], smem_img: smem_img_type): + offset_m = blockIdx.x * block_m # this is the output pixel index + + # the current global tile index, where each tile is of size [block_k] + k_tile_idx = (blockIdx.z // (N * GROUPS)) * K_TILES_PER_BLOCK + k0 + + batch_idx = (blockIdx.z // GROUPS) % N + + group_idx = blockIdx.z % GROUPS + num_tiles_per_channel = cdiv(GROUP_C, block_k) + channel_idx = k_tile_idx // num_tiles_per_channel + channel_group_offset = (k_tile_idx % num_tiles_per_channel) * block_k + filter_y = channel_idx // KX + filter_x = channel_idx % KX + + for i, k_seg in load_smem_a_map.on(threadIdx.x): + k = k_seg * 8 + + # tiling the output image spatial dimension [OUT_H, OUT_W] + img_spatial = i + offset_m + oh_idx = img_spatial // OUT_W + ow_idx = img_spatial % OUT_W + + # these are the input pixel coordinates + ih_idx = oh_idx * STRY + filter_y * DILY + iw_idx = ow_idx * STRX + filter_x * DILX + + channel_offset = channel_group_offset + k + group_idx * GROUP_C + + src_size = 0 + if iw_idx < W and ih_idx < H and channel_group_offset + k < GROUP_C: + src_size = min(8, GROUP_C - (channel_group_offset + k)) + + # a bit strange, the two branches should be the same, but gives different results + # but only when GROUP_C % 8 != 0 + if GROUP_C % 8 == 0: + cp_async( + ~smem_img[i, k], + ~img[batch_idx, ih_idx, iw_idx, channel_offset], + cp_size=16, + src_size=src_size * 2, + cache_level='global', + ) + else: + for ki in range(src_size): + smem_img[i, k + ki] = img[batch_idx, ih_idx, iw_idx, channel_offset + ki] + for ki in range(8 - src_size): + smem_img[i, k + ki + src_size] = 0 + + @hidet.script + def load_smem_weight(k0: int, weight: float16[KX * KY * WC, OC], smem_weight: smem_weight_type): + group_idx = blockIdx.z % GROUPS + offset_n_group = blockIdx.y * block_n + + k_tile_idx = (blockIdx.z // (N * GROUPS)) * K_TILES_PER_BLOCK + k0 + offset_k = 0 + + num_tiles_per_channel = cdiv(GROUP_C, block_k) + channel_idx = k_tile_idx // num_tiles_per_channel + channel_offset = k_tile_idx % num_tiles_per_channel + filter_y = channel_idx // KX + filter_x = channel_idx % KX + offset_k = filter_y * KX * WC + filter_x * WC + channel_offset * block_k + + for k, j_seg in load_smem_b_map.on(threadIdx.x): + j = j_seg * 8 + # we don't need to mask channel wise, since we have already done so for the img + # so the extra bits are not relevant when multipled by zeros + offset_n = offset_n_group + group_idx * GROUP_OC + src_size = ( + 0 + if (offset_n_group + j >= GROUP_OC or offset_k + k >= KY * KX * WC) + else min(8, GROUP_OC - (offset_n_group + j)) + ) + + # also quite strange; the two branches should be the same, but gives different + # results when GROUP_OC % 8 != 0 + if GROUP_OC % 8 == 0: + cp_async( + ~smem_weight[k, j], + ~weight[offset_k + k, offset_n + j], + cp_size=16, + src_size=src_size * 2, + cache_level='global', + ) + else: + for ji in range(src_size): + smem_weight[k, j + ji] = weight[offset_k + k, offset_n + j + ji] + for ji in range(8 - src_size): + smem_weight[k, j + ji + src_size] = 0 + + @hidet.script + def matmul_f16_kernel( + img: float16[N, H, W, C], weight: float16[KX * KY * WC, OC], res: float16[K_PARTS, N, OUT_H, OUT_W, OC] + ): + # matrix multiplication, using mma instruction + attrs.cuda.grid_dim = grid_dim + attrs.cuda.block_dim = threads + # the second 2 means '2 bytes per float16' + attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes + # smem_storage = dyn_smem_storage + smem_img = tensor_pointer( + 'float16', shape=[2, block_m, block_k], layout=row_layout(2) + smem_img_type.layout + ) + smem_weight = tensor_pointer( + 'float16', shape=[2, block_k, block_n], layout=row_layout(2) + smem_weight_type.layout + ) + smem_img = dynamic_shared_memory(byte_offset=0, dtype=float16) + smem_weight = dynamic_shared_memory(byte_offset=2 * block_m * block_k * 2, dtype=float16) + regs_a = register_tensor(float16, [2, mma_count_m, mma_config.a_elements]) + regs_b = register_tensor(float16, [2, mma_count_n, mma_config.b_elements]) + regs_c = register_tensor(float16, [mma_count_m, mma_count_n, mma_config.c_elements]) + + for i, j, p in grid(mma_count_m, mma_count_n, mma_config.c_elements): + regs_c[i, j, p] = 0.0 + + load_smem_img(0, img, ~smem_img[0, 0, 0]) + load_smem_weight(0, weight, ~smem_weight[0, 0, 0]) + cp_async_wait_all() + + syncthreads() + for k0 in range(K_TILES_PER_BLOCK): + load_smem_img(k0 + 1, img, ~smem_img[(k0 + 1) % 2, 0, 0]) + load_smem_weight(k0 + 1, weight, ~smem_weight[(k0 + 1) % 2, 0, 0]) + + for mi in range(mma_count_m): + load_regs_a(mi, 0, ~smem_img[k0 % 2, 0, 0], ~regs_a[0, mi, 0]) + for mj in range(mma_count_n): + load_regs_b(mj, 0, ~smem_weight[k0 % 2, 0, 0], ~regs_b[0, mj, 0]) + for mk in range(mma_count_k): + if mk + 1 < mma_count_k: + for mi in range(mma_count_m): + load_regs_a(mi, mk + 1, ~smem_img[k0 % 2, 0, 0], ~regs_a[(mk + 1) % 2, mi, 0]) + for mj in range(mma_count_n): + load_regs_b(mj, mk + 1, ~smem_weight[k0 % 2, 0, 0], ~regs_b[(mk + 1) % 2, mj, 0]) + for mi, mj in grid(mma_count_m, mma_count_n): + warp_mma(~regs_a[mk % 2, mi, 0], ~regs_b[mk % 2, mj, 0], ~regs_c[mi, mj, 0]) + cp_async_wait_all() + syncthreads() + + # store back + warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 + offset_m, offset_n = blockIdx.x * block_m, blockIdx.y * block_n + k_part_idx = blockIdx.z // (N * GROUPS) + batch_idx = (blockIdx.z // GROUPS) % N + group_idx = blockIdx.z % GROUPS + group_offset = group_idx * GROUP_OC + + if warp_count_k == 1: + wi = warp_id // (warp_count_n * warp_count_k) + wj = (warp_id // warp_count_k) % warp_count_n + wk = warp_id % warp_count_k + + for mi in range(mma_count_m): + for mj in range(mma_count_n): + p = 0 + for i, j in mma_config.c_store_map.on(lane_id): + res_spatial = wi * warp_m + mi * mma_m + i + offset_m + channel_group_idx = wj * warp_n + mj * mma_n + j + offset_n + + channel_idx = channel_group_idx + group_offset + res_x = res_spatial % OUT_W + res_y = res_spatial // OUT_W + in_bound = (res_spatial < OUT_H * OUT_W) and (channel_group_idx < GROUP_OC) + if in_bound: + res[k_part_idx, batch_idx, res_y, res_x, channel_idx] = regs_c[mi, mj, p] + p += 1 + else: + smem_c = tensor_pointer('float16', shape=[block_m, block_n]) + smem_c = dynamic_shared_memory(byte_offset=0, dtype=float16) + + for k_round in range(warp_count_k): + for wi, wj, wk in spatial(warp_count_m, warp_count_n, warp_count_k).on(warp_id): + if wk == k_round: + for mi, mj in grid(mma_count_m, mma_count_n): + p = 0 + for i, j in mma_config.c_store_map.on(lane_id): + delta_m = wi * warp_m + mi * mma_m + i + delta_n = wj * warp_n + mj * mma_n + j + in_bound = (offset_m + delta_m < OUT_H * OUT_W) and (offset_n + delta_n < OC) + if in_bound: + if k_round == 0: + smem_c[delta_m, delta_n] = regs_c[mi, mj, p] + else: + smem_c[delta_m, delta_n] += regs_c[mi, mj, p] + p += 1 + if warp_count_k > 1: + syncthreads() + for i, j in store_smem_c_map.on(threadIdx.x): + res_spatial = i + offset_m + channel_group_idx = j + offset_n + channel_idx = channel_group_idx + group_offset + + res_x = res_spatial % OUT_W + res_y = res_spatial // OUT_W + if res_spatial < OUT_H * OUT_W and channel_group_idx < GROUP_OC: + res[k_part_idx, batch_idx, res_y, res_x, channel_idx] = smem_c[i, j] + + ir_module = module.ir_module() + assert isinstance(matmul_f16_kernel, Function) + + return ir_module + + +class Conv2dGemmFp16Op(Operator): + def __init__( + self, + img: Tensor, + weight: Tensor, + orig_weight_shape: List[int], + stride: List[int], + dilations: List[int], + groups: int, + parallel_k_parts=1, + ): + if not (isinstance(parallel_k_parts, int) and not isinstance(parallel_k_parts, bool)): + raise ValueError('parallel_k_parts must be an integer, got {}'.format(parallel_k_parts)) + + super().__init__( + inputs=[img, weight], + attributes={ + 'stride': stride, + 'dilations': dilations, + 'orig_weight_shape': orig_weight_shape, + 'groups': groups, + 'parallel_k_parts': parallel_k_parts, + }, + task=Conv2dGemmFp16Task( + input_like(img, 'img'), + input_like(weight, 'weight'), + orig_weight_shape, + stride, + dilations, + groups=groups, + parallel_k_parts=parallel_k_parts, + ), + ) + + +# pylint: disable=dangerous-default-value +def parallel_part_heuristic( + input_shape, weight_shape, stride: List[int] = [1, 1], dilation: List[int] = [1, 1], groups: int = 1 +): + N, H, W, _ = input_shape + OC, WC, KY, KX = weight_shape + DILY, DILX = dilation + STRY, STRX = stride + OUT_H = (H - DILY * (KY - 1) - 1) // STRY + 1 + OUT_W = (W - DILX * (KX - 1) - 1) // STRX + 1 + m_size = OUT_H * OUT_W + n_size = OC // groups + k_size = WC * KX * KY + estimate_blocks = N * cdiv(m_size, 64) * cdiv(n_size, 64) * groups + estimate_concurrent_blocks = 80 * 5 + max_k_parts = cdiv(k_size, 64) + k_parts = min(cdiv(estimate_concurrent_blocks, estimate_blocks), max_k_parts) + return k_parts + + +def conv2d_gemm_fp16_channel_last( + img: Tensor, weight: Tensor, stride: List[int], dilations: List[int], groups: int, parallel_k_parts=1 +) -> Tensor: + import hidet + + if len(img.shape) != 4 or len(weight.shape) != 4: + raise ValueError('a and b must have 4 dimensions, got shape {} and {}'.format(img.shape, weight.shape)) + if img.dtype != dtypes.float16 or weight.dtype != dtypes.float16: + raise ValueError('ConvGemmF16Op only support float16, got {} and {}'.format(img.dtype, weight.dtype)) + oc, wc, ky, kx = weight.shape + weight = hidet.ops.transpose(weight, [2, 3, 1, 0]).reshape([ky * kx * wc, oc]) + return ( + Conv2dGemmFp16Op( + img, + weight, + orig_weight_shape=[oc, wc, ky, kx], + stride=stride, + dilations=dilations, + groups=groups, + parallel_k_parts=parallel_k_parts, + ) + .get_output(0) + .sum(0) + ) + + +def conv2d_gemm_fp16( + img: Tensor, weight: Tensor, stride: List[int], dilations: List[int], groups: int, parallel_k_parts=1 +) -> Tensor: + import hidet + + img = hidet.ops.transpose(img, [0, 2, 3, 1]) + res = conv2d_gemm_fp16_channel_last(img, weight, stride, dilations, groups, parallel_k_parts) + return hidet.ops.transpose(res, [0, 3, 1, 2]) diff --git a/python/hidet/graph/ops/conv2d/conv2d_winograd.py b/python/hidet/graph/ops/conv2d/conv2d_winograd.py index ab1909c07..2e7f219ef 100644 --- a/python/hidet/graph/ops/conv2d/conv2d_winograd.py +++ b/python/hidet/graph/ops/conv2d/conv2d_winograd.py @@ -17,7 +17,7 @@ import numpy as np -from hidet.ir.expr import const_tensor, Constant, cast +from hidet.ir.expr import const_tensor, Constant, cast, is_constant from hidet.graph.ops.matmul import matmul from ..utils import Tensor, Operator, Task, TensorNode, input_like, compute, reduce, normalize_kernel @@ -69,9 +69,9 @@ def __init__(self, w: TensorNode, ms: List[int]): assert len(w.shape) == 4 oc, c, rx, ry = w.shape mx, my = ms - if not rx.is_const(): + if not is_constant(rx): raise ValueError('winograd filter transform: rx must be const') - if not ry.is_const(): + if not is_constant(ry): raise ValueError('winograd filter transform: ry must be const') alpha_x, alpha_y = mx + rx - 1, my + ry - 1 GH = winograd_transform_matrices(mx, int(rx))[0] diff --git a/python/hidet/graph/ops/conv2d/resolve.py b/python/hidet/graph/ops/conv2d/resolve.py index 11a3b3cf4..ae8e40fb6 100644 --- a/python/hidet/graph/ops/conv2d/resolve.py +++ b/python/hidet/graph/ops/conv2d/resolve.py @@ -14,8 +14,10 @@ from hidet.graph.transforms import ResolveRule, register_resolve_rule from hidet.graph import ops from hidet.ir.expr import is_constant +from hidet.ir.dtypes import float16 -from .conv2d import Conv2dOp +from .conv2d import Conv2dOp, Conv2dChannelLastOp +from .conv2d_gemm import parallel_part_heuristic @register_resolve_rule(Conv2dOp) @@ -34,10 +36,38 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]: return None # use depthwise schedule in the default Task data, weight = op.inputs kernel_size = weight.shape[2:] - if self.enable_winograd and tuple(stride) == (1, 1) and tuple(kernel_size) == (3, 3) and groups == 1: + if channels >= 16 and data.dtype == float16 and weight.dtype == float16: + # we set parallel_k to 1 for channel first, because we need to transpose back; + # setting parallel_k > 1 pervents epilogue fusion, leading to bad performance. + k_parts = 1 + out = ops.conv2d_gemm_fp16(data, weight, stride, dilations, groups, k_parts) + elif self.enable_winograd and tuple(stride) == (1, 1) and tuple(kernel_size) == (3, 3) and groups == 1: # winograd algorithm out = ops.conv2d_winograd(data, weight) else: # implicit gemm algorithm out = ops.conv2d_gemm(data, weight, stride, dilations, groups) return [out] + + +@register_resolve_rule(Conv2dChannelLastOp) +class Conv2dChannelLastResolveRule(ResolveRule): + def resolve(self, op: Operator) -> Optional[List[Tensor]]: + assert isinstance(op, Conv2dOp) + stride = ops.utils.normalize_stride(op.attrs['stride']) + groups = op.attrs['groups'] + dilations = op.attrs['dilations'] + channels = op.inputs[0].shape[-1] + # TODO: current assert mechanism does not cover this use case + if is_constant(channels) and groups == channels: + return None # use depthwise schedule in the default Task + data, weight = op.inputs + if channels >= 16 and data.dtype == float16 and weight.dtype == float16: + # after some benchmarking, basically k_parts = 1 is sufficent for most cases + if all(is_constant(s) for s in data.shape): + k_parts = parallel_part_heuristic(data.shape, weight.shape, stride, dilations, groups) + else: + k_parts = 1 + out = ops.conv2d_gemm_fp16_channel_last(data, weight, stride, dilations, groups, k_parts) + return [out] + return None diff --git a/python/hidet/ir/layout.py b/python/hidet/ir/layout.py index f4e940217..1fdf6dbd7 100644 --- a/python/hidet/ir/layout.py +++ b/python/hidet/ir/layout.py @@ -95,7 +95,7 @@ def __str__(self): return '{}(shape={}, size={})'.format(self.__class__.__name__, self.shape, self.size) else: shape = [int(v) for v in self.shape] - table = np.zeros(shape=shape, dtype=np.int) + table = np.zeros(shape=shape, dtype=int) ranges = [range(v) for v in shape] for indices in itertools.product(*ranges): local_index = self.global2local(*indices) diff --git a/tests/operators/test_conv2d.py b/tests/operators/test_conv2d.py index 0a2fd53ad..2e63868ff 100644 --- a/tests/operators/test_conv2d.py +++ b/tests/operators/test_conv2d.py @@ -16,17 +16,103 @@ import pytest from hidet import ops -from hidet.testing import check_binary, check_binary_dynamic +from hidet.testing import check_binary, check_binary_dynamic, check_torch_binary -def torch_conv2d(data: np.ndarray, weight: np.ndarray, padding: List[int], stride: List[int], dilations: List[int]): +def torch_conv2d( + data: np.ndarray, weight: np.ndarray, padding: List[int], stride: List[int], dilations: List[int], groups: int = 1 +): data_torch, weight_torch = torch.from_numpy(data), torch.from_numpy(weight) + needs_convert = False + if data_torch.dtype == torch.float16 and not data_torch.is_cuda: + data_torch = data_torch.cuda() + weight_torch = weight_torch.cuda() + needs_convert = True torch_out = torch.nn.functional.conv2d( - data_torch, weight_torch, bias=None, stride=stride, padding=[padding[0], padding[1]], dilation=dilations + data_torch, + weight_torch, + bias=None, + stride=stride, + padding=[padding[0], padding[1]], + dilation=dilations, + groups=groups, ) + if needs_convert: + torch_out = torch_out.cpu() return torch_out.numpy() +@pytest.mark.parametrize( + "n, c, h, w, oc, kx, ky", + [ + [1, 64, 32, 32, 12, 3, 3], # kernel 3, + [2, 128, 32, 32, 32, 5, 5], # kernel 7, batch size 2 + [1, 32, 32, 32, 64, 1, 1], # kernel 1, + ], +) +@pytest.mark.parametrize("groups", [1, 2, 4]) +@pytest.mark.parametrize("stride", [[1, 1], [2, 3]]) +@pytest.mark.parametrize("dilations", [[1, 1], [2, 3]]) +@pytest.mark.parametrize("parallel_k", [1, 2, 3]) +@pytest.mark.parametrize( + "device", ["cuda"] +) # we don't test for cpu because its quite imprecise in fp16 for larger kernel sizes +def test_conv2d_gemm_fp16(n, c, h, w, oc, kx, ky, groups, stride, dilations, parallel_k, device): + if device == 'cpu': + tol = 0.8 + else: + tol = 0.5 + check_binary( + a_shape=[n, c, h, w], + b_shape=[oc, c // groups, kx, ky], + numpy_op=lambda data, weight: torch_conv2d(data, weight, [0, 0], stride, dilations, groups), + hidet_op=lambda data, weight: ops.transpose( + ops.conv2d_gemm_fp16_channel_last( + ops.transpose(data, [0, 2, 3, 1]), + weight, + stride=stride, + dilations=dilations, + groups=groups, + parallel_k_parts=parallel_k, + ), + [0, 3, 1, 2], + ), + dtype='float16', + device=device, + atol=tol, + rtol=tol, + ) + + +@pytest.mark.parametrize( + "n, c, h, w, oc, kx, ky", + [ + [1, 64, 32, 32, 12, 3, 3], # kernel 3, + [2, 128, 32, 32, 32, 5, 5], # kernel 7, batch size 2 + [1, 32, 32, 32, 64, 1, 1], # kernel 1, + ], +) +@pytest.mark.parametrize("groups", [1, 2, 4]) +@pytest.mark.parametrize("stride", [[1, 1], [2, 3]]) +@pytest.mark.parametrize("dilations", [[1, 1], [2, 3]]) +def test_conv2d_channel_last(n, c, h, w, oc, kx, ky, groups, stride, dilations): + check_torch_binary( + a_shape=[n, c, h, w], + b_shape=[oc, c // groups, kx, ky], + torch_func=lambda data, weight: torch.nn.functional.conv2d( + data, weight, bias=None, stride=stride, padding=[0, 0], dilation=dilations, groups=groups + ), + hidet_func=lambda data, weight: ops.transpose( + ops.conv2d_channel_last( + ops.transpose(data, [0, 2, 3, 1]), weight, stride=stride, dilations=dilations, groups=groups + ), + [0, 3, 1, 2], + ), + atol=0.5, + rtol=0.5, + ) + + @pytest.mark.parametrize("hidet_op", [ops.conv2d, ops.conv2d_gemm]) @pytest.mark.parametrize( "n, c, h, w, oc, kx, ky",