Skip to content

Commit

Permalink
[Operators] Conv2d fp16 implicit gemm kernel (#283)
Browse files Browse the repository at this point in the history
Co-authored-by: Allan Lin <[email protected]>
  • Loading branch information
Aalanli and Allan Lin committed Jun 20, 2023
1 parent bb1612e commit 289377a
Show file tree
Hide file tree
Showing 8 changed files with 747 additions and 17 deletions.
10 changes: 9 additions & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# pylint: disable=redefined-builtin
from .conv1d import conv1d
from .conv1d_transpose import conv1d_transpose
from .conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform
from .conv2d import (
conv2d,
conv2d_channel_last,
conv2d_winograd,
conv2d_gemm,
conv2d_gemm_fp16,
conv2d_gemm_fp16_channel_last,
conv2d_gemm_image_transform,
)
from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .conv3d import conv3d, conv3d_gemm
from .conv3d_transpose import conv3d_transpose
Expand Down
12 changes: 9 additions & 3 deletions python/hidet/graph/ops/conv2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv2d import conv2d
from .conv2d import Conv2dOp
from .conv2d import conv2d, conv2d_channel_last
from .conv2d import Conv2dOp, Conv2dChannelLastOp
from .conv2d_winograd import conv2d_winograd, conv2d_winograd_image_transform, conv2d_winograd_filter_transform
from .conv2d_winograd import conv2d_winograd_inverse_transform
from .conv2d_winograd import Conv2dWinogradInverseTransformOp, Conv2dWinogradFilterTransformOp
from .conv2d_winograd import Conv2dWinogradImageTransformOp
from .conv2d_gemm import conv2d_gemm, conv2d_gemm_image_transform, conv2d_gemm_filter_transform
from .conv2d_gemm import (
conv2d_gemm,
conv2d_gemm_fp16,
conv2d_gemm_fp16_channel_last,
conv2d_gemm_image_transform,
conv2d_gemm_filter_transform,
)
from .conv2d_gemm import conv2d_gemm_inverse_transform
from .conv2d_gemm import Conv2dGemmImageTransformOp

Expand Down
63 changes: 63 additions & 0 deletions python/hidet/graph/ops/conv2d/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,48 @@ def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dila
super().__init__(name='conv2d', inputs=[data, weight], outputs=[output])


class Conv2dChannelLastTask(Task):
def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int):
# pylint: disable=too-many-locals
# we assume that only data needs to have dynamic shape
n, h, w, c = data.shape
oc, wc, kx, ky = weight.shape
sx, sy = stride
dilx, dily = dilations
p, q = (h - dilx * (kx - 1) - 1) // sx + 1, (w - dily * (ky - 1) - 1) // sy + 1
self._assert(
ir.logical_or(c % groups == 0, oc % groups == 0),
msg=(
'Conv2d expect the in_channels % groups == 0 and out_channels % groups == 0, \n'
'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)
),
)
self._assert(
wc * groups == c,
msg=(
'Conv2d expect the weight has shape [out_channels, in_channels / groups, kx, ky], \n'
'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, kx, ky], c, groups)
),
)
out_group_size = oc // groups
output = compute(
name='out',
shape=[n, p, q, oc],
fcompute=lambda ni, pi, qi, oci: reduce(
shape=[wc, kx, ky],
fcompute=lambda wci, kxi, kyi: (
data[ni, pi * sx + kxi * dilx, qi * sy + kyi * dily, (oci // out_group_size) * wc + wci]
* weight[oci, wci, kxi, kyi]
),
reduce_type='sum',
),
)
self.channels = c
self.stride = stride
self.groups = groups
super().__init__(name='conv2d_channel_last', inputs=[data, weight], outputs=[output])


class Conv2dOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
stride = normalize_stride(stride)
Expand All @@ -68,6 +110,17 @@ def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union
)


class Conv2dChannelLastOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
stride = normalize_stride(stride)
dilations = normalize_dilations(dilations)
super().__init__(
inputs=[x, w],
attributes={'stride': stride, 'groups': groups, 'dilations': dilations},
task=Conv2dChannelLastTask(input_like(x, 'x'), input_like(w, 'w'), stride, dilations, groups),
)


def conv2d(
data: Tensor,
weight: Tensor,
Expand All @@ -76,3 +129,13 @@ def conv2d(
groups: int = 1,
) -> Tensor:
return Conv2dOp(data, weight, stride, dilations, groups).get_output(0)


def conv2d_channel_last(
data: Tensor,
weight: Tensor,
stride: Union[int, Sequence[int]] = (1, 1),
dilations: Union[int, Sequence[int]] = (1, 1),
groups: int = 1,
) -> Tensor:
return Conv2dChannelLastOp(data, weight, stride, dilations, groups).get_output(0)
Loading

0 comments on commit 289377a

Please sign in to comment.