Skip to content

Commit

Permalink
cuDNN cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Yudi Sun committed Jun 25, 2024
1 parent 730d30f commit 40a6149
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 297 deletions.
2 changes: 1 addition & 1 deletion python/hidet/cuda/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import cudnnDataType
from .kernels import conv2d, conv2d_gemm, conv2d_autoselect_algo
from .kernels import conv2d, conv2d_gemm
54 changes: 2 additions & 52 deletions python/hidet/cuda/cudnn/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,51 +101,6 @@ def benchmark_cudnn_conv2d_gemm(dtype, compute_type, n, c, h, w, k, p, q, r, s,
print("-------------------------------------------------")


def benchmark_cudnn_conv2d_autoselect_algo(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
# Uses cudnn Legacy-API to autoselect the fastest algorithm
tx = tw = ty = dtype
pad_dim1, pad_dim2 = padding
str_dim1, str_dim2 = stride
dil_dim1, dil_dim2 = dilations

tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx)
tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw)
tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty)

latencies = do_bench(
lambda: hidet.cuda.cudnn.conv2d_autoselect_algo(
n,
c,
h,
w,
k,
r,
s,
tensor_x,
tensor_w,
tensor_y,
tx,
tw,
ty,
compute_type,
pad_dim1,
pad_dim2,
str_dim1,
str_dim2,
dil_dim1,
dil_dim2,
),
warmup=10,
rep=100,
)

print(
f"cudnn_autoselect_algo Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, "
f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("Median Latency Is: " + str(latencies[1]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_torch_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
# Native PyTorch Eager-mode Execution
Expand Down Expand Up @@ -189,11 +144,7 @@ def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, paddi
graph = hidet.graph.optimize(graph)
graph = graph.cuda_graph()

latencies = do_bench(
lambda: graph.run_async(),
warmup=10,
rep=100,
)
latencies = do_bench(lambda: graph.run_async(), warmup=10, rep=100)

print(
f"Optimized Hidet Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, "
Expand All @@ -216,11 +167,10 @@ def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, paddi
[4, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]],
[8, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]],
]
dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float16', cudnnDataType.CUDNN_DATA_HALF]]#, ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]]
dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float16', cudnnDataType.CUDNN_DATA_HALF]]

for data_type in dtypes:
for size in sizes:
benchmark_cudnn_conv2d_gemm(*(data_type + size))
benchmark_torch_conv2d(*(data_type + size))
benchmark_cudnn_conv2d_autoselect_algo(*(data_type + size))
benchmark_hidet_conv2d(*(data_type + size))
27 changes: 0 additions & 27 deletions python/hidet/cuda/cudnn/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,33 +95,6 @@ class cudnnDataType(IntEnum):
restype=None,
)

conv2d_autoselect_algo = get_func(
func_name='hidet_cudnn_conv2d_autoselect_algo',
arg_types=[
c_int32, # n
c_int32, # c
c_int32, # h
c_int32, # w
c_int32, # k
c_int32, # r
c_int32, # s
c_void_p, # ptr_x
c_void_p, # ptr_w
c_void_p, # ptr_y
c_int32, # tx
c_int32, # tw
c_int32, # ty
c_int32, # compute_type
c_int32, # pad_dim1
c_int32, # pad_dim2
c_int32, # str_dim1
c_int32, # str_dim2
c_int32, # dil_dim1
c_int32, # dil_dim2
],
restype=None,
)


@initialize()
def set_cudnn_library_path():
Expand Down
93 changes: 0 additions & 93 deletions python/hidet/cuda/cudnn/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,96 +210,3 @@ def conv2d_gemm(
dil_dim2,
)


def conv2d_autoselect_algo(
n: int,
c: int,
h: int,
w: int,
k: int,
r: int,
s: int,
ptr_x,
ptr_w,
ptr_y,
tx: Union[int, DataType],
tw: Union[int, DataType],
ty: Union[int, DataType],
compute_type: Union[int, cudnnDataType],
pad_dim1: int,
pad_dim2: int,
str_dim1: int,
str_dim2: int,
dil_dim1: int,
dil_dim2: int,
):
"""
Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y.
Parameters
----------
n: int
Batch number.
c: int
Number of channels in the input tensor x.
h: int
Height of the input tensor x.
w: int
Width of the input tensor x.
k: int
Number of channels in the output tensor y.
r: int
Height of the filter w.
s: int
Width of the filter w.
ptr_x: hidet.Tensor or int
Input tensor x, can be either a Tensor or an integer (the address of the tensor).
ptr_w: hidet.Tensor or int
Weight tensor w, can be either a Tensor or an integer (the address of the tensor).
ptr_y: hidet.Tensor or int
Output tensor y, can be either a Tensor or an integer (the address of the tensor).
tx: Union[int, DataType]
Type of elements in tensor x.
tw: Union[int, DataType]
Type of elements in tensor w.
ty: Union[int, DataType]
Type of elements in tensor y.
compute_type: Union[int, cudnnDataType]
The compute type of the operation.
For cuDNN, there's no such thing as a cudnnComputeType_t type.
As per the official example, the computeType is defined in terms of cudnnDataType_t
pad_dim1: int
The value to use for padding along the height dimension
pad_dim2: int
The value to use for padding along the width dimension
str_dim1: int
The stride to use for the height dimension
str_dim2: int
The stride to use for the width dimension
dil_dim1: int
The dilation to use for the height dimension
dil_dim2: int
The dilation to use for the width dimension
"""
ffi.conv2d_autoselect_algo(
n,
c,
h,
w,
k,
r,
s,
as_pointer(ptr_x),
as_pointer(ptr_w),
as_pointer(ptr_y),
as_cudnn_type(tx),
as_cudnn_type(tw),
as_cudnn_type(ty),
compute_type,
pad_dim1,
pad_dim2,
str_dim1,
str_dim2,
dil_dim1,
dil_dim2,
)
90 changes: 21 additions & 69 deletions src/hidet/runtime/cuda/cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <chrono>
#include <hidet/runtime/cuda/cuda.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/cuda/cudnn.h>
Expand Down Expand Up @@ -632,6 +633,7 @@ DLL void hidet_cudnn_conv2d_gemm(
int tx, int tw, int ty, int compute_type,
int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2)
{
auto begin1 = std::chrono::steady_clock::now();
lazy_load_cudnn();

cudnnHandle_t cur_handle = CudnnContext::current_handle();
Expand Down Expand Up @@ -659,90 +661,40 @@ DLL void hidet_cudnn_conv2d_gemm(
CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty),
out_n, out_c, out_h, out_w));
size_t workspaceSize{0};
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor,
convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
&workspaceSize));
void *workspace = request_cuda_workspace(workspaceSize, false);

