Skip to content

Commit

Permalink
[cuDNN] Add cudnn conv2d (#435)
Browse files Browse the repository at this point in the history
Add cudnn conv2d to runtime.
  • Loading branch information
yaoyaoding committed Jul 3, 2024
2 parents 531b8d3 + a48012c commit a2a60b1
Show file tree
Hide file tree
Showing 10 changed files with 1,599 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
add_library(hidet_runtime SHARED
src/hidet/runtime/cuda/context.cpp
src/hidet/runtime/cuda/cublas.cpp
src/hidet/runtime/cuda/cudnn.cpp
src/hidet/runtime/cuda/cuda.cpp
src/hidet/runtime/cpu/context.cpp
src/hidet/runtime/callbacks.cpp
Expand Down
27 changes: 27 additions & 0 deletions include/hidet/runtime/cuda/cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,30 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define HIDET_CUDNN_MAX_GPUS 32

#include <hidet/runtime/common.h>

struct cudnnContext;
typedef struct cudnnContext *cudnnHandle_t;

typedef void *cudnnBackendDescriptor_t;

/* Legacy API */
struct cudnnTensorStruct;
struct cudnnFilterStruct;
struct cudnnConvolutionStruct;

typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t;
typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t;
typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t;

struct CudnnContext {
cudnnHandle_t handles[HIDET_CUDNN_MAX_GPUS];
static CudnnContext* global();
static cudnnHandle_t current_handle();
};

DLL void hidet_cudnn_set_library_path(const char* path);

1 change: 1 addition & 0 deletions python/hidet/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .event import Event

from . import cublas
from . import cudnn
13 changes: 13 additions & 0 deletions python/hidet/cuda/cudnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import cudnnDataType
from .kernels import conv2d, conv2d_gemm
175 changes: 175 additions & 0 deletions python/hidet/cuda/cudnn/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import numpy as np
import torch

import hidet
from hidet.cuda.cudnn import cudnnDataType
from hidet.utils.benchmark import do_bench
from hidet import ops


def benchmark_cudnn_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
# Uses ordinary cudnn.conv2d implemented with Graph-API
tx = tw = ty = dtype
pad_dim1, pad_dim2 = padding
str_dim1, str_dim2 = stride
dil_dim1, dil_dim2 = dilations

tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx)
tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw)
tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty)

latencies = do_bench(
lambda: hidet.cuda.cudnn.conv2d(
n,
c,
h,
w,
k,
r,
s,
p,
q,
tensor_x,
tensor_w,
tensor_y,
tx,
tw,
ty,
compute_type,
pad_dim1,
pad_dim2,
str_dim1,
str_dim2,
dil_dim1,
dil_dim2,
),
warmup=10,
rep=1,
)

print(
f"CuDNN Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, "
f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_cudnn_conv2d_gemm(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
# Uses cudnn.conv2d_gemm implemented with Legacy-API
tx = tw = ty = dtype
pad_dim1, pad_dim2 = padding
str_dim1, str_dim2 = stride
dil_dim1, dil_dim2 = dilations

tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx)
tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw)
tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty)

latencies = do_bench(
lambda: hidet.cuda.cudnn.conv2d_gemm(
n,
c,
h,
w,
k,
r,
s,
tensor_x,
tensor_w,
tensor_y,
tx,
tw,
ty,
compute_type,
pad_dim1,
pad_dim2,
str_dim1,
str_dim2,
dil_dim1,
dil_dim2,
),
warmup=10,
rep=100,
)

