Skip to content

Commit

Permalink
[CUDNN] Add CuDNN performance benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Yudi Sun committed Mar 14, 2024
1 parent 4b37c61 commit f7ce7ef
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 1 deletion.
107 changes: 107 additions & 0 deletions python/hidet/cuda/cudnn/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import hidet
from hidet.cuda.cudnn import cudnnDataType
import numpy as np
import torch
from hidet.utils.benchmark import do_bench


def benchmark_cudnn_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
if dtype_str == "float32":
dtype = hidet.float32
elif dtype_str == "float64":
dtype = hidet.float64
else:
raise Exception("Unsupported DataType")

tx = tw = ty = dtype
pad_dim1, pad_dim2 = padding
str_dim1, str_dim2 = stride
dil_dim1, dil_dim2 = dilations

tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx)
tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw)
tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty)

latencies = do_bench(
lambda: hidet.cuda.cudnn.conv2d(
n,
c,
h,
w,
k,
r,
s,
p,
q,
tensor_x,
tensor_w,
tensor_y,
tx,
tw,
ty,
compute_type,
pad_dim1,
pad_dim2,
str_dim1,
str_dim2,
dil_dim1,
dil_dim2,
),
warmup=10,
rep=100,
)

print(
f"CuDNN Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, weight shape = {[k,c,r,s]}, "
f"padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds")
print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds")
print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds")
print("-------------------------------------------------")


def benchmark_torch_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations):
if dtype_str == "float32":
dtype = np.float32
elif dtype_str == "float64":
dtype = np.float64
else:
raise Exception("Unsupported DataType")

data = np.array(np.random.randn(n, c, h, w)).astype(dtype)
weight = np.array(np.random.randn(k, c, r, s)).astype(dtype)

data_torch, weight_torch = torch.from_numpy(data), torch.from_numpy(weight)
data_torch = data_torch.cuda()
weight_torch = weight_torch.cuda()

latencies = do_bench(
lambda: torch.nn.functional.conv2d(
data_torch, weight_torch, bias=None, stride=stride, padding=padding, dilation=dilations, groups=1
),
warmup=10,
rep=100,
)

print(
f"PyTorch Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, weight shape = {[k,c,r,s]}, "
f"padding = {padding}, stride = {stride}, dilations = {dilations}:"
)
print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds")
print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds")
print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds")
print("-------------------------------------------------")


if __name__ == '__main__':
sizes = [
[1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]],
[2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]],
]
dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]]

for dtype in dtypes:
for size in sizes:
benchmark_cudnn_conv2d(*(dtype + size))
benchmark_torch_conv2d(*(dtype + size))
2 changes: 1 addition & 1 deletion src/hidet/runtime/cuda/cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ DLL void hidet_cudnn_conv2d(

void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers
int64_t uids[3] = {'x', 'w', 'y'};
void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream);
void *workspace = request_cuda_workspace(workspaceSize, false);

cudnnBackendDescriptor_t varpack;
CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, &varpack));
Expand Down
1 change: 1 addition & 0 deletions tests/cuda/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
[1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3,
[2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2
[1, 3, 32, 32, 12, 16, 11, 1, 1, [0, 0], [2, 3], [1, 1]], # kernel 1,
[2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]],
],
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit f7ce7ef

Please sign in to comment.