// size_t workspaceSize{0};
// CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor,
// convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
// &workspaceSize));

size_t workspaceSize{2000000};
// std::cout << workspaceSize << std::endl;
// void *workspace = request_cuda_workspace(workspaceSize, false);
void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream);

void *p_alpha = nullptr;
void *p_beta = nullptr;
cudnnDataType_t compType = cudnnDataType_t(compute_type);
set_alpha_beta(&p_alpha, &p_beta, compType);
auto end1 = std::chrono::steady_clock::now();
std::cout << "Time difference 1 = " << std::chrono::duration_cast<std::chrono::microseconds>(end1 - begin1).count() << "[µs]" << std::endl;

auto begin2 = std::chrono::steady_clock::now();
CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w,
convolution_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
workspace, workspaceSize,
p_beta, output_descriptor, ptr_y));
auto end2 = std::chrono::steady_clock::now();
std::cout << "Time difference 2 = " << std::chrono::duration_cast<std::chrono::microseconds>(end2 - begin2).count() << "[µs]" << std::endl;

auto begin3 = std::chrono::steady_clock::now();
CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor));
CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor));
CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor));
CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor));
}

DLL void hidet_cudnn_conv2d_autoselect_algo(
int n, int c, int h, int w, int k, int r, int s,
void *ptr_x, void *ptr_w, void *ptr_y,
int tx, int tw, int ty, int compute_type,
int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2)
{
lazy_load_cudnn();

cudnnHandle_t cur_handle = CudnnContext::current_handle();

// Set the stream to the current stream
cudaStream_t cur_stream = get_cuda_stream();
CHECK_CUDNN(cudnnSetStream(cur_handle, cur_stream));

// Build descriptors and launch the kernel
cudnnTensorDescriptor_t input_descriptor;
CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(tx), n, c, h, w));
cudnnFilterDescriptor_t kernel_descriptor;
CHECK_CUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor));
CHECK_CUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor, cudnnDataType_t(tw), CUDNN_TENSOR_NCHW, k, c, r, s));
cudnnConvolutionDescriptor_t convolution_descriptor;
CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor));
CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, pad_dim1, pad_dim2, str_dim1, str_dim2, dil_dim1, dil_dim2,
CUDNN_CROSS_CORRELATION, cudnnDataType_t(compute_type)));

int out_n{0}, out_c{0}, out_h{0}, out_w{0};
CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, input_descriptor, kernel_descriptor,
&out_n, &out_c, &out_h, &out_w));
cudnnTensorDescriptor_t output_descriptor;
CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor));
CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty),
out_n, out_c, out_h, out_w));

int returnedAlgoCount;
cudnnConvolutionFwdAlgoPerf_t perfResults;

CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm_v7(cur_handle, input_descriptor, kernel_descriptor,
convolution_descriptor, output_descriptor,
1, &returnedAlgoCount, &perfResults));
cudnnConvolutionFwdAlgo_t convolution_algorithm = perfResults.algo;

size_t workspaceSize{0};
CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor,
convolution_descriptor, output_descriptor, convolution_algorithm,
&workspaceSize));
void *workspace = request_cuda_workspace(workspaceSize, false);

void *p_alpha = nullptr;
void *p_beta = nullptr;
cudnnDataType_t compType = cudnnDataType_t(compute_type);
set_alpha_beta(&p_alpha, &p_beta, compType);

CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w,
convolution_descriptor, convolution_algorithm,
workspace, workspaceSize,
p_beta, output_descriptor, ptr_y));

CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor));
CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor));
CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor));
CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor));
hidet_cuda_free_async(workspace, cur_stream);
auto end3 = std::chrono::steady_clock::now();
std::cout << "Time difference 3 = " << std::chrono::duration_cast<std::chrono::microseconds>(end3 - begin3).count() << "[µs]" << std::endl;
}

DLL void hidet_cudnn_conv2d(
Expand Down
Loading

0 comments on commit 40a6149

Please sign in to comment.