print(
f"cudnn_gemm Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, "
f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_torch_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
# Native PyTorch Eager-mode Execution
data = np.array(np.random.randn(n, c, h, w)).astype(dtype)
weight = np.array(np.random.randn(k, c, r, s)).astype(dtype)

data_torch, weight_torch = torch.from_numpy(data), torch.from_numpy(weight)
data_torch = data_torch.cuda()
weight_torch = weight_torch.cuda()

latencies = do_bench(
lambda: torch.nn.functional.conv2d(
data_torch, weight_torch, bias=None, stride=stride, padding=padding, dilation=dilations, groups=1
),
warmup=10,
rep=100,
)

print(
f"PyTorch Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, "
f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
# Uses optimized Hidet Graph implementation
tx = tw = dtype
pad_dim1, pad_dim2 = padding
str_dim1, str_dim2 = stride
dil_dim1, dil_dim2 = dilations

hidet.option.search_space(2)
tensor_x = hidet.symbol((n, c, h, w), device='cuda', dtype=tx)
tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw)
output = ops.conv2d(
tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2)
)
graph = hidet.trace_from(output, inputs=[tensor_x, tensor_w])
graph = hidet.graph.optimize(graph)
graph = graph.cuda_graph()

latencies = do_bench(lambda: graph.run_async(), warmup=10, rep=100)

print(
f"Optimized Hidet Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, "
f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


if __name__ == '__main__':
sizes = [
# Group 1
[1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]],
[2, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]],
[4, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]],
[8, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]],
# Group 2
[1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]],
[2, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]],
[4, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]],
[8, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]],
]
dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float16', cudnnDataType.CUDNN_DATA_HALF]]

for data_type in dtypes:
for size in sizes:
benchmark_cudnn_conv2d_gemm(*(data_type + size))
benchmark_torch_conv2d(*(data_type + size))
benchmark_hidet_conv2d(*(data_type + size))
109 changes: 109 additions & 0 deletions python/hidet/cuda/cudnn/ffi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import glob
from enum import IntEnum
from ctypes import c_int32, c_void_p, c_char_p
from hidet.ffi.ffi import get_func

from hidet.utils.py import initialize


class cudnnDataType(IntEnum):
"""
defined in cudnn_ops_infer_v8.h
"""

CUDNN_DATA_FLOAT = 0
CUDNN_DATA_DOUBLE = 1
CUDNN_DATA_HALF = 2
CUDNN_DATA_INT8 = 3
CUDNN_DATA_INT32 = 4
CUDNN_DATA_INT8x4 = 5
CUDNN_DATA_UINT8 = 6
CUDNN_DATA_UINT8x4 = 7
CUDNN_DATA_INT8x32 = 8
CUDNN_DATA_BFLOAT16 = 9
CUDNN_DATA_INT64 = 10


set_library_path = get_func(func_name='hidet_cudnn_set_library_path', arg_types=[c_char_p], restype=None)

conv2d = get_func(
func_name='hidet_cudnn_conv2d',
arg_types=[
c_int32, # n
c_int32, # c
c_int32, # h
c_int32, # w
c_int32, # k
c_int32, # r
c_int32, # s
c_int32, # p
c_int32, # q
c_void_p, # ptr_x
c_void_p, # ptr_w
c_void_p, # ptr_y
c_int32, # tx
c_int32, # tw
c_int32, # ty
c_int32, # compute_type
c_int32, # pad_dim1
c_int32, # pad_dim2
c_int32, # str_dim1
c_int32, # str_dim2
c_int32, # dil_dim1
c_int32, # dil_dim2
],
restype=None,
)

conv2d_gemm = get_func(
func_name='hidet_cudnn_conv2d_gemm',
arg_types=[
c_int32, # n
c_int32, # c
c_int32, # h
c_int32, # w
c_int32, # k
c_int32, # r
c_int32, # s
c_void_p, # ptr_x
c_void_p, # ptr_w
c_void_p, # ptr_y
c_int32, # tx
c_int32, # tw
c_int32, # ty
c_int32, # compute_type
c_int32, # pad_dim1
c_int32, # pad_dim2
c_int32, # str_dim1
c_int32, # str_dim2
c_int32, # dil_dim1
c_int32, # dil_dim2
],
restype=None,
)


@initialize()
def set_cudnn_library_path():
# use nvidia-cuda-cudnn
for path in sys.path:
nvidia_path = os.path.join(path, 'nvidia')
if not os.path.exists(nvidia_path):
continue
cudnn_path = glob.glob(os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.[0-9]*'))
if cudnn_path:
set_library_path(cudnn_path[0].encode('utf-8'))
return
Loading

0 comments on commit a2a60b1

Please sign in to comment.