From 4b37c6192fc33798212c7cf7a52c028510bced29 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Wed, 6 Mar 2024 21:23:33 -0500 Subject: [PATCH 01/12] Add cudnn conv2d --- CMakeLists.txt | 1 + include/hidet/runtime/cuda/cudnn.h | 18 + python/hidet/cuda/__init__.py | 1 + python/hidet/cuda/cudnn/__init__.py | 13 + python/hidet/cuda/cudnn/ffi.py | 82 ++++ python/hidet/cuda/cudnn/kernels.py | 117 ++++++ python/hidet/cuda/cudnn/utils.py | 41 ++ src/hidet/runtime/cuda/cudnn.cpp | 616 ++++++++++++++++++++++++++++ tests/cuda/test_cudnn.py | 73 ++++ 9 files changed, 962 insertions(+) create mode 100644 python/hidet/cuda/cudnn/__init__.py create mode 100644 python/hidet/cuda/cudnn/ffi.py create mode 100644 python/hidet/cuda/cudnn/kernels.py create mode 100644 python/hidet/cuda/cudnn/utils.py create mode 100644 src/hidet/runtime/cuda/cudnn.cpp create mode 100644 tests/cuda/test_cudnn.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dab979b7..3089d698d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") add_library(hidet_runtime SHARED src/hidet/runtime/cuda/context.cpp src/hidet/runtime/cuda/cublas.cpp + src/hidet/runtime/cuda/cudnn.cpp src/hidet/runtime/cuda/cuda.cpp src/hidet/runtime/cpu/context.cpp src/hidet/runtime/callbacks.cpp diff --git a/include/hidet/runtime/cuda/cudnn.h b/include/hidet/runtime/cuda/cudnn.h index 5653e0cb5..fcbc43697 100644 --- a/include/hidet/runtime/cuda/cudnn.h +++ b/include/hidet/runtime/cuda/cudnn.h @@ -9,3 +9,21 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#pragma once +#define HIDET_CUDNN_MAX_GPUS 32 + +#include + +struct cudnnContext; +typedef struct cudnnContext *cudnnHandle_t; + +typedef void *cudnnBackendDescriptor_t; + +struct CudnnContext { + cudnnHandle_t handles[HIDET_CUDNN_MAX_GPUS]; + static CudnnContext* global(); + static cudnnHandle_t current_handle(); +}; + +DLL void hidet_cudnn_set_library_path(const char* path); + diff --git a/python/hidet/cuda/__init__.py b/python/hidet/cuda/__init__.py index 7e6efbfa5..6f8c77b12 100644 --- a/python/hidet/cuda/__init__.py +++ b/python/hidet/cuda/__init__.py @@ -18,3 +18,4 @@ from .event import Event from . import cublas +from . import cudnn diff --git a/python/hidet/cuda/cudnn/__init__.py b/python/hidet/cuda/cudnn/__init__.py new file mode 100644 index 000000000..de471d1a6 --- /dev/null +++ b/python/hidet/cuda/cudnn/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .ffi import cudnnDataType +from .kernels import conv2d diff --git a/python/hidet/cuda/cudnn/ffi.py b/python/hidet/cuda/cudnn/ffi.py new file mode 100644 index 000000000..bb559fcda --- /dev/null +++ b/python/hidet/cuda/cudnn/ffi.py @@ -0,0 +1,82 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import glob +from enum import IntEnum +from ctypes import c_int32, c_void_p, c_char_p +from hidet.ffi.ffi import get_func + +from hidet.utils.py import initialize + + +class cudnnDataType(IntEnum): + """ + defined in cudnn_ops_infer_v8.h + """ + + CUDNN_DATA_FLOAT = 0 + CUDNN_DATA_DOUBLE = 1 + CUDNN_DATA_HALF = 2 + CUDNN_DATA_INT8 = 3 + CUDNN_DATA_INT32 = 4 + CUDNN_DATA_INT8x4 = 5 + CUDNN_DATA_UINT8 = 6 + CUDNN_DATA_UINT8x4 = 7 + CUDNN_DATA_INT8x32 = 8 + CUDNN_DATA_BFLOAT16 = 9 + CUDNN_DATA_INT64 = 10 + + +set_library_path = get_func(func_name='hidet_cudnn_set_library_path', arg_types=[c_char_p], restype=None) + +conv2d = get_func( + func_name='hidet_cudnn_conv2d', + arg_types=[ + c_int32, # n + c_int32, # c + c_int32, # h + c_int32, # w + c_int32, # k + c_int32, # r + c_int32, # s + c_int32, # p + c_int32, # q + c_void_p, # ptr_x + c_void_p, # ptr_w + c_void_p, # ptr_y + c_int32, # tx + c_int32, # tw + c_int32, # ty + c_int32, # compute_type + c_int32, # pad_dim1 + c_int32, # pad_dim2 + c_int32, # str_dim1 + c_int32, # str_dim2 + c_int32, # dil_dim1 + c_int32, # dil_dim2 + ], + restype=None, +) + + +@initialize() +def set_cudnn_library_path(): + # use nvidia-cuda-cudnn + for path in sys.path: + nvidia_path = os.path.join(path, 'nvidia') + if not os.path.exists(nvidia_path): + continue + cudnn_path = glob.glob(os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.[0-9]*')) + if cudnn_path: + set_library_path(cudnn_path[0].encode('utf-8')) + return diff --git a/python/hidet/cuda/cudnn/kernels.py b/python/hidet/cuda/cudnn/kernels.py new file mode 100644 index 000000000..2b1f3b2e9 --- /dev/null +++ b/python/hidet/cuda/cudnn/kernels.py @@ -0,0 +1,117 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union +from hidet.ir.dtypes import DataType +from .ffi import cudnnDataType +from . import ffi +from .utils import as_pointer, as_cudnn_type + + +def conv2d( + n: int, + c: int, + h: int, + w: int, + k: int, + r: int, + s: int, + p: int, + q: int, + ptr_x, + ptr_w, + ptr_y, + tx: Union[int, DataType], + tw: Union[int, DataType], + ty: Union[int, DataType], + compute_type: Union[int, cudnnDataType], + pad_dim1: int, + pad_dim2: int, + str_dim1: int, + str_dim2: int, + dil_dim1: int, + dil_dim2: int, +): + """ + Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y. + + Parameters + ---------- + n: int + Batch number. + c: int + Number of channels in the input tensor x. + h: int + Height of the input tensor x. + w: int + Width of the input tensor x. + k: int + Number of channels in the output tensor y. + r: int + Height of the filter w. + s: int + Width of the filter w. + p: int + Height of the output tensor y. + q: int + Width of the output tensor y. + ptr_x: hidet.Tensor or int + Input tensor x, can be either a Tensor or an integer (the address of the tensor). + ptr_w: hidet.Tensor or int + Weight tensor w, can be either a Tensor or an integer (the address of the tensor). + ptr_y: hidet.Tensor or int + Output tensor y, can be either a Tensor or an integer (the address of the tensor). + tx: Union[int, DataType] + Type of elements in tensor x. + tw: Union[int, DataType] + Type of elements in tensor w. + ty: Union[int, DataType] + Type of elements in tensor y. + compute_type: Union[int, cudnnDataType] + The compute type of the operation. + For cuDNN, there's no such thing as a cudnnComputeType_t type. + As per the official example, the computeType is defined in terms of cudnnDataType_t + pad_dim1: int + The value to use for padding along the height dimension + pad_dim2: int + The value to use for padding along the width dimension + str_dim1: int + The stride to use for the height dimension + str_dim2: int + The stride to use for the width dimension + dil_dim1: int + The dilation to use for the height dimension + dil_dim2: int + The dilation to use for the width dimension + """ + ffi.conv2d( + n, + c, + h, + w, + k, + r, + s, + p, + q, + as_pointer(ptr_x), + as_pointer(ptr_w), + as_pointer(ptr_y), + as_cudnn_type(tx), + as_cudnn_type(tw), + as_cudnn_type(ty), + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) diff --git a/python/hidet/cuda/cudnn/utils.py b/python/hidet/cuda/cudnn/utils.py new file mode 100644 index 000000000..a599a880b --- /dev/null +++ b/python/hidet/cuda/cudnn/utils.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir import dtypes +from hidet.ir.dtypes import DataType +from .ffi import cudnnDataType + +_cudnn_type_dict = { + dtypes.float32: cudnnDataType.CUDNN_DATA_FLOAT, + dtypes.float64: cudnnDataType.CUDNN_DATA_DOUBLE, + dtypes.int32: cudnnDataType.CUDNN_DATA_INT32, + dtypes.int64: cudnnDataType.CUDNN_DATA_INT64, +} + + +def as_pointer(obj) -> int: + from hidet.graph.tensor import Tensor + + if isinstance(obj, Tensor): + return obj.storage.addr + elif isinstance(obj, int): + return obj + else: + raise TypeError(f'Expected Tensor or int, but got {type(obj)}') + + +def as_cudnn_type(obj) -> int: + if isinstance(obj, DataType): + return _cudnn_type_dict[obj] + elif isinstance(obj, int): + return obj + else: + raise TypeError(f'Expected DataType or int, but got {type(obj)}') diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp new file mode 100644 index 000000000..468e5c2dc --- /dev/null +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -0,0 +1,616 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "./utils.h" + +/* + * CUDNN return codes - defined in cudnn_ops_infer_v8.h + */ +typedef enum { + CUDNN_STATUS_SUCCESS = 0, + CUDNN_STATUS_NOT_INITIALIZED = 1, + CUDNN_STATUS_ALLOC_FAILED = 2, + CUDNN_STATUS_BAD_PARAM = 3, + CUDNN_STATUS_INTERNAL_ERROR = 4, + CUDNN_STATUS_INVALID_VALUE = 5, + CUDNN_STATUS_ARCH_MISMATCH = 6, + CUDNN_STATUS_MAPPING_ERROR = 7, + CUDNN_STATUS_EXECUTION_FAILED = 8, + CUDNN_STATUS_NOT_SUPPORTED = 9, + CUDNN_STATUS_LICENSE_ERROR = 10, + CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING = 11, + CUDNN_STATUS_RUNTIME_IN_PROGRESS = 12, + CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 13, + CUDNN_STATUS_VERSION_MISMATCH = 14, +} cudnnStatus_t; + +/* +* CUDNN Descriptor Types - defined in cudnn_backend_v8.h +*/ +typedef enum { + CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0, + CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, + CUDNN_BACKEND_ENGINE_DESCRIPTOR, + CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, + CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, + CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, + CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR, + CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR, + CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR, + CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR, + CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, + CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, + CUDNN_BACKEND_TENSOR_DESCRIPTOR, + CUDNN_BACKEND_MATMUL_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR, + CUDNN_BACKEND_REDUCTION_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR, +} cudnnBackendDescriptorType_t; + +/* + * CUDNN data type - defined in cudnn_ops_infer_v8.h + */ +typedef enum { + CUDNN_DATA_FLOAT = 0, + CUDNN_DATA_DOUBLE = 1, + CUDNN_DATA_HALF = 2, + CUDNN_DATA_INT8 = 3, + CUDNN_DATA_INT32 = 4, + CUDNN_DATA_INT8x4 = 5, + CUDNN_DATA_UINT8 = 6, + CUDNN_DATA_UINT8x4 = 7, + CUDNN_DATA_INT8x32 = 8, + CUDNN_DATA_BFLOAT16 = 9, + CUDNN_DATA_INT64 = 10, +} cudnnDataType_t; + +/* +* CUDNN Backend Attribute Names - defined in cudnn_backend_v8.h +*/ +typedef enum { + CUDNN_ATTR_POINTWISE_MODE = 0, + CUDNN_ATTR_POINTWISE_MATH_PREC = 1, + CUDNN_ATTR_POINTWISE_NAN_PROPAGATION = 2, + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3, + CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4, + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5, + CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6, + CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7, + CUDNN_ATTR_POINTWISE_SWISH_BETA = 8, + + CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100, + CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101, + CUDNN_ATTR_CONVOLUTION_DILATIONS = 102, + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103, + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104, + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105, + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106, + + CUDNN_ATTR_ENGINEHEUR_MODE = 200, + CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201, + CUDNN_ATTR_ENGINEHEUR_RESULTS = 202, + + CUDNN_ATTR_ENGINECFG_ENGINE = 300, + CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301, + CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302, + + CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400, + CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401, + CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402, + CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403, + CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404, + + CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500, + CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501, + CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502, + CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503, + + CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600, + CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601, + + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717, + + CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750, + CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751, + CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752, + CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753, + CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754, + CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755, + CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756, + CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757, + + CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770, + CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771, + CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772, + CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773, + CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774, + + CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780, + CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781, + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782, + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784, + CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785, + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786, + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787, + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788, + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793, + CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796, + + CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800, + CUDNN_ATTR_OPERATIONGRAPH_OPS = 801, + CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802, + + CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900, + CUDNN_ATTR_TENSOR_DATA_TYPE = 901, + CUDNN_ATTR_TENSOR_DIMENSIONS = 902, + CUDNN_ATTR_TENSOR_STRIDES = 903, + CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904, + CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905, + CUDNN_ATTR_TENSOR_UNIQUE_ID = 906, + CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907, + CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908, + + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000, + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001, + CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003, + + CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100, + CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101, + + CUDNN_ATTR_KNOB_INFO_TYPE = 1200, + CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201, + CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202, + CUDNN_ATTR_KNOB_INFO_STRIDE = 1203, + + CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300, + CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301, + CUDNN_ATTR_ENGINE_KNOB_INFO = 1302, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303, + CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305, + + CUDNN_ATTR_MATMUL_COMP_TYPE = 1500, + + CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520, + CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521, + CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522, + CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523, + CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT = 1524, + + CUDNN_ATTR_REDUCTION_OPERATOR = 1600, + CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601, + + CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610, + CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611, + CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612, + + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630, +} cudnnBackendAttributeName_t; + +/* +* CUDNN Backend Attribute Type - defined in cudnn_backend_v8.h +*/ +typedef enum { + CUDNN_TYPE_HANDLE = 0, + CUDNN_TYPE_DATA_TYPE, + CUDNN_TYPE_BOOLEAN, + CUDNN_TYPE_INT64, + CUDNN_TYPE_FLOAT, + CUDNN_TYPE_DOUBLE, + CUDNN_TYPE_VOID_PTR, + CUDNN_TYPE_CONVOLUTION_MODE, + CUDNN_TYPE_HEUR_MODE, + CUDNN_TYPE_KNOB_TYPE, + CUDNN_TYPE_NAN_PROPOGATION, + CUDNN_TYPE_NUMERICAL_NOTE, + CUDNN_TYPE_LAYOUT_TYPE, + CUDNN_TYPE_ATTRIB_NAME, + CUDNN_TYPE_POINTWISE_MODE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + CUDNN_TYPE_GENSTATS_MODE, + CUDNN_TYPE_BN_FINALIZE_STATS_MODE, + CUDNN_TYPE_REDUCTION_OPERATOR_TYPE, + CUDNN_TYPE_BEHAVIOR_NOTE, +} cudnnBackendAttributeType_t; + +/* + * convolution mode - defined in cudnn_cnn_infer_v8.h + */ +typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t; + + +// define cudnn Graph API functions +typedef cudnnStatus_t (*cudnnCreate_t)(cudnnHandle_t *handle); +typedef const char * (*cudnnGetErrorString_t)(cudnnStatus_t status); +typedef cudnnStatus_t (*cudnnSetStream_t)( + cudnnHandle_t handle, + cudaStream_t streamId); +typedef cudnnStatus_t (*cudnnBackendCreateDescriptor_t)( + cudnnBackendDescriptorType_t descriptorType, + cudnnBackendDescriptor_t *descriptor); +typedef cudnnStatus_t (*cudnnBackendDestroyDescriptor_t)(cudnnBackendDescriptor_t descriptor); +typedef cudnnStatus_t (*cudnnBackendSetAttribute_t)( + cudnnBackendDescriptor_t descriptor, + cudnnBackendAttributeName_t attributeName, + cudnnBackendAttributeType_t attributeType, + int64_t elementCount, + void *arrayOfElements); +typedef cudnnStatus_t (*cudnnBackendGetAttribute_t)( + cudnnBackendDescriptor_t descriptor, + cudnnBackendAttributeName_t attributeName, + cudnnBackendAttributeType_t attributeType, + int64_t requestedElementCount, + int64_t *elementCount, + void *arrayOfElements); +typedef cudnnStatus_t (*cudnnBackendFinalize_t)(cudnnBackendDescriptor_t descriptor); +typedef cudnnStatus_t (*cudnnBackendExecute_t)( + cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t varianPack); + + +// cudnn api functions +static cudnnCreate_t cudnnCreate; +static cudnnGetErrorString_t cudnnGetErrorString; +static cudnnSetStream_t cudnnSetStream; +static cudnnBackendCreateDescriptor_t cudnnBackendCreateDescriptor; +static cudnnBackendDestroyDescriptor_t cudnnBackendDestroyDescriptor; +static cudnnBackendSetAttribute_t cudnnBackendSetAttribute; +static cudnnBackendGetAttribute_t cudnnBackendGetAttribute; +static cudnnBackendFinalize_t cudnnBackendFinalize; +static cudnnBackendExecute_t cudnnBackendExecute; + +static std::string library_path; +static void* libcudnn = nullptr; + +// utility functions +#define CHECK_CUDNN(status) do { \ + cudnnStatus_t err = (status); \ + if(err != 0) { \ + LOG(FATAL) << "cuDNN error: " << cudnnGetErrorString(err); \ + } \ +} while(0) + +static cudnnBackendAttributeType_t get_attribute_type_from_compute_type(cudnnDataType_t computeType) { + switch (computeType) { + case CUDNN_DATA_FLOAT: + return CUDNN_TYPE_FLOAT; + case CUDNN_DATA_DOUBLE: + return CUDNN_TYPE_DOUBLE; + case CUDNN_DATA_INT64: + case CUDNN_DATA_INT32: + return CUDNN_TYPE_INT64; + default: + LOG(FATAL) << "Unsupported compute type: " << computeType; + return CUDNN_TYPE_VOID_PTR; + } +} + +static void set_alpha_beta(void** p_alpha, void** p_beta, cudnnDataType_t c) { + // There's no such thing as a cudnnComputeType_t type. As per the official example, the computeType is defined + // in terms of cudnnDataType_t + // cudnnBackendAttributeType_t only has support for FLOAT, DOUBLE, and INT64. + if(c == CUDNN_DATA_FLOAT) { + static float alpha = 1.0f; + static float beta = 0.0f; + *p_alpha = α + *p_beta = β + } else if(c == CUDNN_DATA_DOUBLE) { + static double alpha = 1.0; + static double beta = 0.0; + *p_alpha = α + *p_beta = β + } else if(c == CUDNN_DATA_INT64 || c == CUDNN_DATA_INT32) { + static int64_t alpha = 1; + static int64_t beta = 0; + *p_alpha = α + *p_beta = β + } else { + LOG(FATAL) << "Unsupported compute type: " << c; + } +} + +static void lazy_load_cudnn() { + if(libcudnn == nullptr) { + // load cudnn shared library + const char* libpath; + if(library_path.empty()) { + libpath = "libcudnn.so"; + } else { + libpath = library_path.c_str(); + } + libcudnn = dlopen(libpath, RTLD_LAZY); + if(libcudnn == nullptr) { + LOG(FATAL) << "Failed to load cublas library: " << libpath << dlerror(); + } + + // load api functions + cudnnCreate = get_symbol(libcudnn, "cudnnCreate"); + cudnnGetErrorString = get_symbol(libcudnn, "cudnnGetErrorString"); + cudnnSetStream = get_symbol(libcudnn, "cudnnSetStream"); + cudnnBackendCreateDescriptor = get_symbol(libcudnn, "cudnnBackendCreateDescriptor"); + cudnnBackendDestroyDescriptor = get_symbol(libcudnn, "cudnnBackendDestroyDescriptor"); + cudnnBackendSetAttribute = get_symbol(libcudnn, "cudnnBackendSetAttribute"); + cudnnBackendGetAttribute = get_symbol(libcudnn, "cudnnBackendGetAttribute"); + cudnnBackendFinalize = get_symbol(libcudnn, "cudnnBackendFinalize"); + cudnnBackendExecute = get_symbol(libcudnn, "cudnnBackendExecute"); + } +} + + +CudnnContext* CudnnContext::global() { + static CudnnContext instance; + static bool initialized = false; + + if(!initialized) { + // create cudnn handle for each gpu + int count = hidet_cuda_device_count(); + assert(count <= HIDET_CUBLAS_MAX_GPUS); + + int current_device = hidet_cuda_get_device(); + for(int i = 0; i < count; i++) { + hidet_cuda_set_device(i); + CHECK_CUDNN(cudnnCreate(&instance.handles[i])); + } + hidet_cuda_set_device(current_device); + + initialized = true; + } + return &instance; +} + +cudnnHandle_t CudnnContext::current_handle() { + return CudnnContext::global()->handles[hidet_cuda_get_device()]; +} + + +// hidet cudnn api functions +DLL void hidet_cudnn_set_library_path(const char* path) { + if(path) { + library_path = path; + } +} + +DLL void hidet_cudnn_conv2d( + int n, int c, int h, int w, int k, int r, int s, int p, int q, + void *ptr_x, void *ptr_w, void *ptr_y, + int tx, int tw, int ty, int compute_type, + int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2 +) { + lazy_load_cudnn(); + + cudnnHandle_t cur_handle = CudnnContext::current_handle(); + + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUDNN(cudnnSetStream(cur_handle, cur_stream)); + + // Build the descriptor for x + int64_t xDim[] = {n, c, h, w}; + int64_t xStr[] = {c * h * w, h * w, w, 1}; + int64_t xUi = 'x'; + int64_t alignment = 8; + cudnnBackendDescriptor_t xDesc; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &xDesc)); + cudnnDataType_t xDtype = cudnnDataType_t(tx); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &xDtype)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, + CUDNN_TYPE_INT64, 4, xDim)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_STRIDES, + CUDNN_TYPE_INT64, 4, xStr)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, + CUDNN_TYPE_INT64, 1, &xUi)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + CUDNN_TYPE_INT64, 1, &alignment)); + CHECK_CUDNN(cudnnBackendFinalize(xDesc)); + + // Build the descriptor for w + int64_t wDim[] = {k, c, r, s}; + int64_t wStr[] = {c * r * s, r * s, s, 1}; + int64_t wUi = 'w'; + cudnnBackendDescriptor_t wDesc; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &wDesc)); + cudnnDataType_t wDtype = cudnnDataType_t(tw); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &wDtype)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, + CUDNN_TYPE_INT64, 4, wDim)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_STRIDES, + CUDNN_TYPE_INT64, 4, wStr)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, + CUDNN_TYPE_INT64, 1, &wUi)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + CUDNN_TYPE_INT64, 1, &alignment)); + CHECK_CUDNN(cudnnBackendFinalize(wDesc)); + + // Build the descriptor for y + int64_t yDim[] = {n, k, p, q}; + int64_t yStr[] = {k * p * q, p * q, q, 1}; + int64_t yUi = 'y'; + cudnnBackendDescriptor_t yDesc; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &yDesc)); + cudnnDataType_t yDtype = cudnnDataType_t(ty); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &yDtype)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, + CUDNN_TYPE_INT64, 4, yDim)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_STRIDES, + CUDNN_TYPE_INT64, 4, yStr)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, + CUDNN_TYPE_INT64, 1, &yUi)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + CUDNN_TYPE_INT64, 1, &alignment)); + CHECK_CUDNN(cudnnBackendFinalize(yDesc)); + + // Build the descriptor for the convolution operator + cudnnBackendDescriptor_t cDesc; + int64_t nbDims = 2; + cudnnDataType_t compType = cudnnDataType_t(compute_type); + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + int64_t pad[] = {pad_dim1, pad_dim2}; + int64_t filterStr[] = {str_dim1, str_dim2}; + int64_t dilation[] = {dil_dim1, dil_dim2}; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, &cDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, 1, &nbDims)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &compType)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_CONV_MODE, + CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, nbDims, pad)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, nbDims, pad)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, nbDims, dilation)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, nbDims, filterStr)); + CHECK_CUDNN(cudnnBackendFinalize(cDesc)); + + // Build the descriptor for the convolution forward operation + cudnnBackendDescriptor_t fprop; + void *p_alpha = nullptr; + void *p_beta = nullptr; + set_alpha_beta(&p_alpha, &p_beta, compType); + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, + &fprop)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &xDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &wDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &yDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &cDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, + get_attribute_type_from_compute_type(compType), 1, p_alpha)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, + get_attribute_type_from_compute_type(compType), 1, p_beta)); + CHECK_CUDNN(cudnnBackendFinalize(fprop)); + + // Build the operation graph descriptor + cudnnBackendDescriptor_t op_graph; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, &op_graph)); + CHECK_CUDNN(cudnnBackendSetAttribute(op_graph, CUDNN_ATTR_OPERATIONGRAPH_OPS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &fprop)); + CHECK_CUDNN(cudnnBackendSetAttribute(op_graph, CUDNN_ATTR_OPERATIONGRAPH_HANDLE, + CUDNN_TYPE_HANDLE, 1, &cur_handle)); + CHECK_CUDNN(cudnnBackendFinalize(op_graph)); + + // Set up engine config + cudnnBackendDescriptor_t engine; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINE_DESCRIPTOR, &engine)); + CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_OPERATION_GRAPH, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); + // TODO: Is it okay to hardcode the engine to be CUDNN_ATTR_ENGINE_GLOBAL_INDEX 0? + // As mentioned here: https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html, + // Engine selection should be determined based on some heuristics. + int64_t gidx = 0; + CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_GLOBAL_INDEX, + CUDNN_TYPE_INT64, 1, &gidx)); + CHECK_CUDNN(cudnnBackendFinalize(engine)); + + cudnnBackendDescriptor_t engcfg; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, &engcfg)); + CHECK_CUDNN(cudnnBackendSetAttribute(engcfg, CUDNN_ATTR_ENGINECFG_ENGINE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engine)); + CHECK_CUDNN(cudnnBackendFinalize(engcfg)); + + // Set up the execution plan + cudnnBackendDescriptor_t plan; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, &plan)); + CHECK_CUDNN(cudnnBackendSetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_HANDLE, CUDNN_TYPE_HANDLE, 1, &cur_handle)); + CHECK_CUDNN(cudnnBackendSetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engcfg)); + CHECK_CUDNN(cudnnBackendFinalize(plan)); + + int64_t workspaceSize; + CHECK_CUDNN(cudnnBackendGetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE, + CUDNN_TYPE_INT64, 1, NULL, &workspaceSize)); + + void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers + int64_t uids[3] = {'x', 'w', 'y'}; + void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); + + cudnnBackendDescriptor_t varpack; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, &varpack)); + CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, + CUDNN_TYPE_VOID_PTR, 3, dev_ptrs)); + CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, + CUDNN_TYPE_INT64, 3, uids)); + CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_WORKSPACE, + CUDNN_TYPE_VOID_PTR, 1, &workspace)); + CHECK_CUDNN(cudnnBackendFinalize(varpack)); + + // Execute the plan + CHECK_CUDNN(cudnnBackendExecute(cur_handle, plan, varpack)); + + // Cleanup + hidet_cuda_free_async(workspace, cur_stream); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(xDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(wDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(yDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(cDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(fprop)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(op_graph)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(engine)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(engcfg)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(plan)); +} + + diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py new file mode 100644 index 000000000..4c9b3bfb6 --- /dev/null +++ b/tests/cuda/test_cudnn.py @@ -0,0 +1,73 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import math +import hidet +from hidet import ops +from hidet.cuda.cudnn import cudnnDataType + + +@pytest.mark.parametrize( + "n, c, h, w, k, p, q, r, s, padding, stride, dilations", + [ + [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, + [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 + [1, 3, 32, 32, 12, 16, 11, 1, 1, [0, 0], [2, 3], [1, 1]], # kernel 1, + ], +) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [(hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8)], +) +def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + golden = ops.conv2d( + tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2) + ) + hidet.cuda.cudnn.conv2d( + n, + c, + h, + w, + k, + r, + s, + p, + q, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) + + hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) + + +if __name__ == '__main__': + pytest.main([__file__]) From 68192fdfd9a5174d20daf005adc102c21976ed0f Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Thu, 14 Mar 2024 18:50:53 -0400 Subject: [PATCH 02/12] [CUDNN] Add CuDNN performance benchmarks --- python/hidet/cuda/cudnn/benchmark.py | 108 +++++++++++++++++++++++++++ src/hidet/runtime/cuda/cudnn.cpp | 2 +- tests/cuda/test_cudnn.py | 1 + 3 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 python/hidet/cuda/cudnn/benchmark.py diff --git a/python/hidet/cuda/cudnn/benchmark.py b/python/hidet/cuda/cudnn/benchmark.py new file mode 100644 index 000000000..83666115b --- /dev/null +++ b/python/hidet/cuda/cudnn/benchmark.py @@ -0,0 +1,108 @@ +import numpy as np +import torch + +import hidet +from hidet.cuda.cudnn import cudnnDataType +from hidet.utils.benchmark import do_bench + + +def benchmark_cudnn_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + if dtype_str == "float32": + dtype = hidet.float32 + elif dtype_str == "float64": + dtype = hidet.float64 + else: + raise Exception("Unsupported DataType") + + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + latencies = do_bench( + lambda: hidet.cuda.cudnn.conv2d( + n, + c, + h, + w, + k, + r, + s, + p, + q, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ), + warmup=10, + rep=100, + ) + + print( + f"CuDNN Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds") + print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds") + print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_torch_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + if dtype_str == "float32": + dtype = np.float32 + elif dtype_str == "float64": + dtype = np.float64 + else: + raise Exception("Unsupported DataType") + + data = np.array(np.random.randn(n, c, h, w)).astype(dtype) + weight = np.array(np.random.randn(k, c, r, s)).astype(dtype) + + data_torch, weight_torch = torch.from_numpy(data), torch.from_numpy(weight) + data_torch = data_torch.cuda() + weight_torch = weight_torch.cuda() + + latencies = do_bench( + lambda: torch.nn.functional.conv2d( + data_torch, weight_torch, bias=None, stride=stride, padding=padding, dilation=dilations, groups=1 + ), + warmup=10, + rep=100, + ) + + print( + f"PyTorch Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds") + print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds") + print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds") + print("-------------------------------------------------") + + +if __name__ == '__main__': + sizes = [ + [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], + [2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]], + ] + dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]] + + for data_type in dtypes: + for size in sizes: + benchmark_cudnn_conv2d(*(data_type + size)) + benchmark_torch_conv2d(*(data_type + size)) diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp index 468e5c2dc..c33e1b3e1 100644 --- a/src/hidet/runtime/cuda/cudnn.cpp +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -585,7 +585,7 @@ DLL void hidet_cudnn_conv2d( void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers int64_t uids[3] = {'x', 'w', 'y'}; - void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); + void *workspace = request_cuda_workspace(workspaceSize, false); cudnnBackendDescriptor_t varpack; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, &varpack)); diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 4c9b3bfb6..abaacbc94 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -22,6 +22,7 @@ [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 [1, 3, 32, 32, 12, 16, 11, 1, 1, [0, 0], [2, 3], [1, 1]], # kernel 1, + [2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]], ], ) @pytest.mark.parametrize( From 90a1791dc9ddc6fa4795cb8092ee8ad36272b331 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Wed, 6 Mar 2024 21:23:33 -0500 Subject: [PATCH 03/12] Add cudnn conv2d --- CMakeLists.txt | 1 + include/hidet/runtime/cuda/cudnn.h | 18 + python/hidet/cuda/__init__.py | 1 + python/hidet/cuda/cudnn/__init__.py | 13 + python/hidet/cuda/cudnn/ffi.py | 82 ++++ python/hidet/cuda/cudnn/kernels.py | 117 ++++++ python/hidet/cuda/cudnn/utils.py | 41 ++ src/hidet/runtime/cuda/cudnn.cpp | 616 ++++++++++++++++++++++++++++ tests/cuda/test_cudnn.py | 73 ++++ 9 files changed, 962 insertions(+) create mode 100644 python/hidet/cuda/cudnn/__init__.py create mode 100644 python/hidet/cuda/cudnn/ffi.py create mode 100644 python/hidet/cuda/cudnn/kernels.py create mode 100644 python/hidet/cuda/cudnn/utils.py create mode 100644 src/hidet/runtime/cuda/cudnn.cpp create mode 100644 tests/cuda/test_cudnn.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dab979b7..3089d698d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") add_library(hidet_runtime SHARED src/hidet/runtime/cuda/context.cpp src/hidet/runtime/cuda/cublas.cpp + src/hidet/runtime/cuda/cudnn.cpp src/hidet/runtime/cuda/cuda.cpp src/hidet/runtime/cpu/context.cpp src/hidet/runtime/callbacks.cpp diff --git a/include/hidet/runtime/cuda/cudnn.h b/include/hidet/runtime/cuda/cudnn.h index 5653e0cb5..fcbc43697 100644 --- a/include/hidet/runtime/cuda/cudnn.h +++ b/include/hidet/runtime/cuda/cudnn.h @@ -9,3 +9,21 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#pragma once +#define HIDET_CUDNN_MAX_GPUS 32 + +#include + +struct cudnnContext; +typedef struct cudnnContext *cudnnHandle_t; + +typedef void *cudnnBackendDescriptor_t; + +struct CudnnContext { + cudnnHandle_t handles[HIDET_CUDNN_MAX_GPUS]; + static CudnnContext* global(); + static cudnnHandle_t current_handle(); +}; + +DLL void hidet_cudnn_set_library_path(const char* path); + diff --git a/python/hidet/cuda/__init__.py b/python/hidet/cuda/__init__.py index 7e6efbfa5..6f8c77b12 100644 --- a/python/hidet/cuda/__init__.py +++ b/python/hidet/cuda/__init__.py @@ -18,3 +18,4 @@ from .event import Event from . import cublas +from . import cudnn diff --git a/python/hidet/cuda/cudnn/__init__.py b/python/hidet/cuda/cudnn/__init__.py new file mode 100644 index 000000000..de471d1a6 --- /dev/null +++ b/python/hidet/cuda/cudnn/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .ffi import cudnnDataType +from .kernels import conv2d diff --git a/python/hidet/cuda/cudnn/ffi.py b/python/hidet/cuda/cudnn/ffi.py new file mode 100644 index 000000000..bb559fcda --- /dev/null +++ b/python/hidet/cuda/cudnn/ffi.py @@ -0,0 +1,82 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import glob +from enum import IntEnum +from ctypes import c_int32, c_void_p, c_char_p +from hidet.ffi.ffi import get_func + +from hidet.utils.py import initialize + + +class cudnnDataType(IntEnum): + """ + defined in cudnn_ops_infer_v8.h + """ + + CUDNN_DATA_FLOAT = 0 + CUDNN_DATA_DOUBLE = 1 + CUDNN_DATA_HALF = 2 + CUDNN_DATA_INT8 = 3 + CUDNN_DATA_INT32 = 4 + CUDNN_DATA_INT8x4 = 5 + CUDNN_DATA_UINT8 = 6 + CUDNN_DATA_UINT8x4 = 7 + CUDNN_DATA_INT8x32 = 8 + CUDNN_DATA_BFLOAT16 = 9 + CUDNN_DATA_INT64 = 10 + + +set_library_path = get_func(func_name='hidet_cudnn_set_library_path', arg_types=[c_char_p], restype=None) + +conv2d = get_func( + func_name='hidet_cudnn_conv2d', + arg_types=[ + c_int32, # n + c_int32, # c + c_int32, # h + c_int32, # w + c_int32, # k + c_int32, # r + c_int32, # s + c_int32, # p + c_int32, # q + c_void_p, # ptr_x + c_void_p, # ptr_w + c_void_p, # ptr_y + c_int32, # tx + c_int32, # tw + c_int32, # ty + c_int32, # compute_type + c_int32, # pad_dim1 + c_int32, # pad_dim2 + c_int32, # str_dim1 + c_int32, # str_dim2 + c_int32, # dil_dim1 + c_int32, # dil_dim2 + ], + restype=None, +) + + +@initialize() +def set_cudnn_library_path(): + # use nvidia-cuda-cudnn + for path in sys.path: + nvidia_path = os.path.join(path, 'nvidia') + if not os.path.exists(nvidia_path): + continue + cudnn_path = glob.glob(os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.[0-9]*')) + if cudnn_path: + set_library_path(cudnn_path[0].encode('utf-8')) + return diff --git a/python/hidet/cuda/cudnn/kernels.py b/python/hidet/cuda/cudnn/kernels.py new file mode 100644 index 000000000..2b1f3b2e9 --- /dev/null +++ b/python/hidet/cuda/cudnn/kernels.py @@ -0,0 +1,117 @@ +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union +from hidet.ir.dtypes import DataType +from .ffi import cudnnDataType +from . import ffi +from .utils import as_pointer, as_cudnn_type + + +def conv2d( + n: int, + c: int, + h: int, + w: int, + k: int, + r: int, + s: int, + p: int, + q: int, + ptr_x, + ptr_w, + ptr_y, + tx: Union[int, DataType], + tw: Union[int, DataType], + ty: Union[int, DataType], + compute_type: Union[int, cudnnDataType], + pad_dim1: int, + pad_dim2: int, + str_dim1: int, + str_dim2: int, + dil_dim1: int, + dil_dim2: int, +): + """ + Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y. + + Parameters + ---------- + n: int + Batch number. + c: int + Number of channels in the input tensor x. + h: int + Height of the input tensor x. + w: int + Width of the input tensor x. + k: int + Number of channels in the output tensor y. + r: int + Height of the filter w. + s: int + Width of the filter w. + p: int + Height of the output tensor y. + q: int + Width of the output tensor y. + ptr_x: hidet.Tensor or int + Input tensor x, can be either a Tensor or an integer (the address of the tensor). + ptr_w: hidet.Tensor or int + Weight tensor w, can be either a Tensor or an integer (the address of the tensor). + ptr_y: hidet.Tensor or int + Output tensor y, can be either a Tensor or an integer (the address of the tensor). + tx: Union[int, DataType] + Type of elements in tensor x. + tw: Union[int, DataType] + Type of elements in tensor w. + ty: Union[int, DataType] + Type of elements in tensor y. + compute_type: Union[int, cudnnDataType] + The compute type of the operation. + For cuDNN, there's no such thing as a cudnnComputeType_t type. + As per the official example, the computeType is defined in terms of cudnnDataType_t + pad_dim1: int + The value to use for padding along the height dimension + pad_dim2: int + The value to use for padding along the width dimension + str_dim1: int + The stride to use for the height dimension + str_dim2: int + The stride to use for the width dimension + dil_dim1: int + The dilation to use for the height dimension + dil_dim2: int + The dilation to use for the width dimension + """ + ffi.conv2d( + n, + c, + h, + w, + k, + r, + s, + p, + q, + as_pointer(ptr_x), + as_pointer(ptr_w), + as_pointer(ptr_y), + as_cudnn_type(tx), + as_cudnn_type(tw), + as_cudnn_type(ty), + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) diff --git a/python/hidet/cuda/cudnn/utils.py b/python/hidet/cuda/cudnn/utils.py new file mode 100644 index 000000000..a599a880b --- /dev/null +++ b/python/hidet/cuda/cudnn/utils.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir import dtypes +from hidet.ir.dtypes import DataType +from .ffi import cudnnDataType + +_cudnn_type_dict = { + dtypes.float32: cudnnDataType.CUDNN_DATA_FLOAT, + dtypes.float64: cudnnDataType.CUDNN_DATA_DOUBLE, + dtypes.int32: cudnnDataType.CUDNN_DATA_INT32, + dtypes.int64: cudnnDataType.CUDNN_DATA_INT64, +} + + +def as_pointer(obj) -> int: + from hidet.graph.tensor import Tensor + + if isinstance(obj, Tensor): + return obj.storage.addr + elif isinstance(obj, int): + return obj + else: + raise TypeError(f'Expected Tensor or int, but got {type(obj)}') + + +def as_cudnn_type(obj) -> int: + if isinstance(obj, DataType): + return _cudnn_type_dict[obj] + elif isinstance(obj, int): + return obj + else: + raise TypeError(f'Expected DataType or int, but got {type(obj)}') diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp new file mode 100644 index 000000000..468e5c2dc --- /dev/null +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -0,0 +1,616 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "./utils.h" + +/* + * CUDNN return codes - defined in cudnn_ops_infer_v8.h + */ +typedef enum { + CUDNN_STATUS_SUCCESS = 0, + CUDNN_STATUS_NOT_INITIALIZED = 1, + CUDNN_STATUS_ALLOC_FAILED = 2, + CUDNN_STATUS_BAD_PARAM = 3, + CUDNN_STATUS_INTERNAL_ERROR = 4, + CUDNN_STATUS_INVALID_VALUE = 5, + CUDNN_STATUS_ARCH_MISMATCH = 6, + CUDNN_STATUS_MAPPING_ERROR = 7, + CUDNN_STATUS_EXECUTION_FAILED = 8, + CUDNN_STATUS_NOT_SUPPORTED = 9, + CUDNN_STATUS_LICENSE_ERROR = 10, + CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING = 11, + CUDNN_STATUS_RUNTIME_IN_PROGRESS = 12, + CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 13, + CUDNN_STATUS_VERSION_MISMATCH = 14, +} cudnnStatus_t; + +/* +* CUDNN Descriptor Types - defined in cudnn_backend_v8.h +*/ +typedef enum { + CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0, + CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, + CUDNN_BACKEND_ENGINE_DESCRIPTOR, + CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, + CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, + CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, + CUDNN_BACKEND_INTERMEDIATE_INFO_DESCRIPTOR, + CUDNN_BACKEND_KNOB_CHOICE_DESCRIPTOR, + CUDNN_BACKEND_KNOB_INFO_DESCRIPTOR, + CUDNN_BACKEND_LAYOUT_INFO_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_GEN_STATS_DESCRIPTOR, + CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, + CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, + CUDNN_BACKEND_TENSOR_DESCRIPTOR, + CUDNN_BACKEND_MATMUL_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR, + CUDNN_BACKEND_REDUCTION_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR, + CUDNN_BACKEND_OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR, +} cudnnBackendDescriptorType_t; + +/* + * CUDNN data type - defined in cudnn_ops_infer_v8.h + */ +typedef enum { + CUDNN_DATA_FLOAT = 0, + CUDNN_DATA_DOUBLE = 1, + CUDNN_DATA_HALF = 2, + CUDNN_DATA_INT8 = 3, + CUDNN_DATA_INT32 = 4, + CUDNN_DATA_INT8x4 = 5, + CUDNN_DATA_UINT8 = 6, + CUDNN_DATA_UINT8x4 = 7, + CUDNN_DATA_INT8x32 = 8, + CUDNN_DATA_BFLOAT16 = 9, + CUDNN_DATA_INT64 = 10, +} cudnnDataType_t; + +/* +* CUDNN Backend Attribute Names - defined in cudnn_backend_v8.h +*/ +typedef enum { + CUDNN_ATTR_POINTWISE_MODE = 0, + CUDNN_ATTR_POINTWISE_MATH_PREC = 1, + CUDNN_ATTR_POINTWISE_NAN_PROPAGATION = 2, + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3, + CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4, + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5, + CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6, + CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7, + CUDNN_ATTR_POINTWISE_SWISH_BETA = 8, + + CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100, + CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101, + CUDNN_ATTR_CONVOLUTION_DILATIONS = 102, + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103, + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104, + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105, + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106, + + CUDNN_ATTR_ENGINEHEUR_MODE = 200, + CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201, + CUDNN_ATTR_ENGINEHEUR_RESULTS = 202, + + CUDNN_ATTR_ENGINECFG_ENGINE = 300, + CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301, + CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302, + + CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400, + CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401, + CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402, + CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403, + CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404, + + CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500, + CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501, + CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502, + CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503, + + CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600, + CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601, + + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717, + + CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750, + CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751, + CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752, + CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753, + CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754, + CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755, + CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756, + CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757, + + CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770, + CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771, + CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772, + CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773, + CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774, + + CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780, + CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781, + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782, + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784, + CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785, + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786, + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787, + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788, + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793, + CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796, + + CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800, + CUDNN_ATTR_OPERATIONGRAPH_OPS = 801, + CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802, + + CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900, + CUDNN_ATTR_TENSOR_DATA_TYPE = 901, + CUDNN_ATTR_TENSOR_DIMENSIONS = 902, + CUDNN_ATTR_TENSOR_STRIDES = 903, + CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904, + CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905, + CUDNN_ATTR_TENSOR_UNIQUE_ID = 906, + CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907, + CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908, + + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000, + CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001, + CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003, + + CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100, + CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101, + + CUDNN_ATTR_KNOB_INFO_TYPE = 1200, + CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201, + CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202, + CUDNN_ATTR_KNOB_INFO_STRIDE = 1203, + + CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300, + CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301, + CUDNN_ATTR_ENGINE_KNOB_INFO = 1302, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303, + CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305, + + CUDNN_ATTR_MATMUL_COMP_TYPE = 1500, + + CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520, + CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521, + CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522, + CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523, + CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT = 1524, + + CUDNN_ATTR_REDUCTION_OPERATOR = 1600, + CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601, + + CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610, + CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611, + CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612, + + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630, +} cudnnBackendAttributeName_t; + +/* +* CUDNN Backend Attribute Type - defined in cudnn_backend_v8.h +*/ +typedef enum { + CUDNN_TYPE_HANDLE = 0, + CUDNN_TYPE_DATA_TYPE, + CUDNN_TYPE_BOOLEAN, + CUDNN_TYPE_INT64, + CUDNN_TYPE_FLOAT, + CUDNN_TYPE_DOUBLE, + CUDNN_TYPE_VOID_PTR, + CUDNN_TYPE_CONVOLUTION_MODE, + CUDNN_TYPE_HEUR_MODE, + CUDNN_TYPE_KNOB_TYPE, + CUDNN_TYPE_NAN_PROPOGATION, + CUDNN_TYPE_NUMERICAL_NOTE, + CUDNN_TYPE_LAYOUT_TYPE, + CUDNN_TYPE_ATTRIB_NAME, + CUDNN_TYPE_POINTWISE_MODE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + CUDNN_TYPE_GENSTATS_MODE, + CUDNN_TYPE_BN_FINALIZE_STATS_MODE, + CUDNN_TYPE_REDUCTION_OPERATOR_TYPE, + CUDNN_TYPE_BEHAVIOR_NOTE, +} cudnnBackendAttributeType_t; + +/* + * convolution mode - defined in cudnn_cnn_infer_v8.h + */ +typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t; + + +// define cudnn Graph API functions +typedef cudnnStatus_t (*cudnnCreate_t)(cudnnHandle_t *handle); +typedef const char * (*cudnnGetErrorString_t)(cudnnStatus_t status); +typedef cudnnStatus_t (*cudnnSetStream_t)( + cudnnHandle_t handle, + cudaStream_t streamId); +typedef cudnnStatus_t (*cudnnBackendCreateDescriptor_t)( + cudnnBackendDescriptorType_t descriptorType, + cudnnBackendDescriptor_t *descriptor); +typedef cudnnStatus_t (*cudnnBackendDestroyDescriptor_t)(cudnnBackendDescriptor_t descriptor); +typedef cudnnStatus_t (*cudnnBackendSetAttribute_t)( + cudnnBackendDescriptor_t descriptor, + cudnnBackendAttributeName_t attributeName, + cudnnBackendAttributeType_t attributeType, + int64_t elementCount, + void *arrayOfElements); +typedef cudnnStatus_t (*cudnnBackendGetAttribute_t)( + cudnnBackendDescriptor_t descriptor, + cudnnBackendAttributeName_t attributeName, + cudnnBackendAttributeType_t attributeType, + int64_t requestedElementCount, + int64_t *elementCount, + void *arrayOfElements); +typedef cudnnStatus_t (*cudnnBackendFinalize_t)(cudnnBackendDescriptor_t descriptor); +typedef cudnnStatus_t (*cudnnBackendExecute_t)( + cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t varianPack); + + +// cudnn api functions +static cudnnCreate_t cudnnCreate; +static cudnnGetErrorString_t cudnnGetErrorString; +static cudnnSetStream_t cudnnSetStream; +static cudnnBackendCreateDescriptor_t cudnnBackendCreateDescriptor; +static cudnnBackendDestroyDescriptor_t cudnnBackendDestroyDescriptor; +static cudnnBackendSetAttribute_t cudnnBackendSetAttribute; +static cudnnBackendGetAttribute_t cudnnBackendGetAttribute; +static cudnnBackendFinalize_t cudnnBackendFinalize; +static cudnnBackendExecute_t cudnnBackendExecute; + +static std::string library_path; +static void* libcudnn = nullptr; + +// utility functions +#define CHECK_CUDNN(status) do { \ + cudnnStatus_t err = (status); \ + if(err != 0) { \ + LOG(FATAL) << "cuDNN error: " << cudnnGetErrorString(err); \ + } \ +} while(0) + +static cudnnBackendAttributeType_t get_attribute_type_from_compute_type(cudnnDataType_t computeType) { + switch (computeType) { + case CUDNN_DATA_FLOAT: + return CUDNN_TYPE_FLOAT; + case CUDNN_DATA_DOUBLE: + return CUDNN_TYPE_DOUBLE; + case CUDNN_DATA_INT64: + case CUDNN_DATA_INT32: + return CUDNN_TYPE_INT64; + default: + LOG(FATAL) << "Unsupported compute type: " << computeType; + return CUDNN_TYPE_VOID_PTR; + } +} + +static void set_alpha_beta(void** p_alpha, void** p_beta, cudnnDataType_t c) { + // There's no such thing as a cudnnComputeType_t type. As per the official example, the computeType is defined + // in terms of cudnnDataType_t + // cudnnBackendAttributeType_t only has support for FLOAT, DOUBLE, and INT64. + if(c == CUDNN_DATA_FLOAT) { + static float alpha = 1.0f; + static float beta = 0.0f; + *p_alpha = α + *p_beta = β + } else if(c == CUDNN_DATA_DOUBLE) { + static double alpha = 1.0; + static double beta = 0.0; + *p_alpha = α + *p_beta = β + } else if(c == CUDNN_DATA_INT64 || c == CUDNN_DATA_INT32) { + static int64_t alpha = 1; + static int64_t beta = 0; + *p_alpha = α + *p_beta = β + } else { + LOG(FATAL) << "Unsupported compute type: " << c; + } +} + +static void lazy_load_cudnn() { + if(libcudnn == nullptr) { + // load cudnn shared library + const char* libpath; + if(library_path.empty()) { + libpath = "libcudnn.so"; + } else { + libpath = library_path.c_str(); + } + libcudnn = dlopen(libpath, RTLD_LAZY); + if(libcudnn == nullptr) { + LOG(FATAL) << "Failed to load cublas library: " << libpath << dlerror(); + } + + // load api functions + cudnnCreate = get_symbol(libcudnn, "cudnnCreate"); + cudnnGetErrorString = get_symbol(libcudnn, "cudnnGetErrorString"); + cudnnSetStream = get_symbol(libcudnn, "cudnnSetStream"); + cudnnBackendCreateDescriptor = get_symbol(libcudnn, "cudnnBackendCreateDescriptor"); + cudnnBackendDestroyDescriptor = get_symbol(libcudnn, "cudnnBackendDestroyDescriptor"); + cudnnBackendSetAttribute = get_symbol(libcudnn, "cudnnBackendSetAttribute"); + cudnnBackendGetAttribute = get_symbol(libcudnn, "cudnnBackendGetAttribute"); + cudnnBackendFinalize = get_symbol(libcudnn, "cudnnBackendFinalize"); + cudnnBackendExecute = get_symbol(libcudnn, "cudnnBackendExecute"); + } +} + + +CudnnContext* CudnnContext::global() { + static CudnnContext instance; + static bool initialized = false; + + if(!initialized) { + // create cudnn handle for each gpu + int count = hidet_cuda_device_count(); + assert(count <= HIDET_CUBLAS_MAX_GPUS); + + int current_device = hidet_cuda_get_device(); + for(int i = 0; i < count; i++) { + hidet_cuda_set_device(i); + CHECK_CUDNN(cudnnCreate(&instance.handles[i])); + } + hidet_cuda_set_device(current_device); + + initialized = true; + } + return &instance; +} + +cudnnHandle_t CudnnContext::current_handle() { + return CudnnContext::global()->handles[hidet_cuda_get_device()]; +} + + +// hidet cudnn api functions +DLL void hidet_cudnn_set_library_path(const char* path) { + if(path) { + library_path = path; + } +} + +DLL void hidet_cudnn_conv2d( + int n, int c, int h, int w, int k, int r, int s, int p, int q, + void *ptr_x, void *ptr_w, void *ptr_y, + int tx, int tw, int ty, int compute_type, + int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2 +) { + lazy_load_cudnn(); + + cudnnHandle_t cur_handle = CudnnContext::current_handle(); + + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUDNN(cudnnSetStream(cur_handle, cur_stream)); + + // Build the descriptor for x + int64_t xDim[] = {n, c, h, w}; + int64_t xStr[] = {c * h * w, h * w, w, 1}; + int64_t xUi = 'x'; + int64_t alignment = 8; + cudnnBackendDescriptor_t xDesc; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &xDesc)); + cudnnDataType_t xDtype = cudnnDataType_t(tx); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &xDtype)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, + CUDNN_TYPE_INT64, 4, xDim)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_STRIDES, + CUDNN_TYPE_INT64, 4, xStr)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, + CUDNN_TYPE_INT64, 1, &xUi)); + CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + CUDNN_TYPE_INT64, 1, &alignment)); + CHECK_CUDNN(cudnnBackendFinalize(xDesc)); + + // Build the descriptor for w + int64_t wDim[] = {k, c, r, s}; + int64_t wStr[] = {c * r * s, r * s, s, 1}; + int64_t wUi = 'w'; + cudnnBackendDescriptor_t wDesc; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &wDesc)); + cudnnDataType_t wDtype = cudnnDataType_t(tw); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &wDtype)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, + CUDNN_TYPE_INT64, 4, wDim)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_STRIDES, + CUDNN_TYPE_INT64, 4, wStr)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, + CUDNN_TYPE_INT64, 1, &wUi)); + CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + CUDNN_TYPE_INT64, 1, &alignment)); + CHECK_CUDNN(cudnnBackendFinalize(wDesc)); + + // Build the descriptor for y + int64_t yDim[] = {n, k, p, q}; + int64_t yStr[] = {k * p * q, p * q, q, 1}; + int64_t yUi = 'y'; + cudnnBackendDescriptor_t yDesc; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &yDesc)); + cudnnDataType_t yDtype = cudnnDataType_t(ty); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &yDtype)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, + CUDNN_TYPE_INT64, 4, yDim)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_STRIDES, + CUDNN_TYPE_INT64, 4, yStr)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, + CUDNN_TYPE_INT64, 1, &yUi)); + CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + CUDNN_TYPE_INT64, 1, &alignment)); + CHECK_CUDNN(cudnnBackendFinalize(yDesc)); + + // Build the descriptor for the convolution operator + cudnnBackendDescriptor_t cDesc; + int64_t nbDims = 2; + cudnnDataType_t compType = cudnnDataType_t(compute_type); + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + int64_t pad[] = {pad_dim1, pad_dim2}; + int64_t filterStr[] = {str_dim1, str_dim2}; + int64_t dilation[] = {dil_dim1, dil_dim2}; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, &cDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + CUDNN_TYPE_INT64, 1, &nbDims)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + CUDNN_TYPE_DATA_TYPE, 1, &compType)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_CONV_MODE, + CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + CUDNN_TYPE_INT64, nbDims, pad)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + CUDNN_TYPE_INT64, nbDims, pad)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_DILATIONS, + CUDNN_TYPE_INT64, nbDims, dilation)); + CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + CUDNN_TYPE_INT64, nbDims, filterStr)); + CHECK_CUDNN(cudnnBackendFinalize(cDesc)); + + // Build the descriptor for the convolution forward operation + cudnnBackendDescriptor_t fprop; + void *p_alpha = nullptr; + void *p_beta = nullptr; + set_alpha_beta(&p_alpha, &p_beta, compType); + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, + &fprop)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &xDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &wDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &yDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &cDesc)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, + get_attribute_type_from_compute_type(compType), 1, p_alpha)); + CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, + get_attribute_type_from_compute_type(compType), 1, p_beta)); + CHECK_CUDNN(cudnnBackendFinalize(fprop)); + + // Build the operation graph descriptor + cudnnBackendDescriptor_t op_graph; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, &op_graph)); + CHECK_CUDNN(cudnnBackendSetAttribute(op_graph, CUDNN_ATTR_OPERATIONGRAPH_OPS, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &fprop)); + CHECK_CUDNN(cudnnBackendSetAttribute(op_graph, CUDNN_ATTR_OPERATIONGRAPH_HANDLE, + CUDNN_TYPE_HANDLE, 1, &cur_handle)); + CHECK_CUDNN(cudnnBackendFinalize(op_graph)); + + // Set up engine config + cudnnBackendDescriptor_t engine; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINE_DESCRIPTOR, &engine)); + CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_OPERATION_GRAPH, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); + // TODO: Is it okay to hardcode the engine to be CUDNN_ATTR_ENGINE_GLOBAL_INDEX 0? + // As mentioned here: https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html, + // Engine selection should be determined based on some heuristics. + int64_t gidx = 0; + CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_GLOBAL_INDEX, + CUDNN_TYPE_INT64, 1, &gidx)); + CHECK_CUDNN(cudnnBackendFinalize(engine)); + + cudnnBackendDescriptor_t engcfg; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, &engcfg)); + CHECK_CUDNN(cudnnBackendSetAttribute(engcfg, CUDNN_ATTR_ENGINECFG_ENGINE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engine)); + CHECK_CUDNN(cudnnBackendFinalize(engcfg)); + + // Set up the execution plan + cudnnBackendDescriptor_t plan; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, &plan)); + CHECK_CUDNN(cudnnBackendSetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_HANDLE, CUDNN_TYPE_HANDLE, 1, &cur_handle)); + CHECK_CUDNN(cudnnBackendSetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engcfg)); + CHECK_CUDNN(cudnnBackendFinalize(plan)); + + int64_t workspaceSize; + CHECK_CUDNN(cudnnBackendGetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE, + CUDNN_TYPE_INT64, 1, NULL, &workspaceSize)); + + void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers + int64_t uids[3] = {'x', 'w', 'y'}; + void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); + + cudnnBackendDescriptor_t varpack; + CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, &varpack)); + CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, + CUDNN_TYPE_VOID_PTR, 3, dev_ptrs)); + CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, + CUDNN_TYPE_INT64, 3, uids)); + CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_WORKSPACE, + CUDNN_TYPE_VOID_PTR, 1, &workspace)); + CHECK_CUDNN(cudnnBackendFinalize(varpack)); + + // Execute the plan + CHECK_CUDNN(cudnnBackendExecute(cur_handle, plan, varpack)); + + // Cleanup + hidet_cuda_free_async(workspace, cur_stream); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(xDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(wDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(yDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(cDesc)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(fprop)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(op_graph)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(engine)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(engcfg)); + CHECK_CUDNN(cudnnBackendDestroyDescriptor(plan)); +} + + diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py new file mode 100644 index 000000000..4c9b3bfb6 --- /dev/null +++ b/tests/cuda/test_cudnn.py @@ -0,0 +1,73 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import math +import hidet +from hidet import ops +from hidet.cuda.cudnn import cudnnDataType + + +@pytest.mark.parametrize( + "n, c, h, w, k, p, q, r, s, padding, stride, dilations", + [ + [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, + [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 + [1, 3, 32, 32, 12, 16, 11, 1, 1, [0, 0], [2, 3], [1, 1]], # kernel 1, + ], +) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [(hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8)], +) +def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + golden = ops.conv2d( + tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2) + ) + hidet.cuda.cudnn.conv2d( + n, + c, + h, + w, + k, + r, + s, + p, + q, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) + + hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) + + +if __name__ == '__main__': + pytest.main([__file__]) From 43fbeab8d9039eb08c3555e5e1184c323faf6453 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Thu, 14 Mar 2024 18:50:53 -0400 Subject: [PATCH 04/12] [CUDNN] Add CuDNN performance benchmarks --- python/hidet/cuda/cudnn/benchmark.py | 108 +++++++++++++++++++++++++++ src/hidet/runtime/cuda/cudnn.cpp | 2 +- tests/cuda/test_cudnn.py | 1 + 3 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 python/hidet/cuda/cudnn/benchmark.py diff --git a/python/hidet/cuda/cudnn/benchmark.py b/python/hidet/cuda/cudnn/benchmark.py new file mode 100644 index 000000000..83666115b --- /dev/null +++ b/python/hidet/cuda/cudnn/benchmark.py @@ -0,0 +1,108 @@ +import numpy as np +import torch + +import hidet +from hidet.cuda.cudnn import cudnnDataType +from hidet.utils.benchmark import do_bench + + +def benchmark_cudnn_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + if dtype_str == "float32": + dtype = hidet.float32 + elif dtype_str == "float64": + dtype = hidet.float64 + else: + raise Exception("Unsupported DataType") + + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + latencies = do_bench( + lambda: hidet.cuda.cudnn.conv2d( + n, + c, + h, + w, + k, + r, + s, + p, + q, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ), + warmup=10, + rep=100, + ) + + print( + f"CuDNN Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds") + print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds") + print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_torch_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + if dtype_str == "float32": + dtype = np.float32 + elif dtype_str == "float64": + dtype = np.float64 + else: + raise Exception("Unsupported DataType") + + data = np.array(np.random.randn(n, c, h, w)).astype(dtype) + weight = np.array(np.random.randn(k, c, r, s)).astype(dtype) + + data_torch, weight_torch = torch.from_numpy(data), torch.from_numpy(weight) + data_torch = data_torch.cuda() + weight_torch = weight_torch.cuda() + + latencies = do_bench( + lambda: torch.nn.functional.conv2d( + data_torch, weight_torch, bias=None, stride=stride, padding=padding, dilation=dilations, groups=1 + ), + warmup=10, + rep=100, + ) + + print( + f"PyTorch Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds") + print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds") + print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds") + print("-------------------------------------------------") + + +if __name__ == '__main__': + sizes = [ + [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], + [2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]], + ] + dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]] + + for data_type in dtypes: + for size in sizes: + benchmark_cudnn_conv2d(*(data_type + size)) + benchmark_torch_conv2d(*(data_type + size)) diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp index 468e5c2dc..c33e1b3e1 100644 --- a/src/hidet/runtime/cuda/cudnn.cpp +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -585,7 +585,7 @@ DLL void hidet_cudnn_conv2d( void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers int64_t uids[3] = {'x', 'w', 'y'}; - void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); + void *workspace = request_cuda_workspace(workspaceSize, false); cudnnBackendDescriptor_t varpack; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, &varpack)); diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 4c9b3bfb6..abaacbc94 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -22,6 +22,7 @@ [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 [1, 3, 32, 32, 12, 16, 11, 1, 1, [0, 0], [2, 3], [1, 1]], # kernel 1, + [2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]], ], ) @pytest.mark.parametrize( From 0d4de20ed6070edc7b54b8dbd27ee7ba73a26834 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Sun, 17 Mar 2024 19:59:01 -0400 Subject: [PATCH 05/12] [CuDNN] Support float16 --- python/hidet/cuda/cudnn/utils.py | 1 + src/hidet/runtime/cuda/cudnn.cpp | 6 ++++-- tests/cuda/test_cudnn.py | 9 ++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/hidet/cuda/cudnn/utils.py b/python/hidet/cuda/cudnn/utils.py index a599a880b..97c3e5958 100644 --- a/python/hidet/cuda/cudnn/utils.py +++ b/python/hidet/cuda/cudnn/utils.py @@ -14,6 +14,7 @@ from .ffi import cudnnDataType _cudnn_type_dict = { + dtypes.float16: cudnnDataType.CUDNN_DATA_HALF, dtypes.float32: cudnnDataType.CUDNN_DATA_FLOAT, dtypes.float64: cudnnDataType.CUDNN_DATA_DOUBLE, dtypes.int32: cudnnDataType.CUDNN_DATA_INT32, diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp index c33e1b3e1..8e82af0ec 100644 --- a/src/hidet/runtime/cuda/cudnn.cpp +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -327,6 +327,7 @@ static void* libcudnn = nullptr; static cudnnBackendAttributeType_t get_attribute_type_from_compute_type(cudnnDataType_t computeType) { switch (computeType) { case CUDNN_DATA_FLOAT: + case CUDNN_DATA_HALF: return CUDNN_TYPE_FLOAT; case CUDNN_DATA_DOUBLE: return CUDNN_TYPE_DOUBLE; @@ -342,8 +343,9 @@ static cudnnBackendAttributeType_t get_attribute_type_from_compute_type(cudnnDat static void set_alpha_beta(void** p_alpha, void** p_beta, cudnnDataType_t c) { // There's no such thing as a cudnnComputeType_t type. As per the official example, the computeType is defined // in terms of cudnnDataType_t - // cudnnBackendAttributeType_t only has support for FLOAT, DOUBLE, and INT64. - if(c == CUDNN_DATA_FLOAT) { + if(c == CUDNN_DATA_FLOAT || c == CUDNN_DATA_HALF) { + // cudnnBackendAttributeType_t only has support for FLOAT, DOUBLE, and INT64. There is no HALF attribute type. + // See get_attribute_type_from_compute_type above. static float alpha = 1.0f; static float beta = 0.0f; *p_alpha = α diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index abaacbc94..0d355cf53 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -21,13 +21,16 @@ [ [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 - [1, 3, 32, 32, 12, 16, 11, 1, 1, [0, 0], [2, 3], [1, 1]], # kernel 1, - [2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]], + [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 + [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 ], ) @pytest.mark.parametrize( 'dtype, compute_type, tol', - [(hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8)], + [(hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), + (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), + (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), + ] ) def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): tx = tw = ty = dtype From b2156692fb5efe3ff40fca6fb4d76b62cf086655 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Tue, 19 Mar 2024 15:53:01 -0400 Subject: [PATCH 06/12] [CuDNN] Add legacy APIs for conv2d --- src/hidet/runtime/cuda/cudnn.cpp | 785 ++++++++++++++++++++++--------- 1 file changed, 551 insertions(+), 234 deletions(-) diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp index 8e82af0ec..9b57d3868 100644 --- a/src/hidet/runtime/cuda/cudnn.cpp +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -19,28 +19,30 @@ /* * CUDNN return codes - defined in cudnn_ops_infer_v8.h */ -typedef enum { - CUDNN_STATUS_SUCCESS = 0, - CUDNN_STATUS_NOT_INITIALIZED = 1, - CUDNN_STATUS_ALLOC_FAILED = 2, - CUDNN_STATUS_BAD_PARAM = 3, - CUDNN_STATUS_INTERNAL_ERROR = 4, - CUDNN_STATUS_INVALID_VALUE = 5, - CUDNN_STATUS_ARCH_MISMATCH = 6, - CUDNN_STATUS_MAPPING_ERROR = 7, - CUDNN_STATUS_EXECUTION_FAILED = 8, - CUDNN_STATUS_NOT_SUPPORTED = 9, - CUDNN_STATUS_LICENSE_ERROR = 10, +typedef enum +{ + CUDNN_STATUS_SUCCESS = 0, + CUDNN_STATUS_NOT_INITIALIZED = 1, + CUDNN_STATUS_ALLOC_FAILED = 2, + CUDNN_STATUS_BAD_PARAM = 3, + CUDNN_STATUS_INTERNAL_ERROR = 4, + CUDNN_STATUS_INVALID_VALUE = 5, + CUDNN_STATUS_ARCH_MISMATCH = 6, + CUDNN_STATUS_MAPPING_ERROR = 7, + CUDNN_STATUS_EXECUTION_FAILED = 8, + CUDNN_STATUS_NOT_SUPPORTED = 9, + CUDNN_STATUS_LICENSE_ERROR = 10, CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING = 11, - CUDNN_STATUS_RUNTIME_IN_PROGRESS = 12, - CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 13, - CUDNN_STATUS_VERSION_MISMATCH = 14, + CUDNN_STATUS_RUNTIME_IN_PROGRESS = 12, + CUDNN_STATUS_RUNTIME_FP_OVERFLOW = 13, + CUDNN_STATUS_VERSION_MISMATCH = 14, } cudnnStatus_t; /* -* CUDNN Descriptor Types - defined in cudnn_backend_v8.h -*/ -typedef enum { + * CUDNN Descriptor Types - defined in cudnn_backend_v8.h + */ +typedef enum +{ CUDNN_BACKEND_POINTWISE_DESCRIPTOR = 0, CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, CUDNN_BACKEND_ENGINE_DESCRIPTOR, @@ -70,182 +72,185 @@ typedef enum { /* * CUDNN data type - defined in cudnn_ops_infer_v8.h */ -typedef enum { - CUDNN_DATA_FLOAT = 0, - CUDNN_DATA_DOUBLE = 1, - CUDNN_DATA_HALF = 2, - CUDNN_DATA_INT8 = 3, - CUDNN_DATA_INT32 = 4, - CUDNN_DATA_INT8x4 = 5, - CUDNN_DATA_UINT8 = 6, - CUDNN_DATA_UINT8x4 = 7, - CUDNN_DATA_INT8x32 = 8, +typedef enum +{ + CUDNN_DATA_FLOAT = 0, + CUDNN_DATA_DOUBLE = 1, + CUDNN_DATA_HALF = 2, + CUDNN_DATA_INT8 = 3, + CUDNN_DATA_INT32 = 4, + CUDNN_DATA_INT8x4 = 5, + CUDNN_DATA_UINT8 = 6, + CUDNN_DATA_UINT8x4 = 7, + CUDNN_DATA_INT8x32 = 8, CUDNN_DATA_BFLOAT16 = 9, - CUDNN_DATA_INT64 = 10, + CUDNN_DATA_INT64 = 10, } cudnnDataType_t; /* -* CUDNN Backend Attribute Names - defined in cudnn_backend_v8.h -*/ -typedef enum { - CUDNN_ATTR_POINTWISE_MODE = 0, - CUDNN_ATTR_POINTWISE_MATH_PREC = 1, - CUDNN_ATTR_POINTWISE_NAN_PROPAGATION = 2, - CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3, - CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4, + * CUDNN Backend Attribute Names - defined in cudnn_backend_v8.h + */ +typedef enum +{ + CUDNN_ATTR_POINTWISE_MODE = 0, + CUDNN_ATTR_POINTWISE_MATH_PREC = 1, + CUDNN_ATTR_POINTWISE_NAN_PROPAGATION = 2, + CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP = 3, + CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP = 4, CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE = 5, - CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6, - CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7, - CUDNN_ATTR_POINTWISE_SWISH_BETA = 8, + CUDNN_ATTR_POINTWISE_ELU_ALPHA = 6, + CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA = 7, + CUDNN_ATTR_POINTWISE_SWISH_BETA = 8, - CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100, - CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101, - CUDNN_ATTR_CONVOLUTION_DILATIONS = 102, + CUDNN_ATTR_CONVOLUTION_COMP_TYPE = 100, + CUDNN_ATTR_CONVOLUTION_CONV_MODE = 101, + CUDNN_ATTR_CONVOLUTION_DILATIONS = 102, CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES = 103, - CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104, - CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105, - CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106, + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS = 104, + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS = 105, + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS = 106, - CUDNN_ATTR_ENGINEHEUR_MODE = 200, + CUDNN_ATTR_ENGINEHEUR_MODE = 200, CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH = 201, - CUDNN_ATTR_ENGINEHEUR_RESULTS = 202, + CUDNN_ATTR_ENGINEHEUR_RESULTS = 202, - CUDNN_ATTR_ENGINECFG_ENGINE = 300, + CUDNN_ATTR_ENGINECFG_ENGINE = 300, CUDNN_ATTR_ENGINECFG_INTERMEDIATE_INFO = 301, - CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302, + CUDNN_ATTR_ENGINECFG_KNOB_CHOICES = 302, - CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400, - CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401, - CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402, + CUDNN_ATTR_EXECUTION_PLAN_HANDLE = 400, + CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG = 401, + CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE = 402, CUDNN_ATTR_EXECUTION_PLAN_COMPUTED_INTERMEDIATE_UIDS = 403, CUDNN_ATTR_EXECUTION_PLAN_RUN_ONLY_INTERMEDIATE_UIDS = 404, - CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500, - CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501, - CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502, + CUDNN_ATTR_INTERMEDIATE_INFO_UNIQUE_ID = 500, + CUDNN_ATTR_INTERMEDIATE_INFO_SIZE = 501, + CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_DATA_UIDS = 502, CUDNN_ATTR_INTERMEDIATE_INFO_DEPENDENT_ATTRIBUTES = 503, - CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600, + CUDNN_ATTR_KNOB_CHOICE_KNOB_TYPE = 600, CUDNN_ATTR_KNOB_CHOICE_KNOB_VALUE = 601, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA = 700, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA = 701, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC = 702, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W = 703, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X = 704, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y = 705, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA = 706, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA = 707, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC = 708, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W = 709, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX = 710, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY = 711, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA = 712, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA = 713, CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC = 714, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716, - CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW = 715, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X = 716, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY = 717, CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR = 750, - CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751, - CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752, - CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753, - CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754, - CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755, - CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756, - CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757, - - CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770, + CUDNN_ATTR_OPERATION_POINTWISE_XDESC = 751, + CUDNN_ATTR_OPERATION_POINTWISE_BDESC = 752, + CUDNN_ATTR_OPERATION_POINTWISE_YDESC = 753, + CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 = 754, + CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 = 755, + CUDNN_ATTR_OPERATION_POINTWISE_DXDESC = 756, + CUDNN_ATTR_OPERATION_POINTWISE_DYDESC = 757, + + CUDNN_ATTR_OPERATION_GENSTATS_MODE = 770, CUDNN_ATTR_OPERATION_GENSTATS_MATH_PREC = 771, - CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772, - CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773, + CUDNN_ATTR_OPERATION_GENSTATS_XDESC = 772, + CUDNN_ATTR_OPERATION_GENSTATS_SUMDESC = 773, CUDNN_ATTR_OPERATION_GENSTATS_SQSUMDESC = 774, - CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780, - CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781, - CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782, - CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783, - CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784, - CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785, - CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786, - CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787, + CUDNN_ATTR_OPERATION_BN_FINALIZE_STATS_MODE = 780, + CUDNN_ATTR_OPERATION_BN_FINALIZE_MATH_PREC = 781, + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SUM_DESC = 782, + CUDNN_ATTR_OPERATION_BN_FINALIZE_Y_SQ_SUM_DESC = 783, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SCALE_DESC = 784, + CUDNN_ATTR_OPERATION_BN_FINALIZE_BIAS_DESC = 785, + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_MEAN_DESC = 786, + CUDNN_ATTR_OPERATION_BN_FINALIZE_PREV_RUNNING_VAR_DESC = 787, CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_MEAN_DESC = 788, - CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789, - CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790, - CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791, - CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792, - CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793, - CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794, - CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795, - CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796, - - CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800, - CUDNN_ATTR_OPERATIONGRAPH_OPS = 801, + CUDNN_ATTR_OPERATION_BN_FINALIZE_UPDATED_RUNNING_VAR_DESC = 789, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_MEAN_DESC = 790, + CUDNN_ATTR_OPERATION_BN_FINALIZE_SAVED_INV_STD_DESC = 791, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_SCALE_DESC = 792, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EQ_BIAS_DESC = 793, + CUDNN_ATTR_OPERATION_BN_FINALIZE_ACCUM_COUNT_DESC = 794, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EPSILON_DESC = 795, + CUDNN_ATTR_OPERATION_BN_FINALIZE_EXP_AVERATE_FACTOR_DESC = 796, + + CUDNN_ATTR_OPERATIONGRAPH_HANDLE = 800, + CUDNN_ATTR_OPERATIONGRAPH_OPS = 801, CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT = 802, - CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900, - CUDNN_ATTR_TENSOR_DATA_TYPE = 901, - CUDNN_ATTR_TENSOR_DIMENSIONS = 902, - CUDNN_ATTR_TENSOR_STRIDES = 903, - CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904, + CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT = 900, + CUDNN_ATTR_TENSOR_DATA_TYPE = 901, + CUDNN_ATTR_TENSOR_DIMENSIONS = 902, + CUDNN_ATTR_TENSOR_STRIDES = 903, + CUDNN_ATTR_TENSOR_VECTOR_COUNT = 904, CUDNN_ATTR_TENSOR_VECTORIZED_DIMENSION = 905, - CUDNN_ATTR_TENSOR_UNIQUE_ID = 906, - CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907, - CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908, + CUDNN_ATTR_TENSOR_UNIQUE_ID = 906, + CUDNN_ATTR_TENSOR_IS_VIRTUAL = 907, + CUDNN_ATTR_TENSOR_IS_BY_VALUE = 908, - CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000, + CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS = 1000, CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS = 1001, CUDNN_ATTR_VARIANT_PACK_INTERMEDIATES = 1002, - CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003, + CUDNN_ATTR_VARIANT_PACK_WORKSPACE = 1003, CUDNN_ATTR_LAYOUT_INFO_TENSOR_UID = 1100, - CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101, + CUDNN_ATTR_LAYOUT_INFO_TYPES = 1101, - CUDNN_ATTR_KNOB_INFO_TYPE = 1200, + CUDNN_ATTR_KNOB_INFO_TYPE = 1200, CUDNN_ATTR_KNOB_INFO_MAXIMUM_VALUE = 1201, CUDNN_ATTR_KNOB_INFO_MINIMUM_VALUE = 1202, - CUDNN_ATTR_KNOB_INFO_STRIDE = 1203, + CUDNN_ATTR_KNOB_INFO_STRIDE = 1203, CUDNN_ATTR_ENGINE_OPERATION_GRAPH = 1300, - CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301, - CUDNN_ATTR_ENGINE_KNOB_INFO = 1302, - CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303, - CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304, - CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305, + CUDNN_ATTR_ENGINE_GLOBAL_INDEX = 1301, + CUDNN_ATTR_ENGINE_KNOB_INFO = 1302, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE = 1303, + CUDNN_ATTR_ENGINE_LAYOUT_INFO = 1304, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE = 1305, CUDNN_ATTR_MATMUL_COMP_TYPE = 1500, - CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520, - CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521, - CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522, - CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523, + CUDNN_ATTR_OPERATION_MATMUL_ADESC = 1520, + CUDNN_ATTR_OPERATION_MATMUL_BDESC = 1521, + CUDNN_ATTR_OPERATION_MATMUL_CDESC = 1522, + CUDNN_ATTR_OPERATION_MATMUL_DESC = 1523, CUDNN_ATTR_OPERATION_MATMUL_IRREGULARLY_STRIDED_BATCH_COUNT = 1524, - CUDNN_ATTR_REDUCTION_OPERATOR = 1600, + CUDNN_ATTR_REDUCTION_OPERATOR = 1600, CUDNN_ATTR_REDUCTION_COMP_TYPE = 1601, CUDNN_ATTR_OPERATION_REDUCTION_XDESC = 1610, CUDNN_ATTR_OPERATION_REDUCTION_YDESC = 1611, - CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612, - - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627, + CUDNN_ATTR_OPERATION_REDUCTION_DESC = 1612, + + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MATH_PREC = 1620, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_MEAN_DESC = 1621, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_INVSTD_DESC = 1622, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_BN_SCALE_DESC = 1623, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_X_DESC = 1624, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DY_DESC = 1625, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_SCALE_DESC = 1626, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_DBN_BIAS_DESC = 1627, CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_DY_SCALE_DESC = 1628, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629, - CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_X_SCALE_DESC = 1629, + CUDNN_ATTR_OPERATION_BN_BWD_WEIGHTS_EQ_BIAS = 1630, } cudnnBackendAttributeName_t; /* -* CUDNN Backend Attribute Type - defined in cudnn_backend_v8.h -*/ -typedef enum { + * CUDNN Backend Attribute Type - defined in cudnn_backend_v8.h + */ +typedef enum +{ CUDNN_TYPE_HANDLE = 0, CUDNN_TYPE_DATA_TYPE, CUDNN_TYPE_BOOLEAN, @@ -271,17 +276,85 @@ typedef enum { /* * convolution mode - defined in cudnn_cnn_infer_v8.h */ -typedef enum { CUDNN_CONVOLUTION = 0, CUDNN_CROSS_CORRELATION = 1 } cudnnConvolutionMode_t; +typedef enum +{ + CUDNN_CONVOLUTION = 0, + CUDNN_CROSS_CORRELATION = 1 +} cudnnConvolutionMode_t; + +/* +================================================== +The following types are used by the Legacy APIs +================================================== +*/ + +/* + * convolution forward algorithms - defined in cudnn_ops_infer_v8.h + */ +typedef enum +{ + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM = 0, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM = 1, + CUDNN_CONVOLUTION_FWD_ALGO_GEMM = 2, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT = 3, + CUDNN_CONVOLUTION_FWD_ALGO_FFT = 4, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING = 5, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD = 6, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED = 7, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT = 8 +} cudnnConvolutionFwdAlgo_t; + +/* + * Tensor formats - defined in cudnn_ops_infer_v8.h + */ +typedef enum +{ + CUDNN_TENSOR_NCHW = 0, /* row major (wStride = 1, hStride = w) */ + CUDNN_TENSOR_NHWC = 1, /* feature maps interleaved ( cStride = 1 )*/ + CUDNN_TENSOR_NCHW_VECT_C = 2, /* each image point is vector of element of C, vector length in data type */ +} cudnnTensorFormat_t; + + +/* + * CUDNN Determinism - defined in cudnn_ops_infer_v8.h + */ +typedef enum { + CUDNN_NON_DETERMINISTIC = 0, + CUDNN_DETERMINISTIC = 1, +} cudnnDeterminism_t; + +/* + * CUDNN math type - defined in cudnn_ops_infer_v8.h + */ +typedef enum { + CUDNN_DEFAULT_MATH = 0, + CUDNN_TENSOR_OP_MATH = 1, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION = 2, + CUDNN_FMA_MATH = 3, +} cudnnMathType_t; + +/* +* Convolution Algorithm Performance - defined in cudnn_cnn_infer_v8.h +*/ +typedef struct cudnnConvolutionFwdAlgoPerfStruct { + cudnnConvolutionFwdAlgo_t algo; + cudnnStatus_t status; + float time; + size_t memory; + cudnnDeterminism_t determinism; + cudnnMathType_t mathType; + int reserved[3]; +} cudnnConvolutionFwdAlgoPerf_t; // define cudnn Graph API functions typedef cudnnStatus_t (*cudnnCreate_t)(cudnnHandle_t *handle); -typedef const char * (*cudnnGetErrorString_t)(cudnnStatus_t status); +typedef const char *(*cudnnGetErrorString_t)(cudnnStatus_t status); typedef cudnnStatus_t (*cudnnSetStream_t)( cudnnHandle_t handle, cudaStream_t streamId); typedef cudnnStatus_t (*cudnnBackendCreateDescriptor_t)( - cudnnBackendDescriptorType_t descriptorType, + cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor); typedef cudnnStatus_t (*cudnnBackendDestroyDescriptor_t)(cudnnBackendDescriptor_t descriptor); typedef cudnnStatus_t (*cudnnBackendSetAttribute_t)( @@ -301,8 +374,83 @@ typedef cudnnStatus_t (*cudnnBackendFinalize_t)(cudnnBackendDescriptor_t descrip typedef cudnnStatus_t (*cudnnBackendExecute_t)( cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBackendDescriptor_t varianPack); - -// cudnn api functions +// Legacy API functions +typedef cudnnStatus_t (*cudnnCreateTensorDescriptor_t)(cudnnTensorDescriptor_t *tensorDesc); +typedef cudnnStatus_t (*cudnnSetTensor4dDescriptor_t)( + cudnnTensorDescriptor_t tensorDesc, + cudnnTensorFormat_t format, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w /* width of input section */ +); +typedef cudnnStatus_t (*cudnnCreateFilterDescriptor_t)(cudnnFilterDescriptor_t *filterDesc); +typedef cudnnStatus_t (*cudnnSetFilter4dDescriptor_t)( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, + int k, /* number of output feature maps */ + int c, /* number of input feature maps */ + int h, /* height of each input filter */ + int w); /* width of each input filter */ +typedef cudnnStatus_t (*cudnnCreateConvolutionDescriptor_t)(cudnnConvolutionDescriptor_t *convDesc); +typedef cudnnStatus_t (*cudnnSetConvolution2dDescriptor_t)( + cudnnConvolutionDescriptor_t convDesc, + int pad_h, /* zero-padding height */ + int pad_w, /* zero-padding width */ + int u, /* vertical filter stride */ + int v, /* horizontal filter stride */ + int dilation_h, /* filter dilation in the vertical dimension */ + int dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t mode, + cudnnDataType_t computeType); +typedef cudnnStatus_t (*cudnnGetConvolution2dForwardOutputDim_t)( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, + int *n, + int *c, + int *h, + int *w); +typedef cudnnStatus_t (*cudnnGetConvolutionForwardAlgorithm_v7_t)( + cudnnHandle_t handle, + const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, + const int requestedAlgoCount, + int *returnedAlgoCount, + cudnnConvolutionFwdAlgoPerf_t *perfResults +); +typedef cudnnStatus_t (*cudnnGetConvolutionForwardWorkspaceSize_t)( + cudnnHandle_t handle, + const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes); +typedef cudnnStatus_t (*cudnnConvolutionForward_t)( + cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t xDesc, + const void *x, + const cudnnFilterDescriptor_t wDesc, + const void *w, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionFwdAlgo_t algo, + void *workSpace, + size_t workSpaceSizeInBytes, + const void *beta, + const cudnnTensorDescriptor_t yDesc, + void *y); +typedef cudnnStatus_t (*cudnnDestroyTensorDescriptor_t)(cudnnTensorDescriptor_t tensorDesc); +typedef cudnnStatus_t (*cudnnDestroyFilterDescriptor_t)(cudnnFilterDescriptor_t filterDesc); +typedef cudnnStatus_t (*cudnnDestroyConvolutionDescriptor_t)(cudnnConvolutionDescriptor_t convDesc); + + +// Graph APIs static cudnnCreate_t cudnnCreate; static cudnnGetErrorString_t cudnnGetErrorString; static cudnnSetStream_t cudnnSetStream; @@ -313,69 +461,103 @@ static cudnnBackendGetAttribute_t cudnnBackendGetAttribute; static cudnnBackendFinalize_t cudnnBackendFinalize; static cudnnBackendExecute_t cudnnBackendExecute; +// Legacy APIs +static cudnnCreateTensorDescriptor_t cudnnCreateTensorDescriptor; +static cudnnSetTensor4dDescriptor_t cudnnSetTensor4dDescriptor; +static cudnnCreateFilterDescriptor_t cudnnCreateFilterDescriptor; +static cudnnSetFilter4dDescriptor_t cudnnSetFilter4dDescriptor; +static cudnnCreateConvolutionDescriptor_t cudnnCreateConvolutionDescriptor; +static cudnnSetConvolution2dDescriptor_t cudnnSetConvolution2dDescriptor; +static cudnnConvolutionForward_t cudnnConvolutionForward; +static cudnnGetConvolutionForwardWorkspaceSize_t cudnnGetConvolutionForwardWorkspaceSize; +static cudnnGetConvolution2dForwardOutputDim_t cudnnGetConvolution2dForwardOutputDim; +static cudnnDestroyTensorDescriptor_t cudnnDestroyTensorDescriptor; +static cudnnDestroyFilterDescriptor_t cudnnDestroyFilterDescriptor; +static cudnnDestroyConvolutionDescriptor_t cudnnDestroyConvolutionDescriptor; +static cudnnGetConvolutionForwardAlgorithm_v7_t cudnnGetConvolutionForwardAlgorithm_v7; + static std::string library_path; -static void* libcudnn = nullptr; +static void *libcudnn = nullptr; // utility functions -#define CHECK_CUDNN(status) do { \ - cudnnStatus_t err = (status); \ - if(err != 0) { \ - LOG(FATAL) << "cuDNN error: " << cudnnGetErrorString(err); \ - } \ -} while(0) - -static cudnnBackendAttributeType_t get_attribute_type_from_compute_type(cudnnDataType_t computeType) { - switch (computeType) { - case CUDNN_DATA_FLOAT: - case CUDNN_DATA_HALF: - return CUDNN_TYPE_FLOAT; - case CUDNN_DATA_DOUBLE: - return CUDNN_TYPE_DOUBLE; - case CUDNN_DATA_INT64: - case CUDNN_DATA_INT32: - return CUDNN_TYPE_INT64; - default: - LOG(FATAL) << "Unsupported compute type: " << computeType; - return CUDNN_TYPE_VOID_PTR; +#define CHECK_CUDNN(status) \ + do \ + { \ + cudnnStatus_t err = (status); \ + if (err != 0) \ + { \ + LOG(FATAL) << "cuDNN error: " << cudnnGetErrorString(err); \ + } \ + } while (0) + +static cudnnBackendAttributeType_t get_attribute_type_from_compute_type(cudnnDataType_t computeType) +{ + switch (computeType) + { + case CUDNN_DATA_FLOAT: + case CUDNN_DATA_HALF: + return CUDNN_TYPE_FLOAT; + case CUDNN_DATA_DOUBLE: + return CUDNN_TYPE_DOUBLE; + case CUDNN_DATA_INT64: + case CUDNN_DATA_INT32: + return CUDNN_TYPE_INT64; + default: + LOG(FATAL) << "Unsupported compute type: " << computeType; + return CUDNN_TYPE_VOID_PTR; } } -static void set_alpha_beta(void** p_alpha, void** p_beta, cudnnDataType_t c) { +static void set_alpha_beta(void **p_alpha, void **p_beta, cudnnDataType_t c) +{ // There's no such thing as a cudnnComputeType_t type. As per the official example, the computeType is defined // in terms of cudnnDataType_t - if(c == CUDNN_DATA_FLOAT || c == CUDNN_DATA_HALF) { + if (c == CUDNN_DATA_FLOAT || c == CUDNN_DATA_HALF) + { // cudnnBackendAttributeType_t only has support for FLOAT, DOUBLE, and INT64. There is no HALF attribute type. // See get_attribute_type_from_compute_type above. static float alpha = 1.0f; static float beta = 0.0f; *p_alpha = α *p_beta = β - } else if(c == CUDNN_DATA_DOUBLE) { + } + else if (c == CUDNN_DATA_DOUBLE) + { static double alpha = 1.0; static double beta = 0.0; *p_alpha = α *p_beta = β - } else if(c == CUDNN_DATA_INT64 || c == CUDNN_DATA_INT32) { + } + else if (c == CUDNN_DATA_INT64 || c == CUDNN_DATA_INT32) + { static int64_t alpha = 1; static int64_t beta = 0; *p_alpha = α *p_beta = β - } else { + } + else + { LOG(FATAL) << "Unsupported compute type: " << c; } } -static void lazy_load_cudnn() { - if(libcudnn == nullptr) { +static void lazy_load_cudnn() +{ + if (libcudnn == nullptr) + { // load cudnn shared library - const char* libpath; - if(library_path.empty()) { + const char *libpath; + if (library_path.empty()) + { libpath = "libcudnn.so"; - } else { + } + else + { libpath = library_path.c_str(); } libcudnn = dlopen(libpath, RTLD_LAZY); - if(libcudnn == nullptr) { + if (libcudnn == nullptr) + { LOG(FATAL) << "Failed to load cublas library: " << libpath << dlerror(); } @@ -389,21 +571,37 @@ static void lazy_load_cudnn() { cudnnBackendGetAttribute = get_symbol(libcudnn, "cudnnBackendGetAttribute"); cudnnBackendFinalize = get_symbol(libcudnn, "cudnnBackendFinalize"); cudnnBackendExecute = get_symbol(libcudnn, "cudnnBackendExecute"); + + cudnnCreateTensorDescriptor = get_symbol(libcudnn, "cudnnCreateTensorDescriptor"); + cudnnSetTensor4dDescriptor = get_symbol(libcudnn, "cudnnSetTensor4dDescriptor"); + cudnnCreateFilterDescriptor = get_symbol(libcudnn, "cudnnCreateFilterDescriptor"); + cudnnSetFilter4dDescriptor = get_symbol(libcudnn, "cudnnSetFilter4dDescriptor"); + cudnnCreateConvolutionDescriptor = get_symbol(libcudnn, "cudnnCreateConvolutionDescriptor"); + cudnnSetConvolution2dDescriptor = get_symbol(libcudnn, "cudnnSetConvolution2dDescriptor"); + cudnnGetConvolution2dForwardOutputDim = get_symbol(libcudnn, "cudnnGetConvolution2dForwardOutputDim"); + cudnnGetConvolutionForwardWorkspaceSize = get_symbol(libcudnn, "cudnnGetConvolutionForwardWorkspaceSize"); + cudnnConvolutionForward = get_symbol(libcudnn, "cudnnConvolutionForward"); + cudnnDestroyTensorDescriptor = get_symbol(libcudnn, "cudnnDestroyTensorDescriptor"); + cudnnDestroyFilterDescriptor = get_symbol(libcudnn, "cudnnDestroyFilterDescriptor"); + cudnnDestroyConvolutionDescriptor = get_symbol(libcudnn, "cudnnDestroyConvolutionDescriptor"); + cudnnGetConvolutionForwardAlgorithm_v7 = get_symbol(libcudnn, "cudnnGetConvolutionForwardAlgorithm_v7"); } } - -CudnnContext* CudnnContext::global() { +CudnnContext *CudnnContext::global() +{ static CudnnContext instance; static bool initialized = false; - if(!initialized) { + if (!initialized) + { // create cudnn handle for each gpu int count = hidet_cuda_device_count(); assert(count <= HIDET_CUBLAS_MAX_GPUS); int current_device = hidet_cuda_get_device(); - for(int i = 0; i < count; i++) { + for (int i = 0; i < count; i++) + { hidet_cuda_set_device(i); CHECK_CUDNN(cudnnCreate(&instance.handles[i])); } @@ -414,26 +612,147 @@ CudnnContext* CudnnContext::global() { return &instance; } -cudnnHandle_t CudnnContext::current_handle() { +cudnnHandle_t CudnnContext::current_handle() +{ return CudnnContext::global()->handles[hidet_cuda_get_device()]; } - // hidet cudnn api functions -DLL void hidet_cudnn_set_library_path(const char* path) { - if(path) { +DLL void hidet_cudnn_set_library_path(const char *path) +{ + if (path) + { library_path = path; } } +DLL void hidet_cudnn_conv2d_gemm( + int n, int c, int h, int w, int k, int r, int s, + void *ptr_x, void *ptr_w, void *ptr_y, + int tx, int tw, int ty, int compute_type, + int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2) +{ + lazy_load_cudnn(); + + cudnnHandle_t cur_handle = CudnnContext::current_handle(); + + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUDNN(cudnnSetStream(cur_handle, cur_stream)); + + // Build descriptors and launch the kernel + cudnnTensorDescriptor_t input_descriptor; + CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(tx), n, c, h, w)); + cudnnFilterDescriptor_t kernel_descriptor; + CHECK_CUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor, cudnnDataType_t(tw), CUDNN_TENSOR_NCHW, k, c, r, s)); + cudnnConvolutionDescriptor_t convolution_descriptor; + CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); + CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, pad_dim1, pad_dim2, str_dim1, str_dim2, dil_dim1, dil_dim2, + CUDNN_CROSS_CORRELATION, cudnnDataType_t(compute_type))); + + int out_n{0}, out_c{0}, out_h{0}, out_w{0}; + CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, input_descriptor, kernel_descriptor, + &out_n, &out_c, &out_h, &out_w)); + cudnnTensorDescriptor_t output_descriptor; + CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty), + out_n, out_c, out_h, out_w)); + size_t workspaceSize{0}; + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, + convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + &workspaceSize)); + void *workspace = request_cuda_workspace(workspaceSize, false); + + void *p_alpha = nullptr; + void *p_beta = nullptr; + cudnnDataType_t compType = cudnnDataType_t(compute_type); + set_alpha_beta(&p_alpha, &p_beta, compType); + + CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w, + convolution_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + workspace, workspaceSize, + p_beta, output_descriptor, ptr_y)); + + CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor)); + CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor)); + CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor)); + CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor)); +} + +DLL void hidet_cudnn_conv2d_autoselect_algo( + int n, int c, int h, int w, int k, int r, int s, + void *ptr_x, void *ptr_w, void *ptr_y, + int tx, int tw, int ty, int compute_type, + int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2) +{ + lazy_load_cudnn(); + + cudnnHandle_t cur_handle = CudnnContext::current_handle(); + + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUDNN(cudnnSetStream(cur_handle, cur_stream)); + + // Build descriptors and launch the kernel + cudnnTensorDescriptor_t input_descriptor; + CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(tx), n, c, h, w)); + cudnnFilterDescriptor_t kernel_descriptor; + CHECK_CUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor, cudnnDataType_t(tw), CUDNN_TENSOR_NCHW, k, c, r, s)); + cudnnConvolutionDescriptor_t convolution_descriptor; + CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); + CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, pad_dim1, pad_dim2, str_dim1, str_dim2, dil_dim1, dil_dim2, + CUDNN_CROSS_CORRELATION, cudnnDataType_t(compute_type))); + + int out_n{0}, out_c{0}, out_h{0}, out_w{0}; + CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, input_descriptor, kernel_descriptor, + &out_n, &out_c, &out_h, &out_w)); + cudnnTensorDescriptor_t output_descriptor; + CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); + CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty), + out_n, out_c, out_h, out_w)); + + int returnedAlgoCount; + cudnnConvolutionFwdAlgoPerf_t perfResults; + + CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm_v7(cur_handle, input_descriptor, kernel_descriptor, + convolution_descriptor, output_descriptor, + 1, &returnedAlgoCount, &perfResults)); + cudnnConvolutionFwdAlgo_t convolution_algorithm = perfResults.algo; + + size_t workspaceSize{0}; + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, + convolution_descriptor, output_descriptor, convolution_algorithm, + &workspaceSize)); + void *workspace = request_cuda_workspace(workspaceSize, false); + + void *p_alpha = nullptr; + void *p_beta = nullptr; + cudnnDataType_t compType = cudnnDataType_t(compute_type); + set_alpha_beta(&p_alpha, &p_beta, compType); + + CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w, + convolution_descriptor, convolution_algorithm, + workspace, workspaceSize, + p_beta, output_descriptor, ptr_y)); + + CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor)); + CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor)); + CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor)); + CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor)); +} + DLL void hidet_cudnn_conv2d( int n, int c, int h, int w, int k, int r, int s, int p, int q, void *ptr_x, void *ptr_w, void *ptr_y, - int tx, int tw, int ty, int compute_type, - int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2 -) { + int tx, int tw, int ty, int compute_type, + int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2) +{ lazy_load_cudnn(); - + cudnnHandle_t cur_handle = CudnnContext::current_handle(); // Set the stream to the current stream @@ -449,15 +768,15 @@ DLL void hidet_cudnn_conv2d( CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &xDesc)); cudnnDataType_t xDtype = cudnnDataType_t(tx); CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, - CUDNN_TYPE_DATA_TYPE, 1, &xDtype)); + CUDNN_TYPE_DATA_TYPE, 1, &xDtype)); CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, - CUDNN_TYPE_INT64, 4, xDim)); + CUDNN_TYPE_INT64, 4, xDim)); CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_STRIDES, - CUDNN_TYPE_INT64, 4, xStr)); + CUDNN_TYPE_INT64, 4, xStr)); CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, - CUDNN_TYPE_INT64, 1, &xUi)); + CUDNN_TYPE_INT64, 1, &xUi)); CHECK_CUDNN(cudnnBackendSetAttribute(xDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, - CUDNN_TYPE_INT64, 1, &alignment)); + CUDNN_TYPE_INT64, 1, &alignment)); CHECK_CUDNN(cudnnBackendFinalize(xDesc)); // Build the descriptor for w @@ -468,15 +787,15 @@ DLL void hidet_cudnn_conv2d( CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &wDesc)); cudnnDataType_t wDtype = cudnnDataType_t(tw); CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, - CUDNN_TYPE_DATA_TYPE, 1, &wDtype)); + CUDNN_TYPE_DATA_TYPE, 1, &wDtype)); CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, - CUDNN_TYPE_INT64, 4, wDim)); + CUDNN_TYPE_INT64, 4, wDim)); CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_STRIDES, - CUDNN_TYPE_INT64, 4, wStr)); + CUDNN_TYPE_INT64, 4, wStr)); CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, - CUDNN_TYPE_INT64, 1, &wUi)); + CUDNN_TYPE_INT64, 1, &wUi)); CHECK_CUDNN(cudnnBackendSetAttribute(wDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, - CUDNN_TYPE_INT64, 1, &alignment)); + CUDNN_TYPE_INT64, 1, &alignment)); CHECK_CUDNN(cudnnBackendFinalize(wDesc)); // Build the descriptor for y @@ -487,15 +806,15 @@ DLL void hidet_cudnn_conv2d( CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_TENSOR_DESCRIPTOR, &yDesc)); cudnnDataType_t yDtype = cudnnDataType_t(ty); CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_DATA_TYPE, - CUDNN_TYPE_DATA_TYPE, 1, &yDtype)); + CUDNN_TYPE_DATA_TYPE, 1, &yDtype)); CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_DIMENSIONS, - CUDNN_TYPE_INT64, 4, yDim)); + CUDNN_TYPE_INT64, 4, yDim)); CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_STRIDES, - CUDNN_TYPE_INT64, 4, yStr)); + CUDNN_TYPE_INT64, 4, yStr)); CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_UNIQUE_ID, - CUDNN_TYPE_INT64, 1, &yUi)); + CUDNN_TYPE_INT64, 1, &yUi)); CHECK_CUDNN(cudnnBackendSetAttribute(yDesc, CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, - CUDNN_TYPE_INT64, 1, &alignment)); + CUDNN_TYPE_INT64, 1, &alignment)); CHECK_CUDNN(cudnnBackendFinalize(yDesc)); // Build the descriptor for the convolution operator @@ -508,19 +827,19 @@ DLL void hidet_cudnn_conv2d( int64_t dilation[] = {dil_dim1, dil_dim2}; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, &cDesc)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, - CUDNN_TYPE_INT64, 1, &nbDims)); + CUDNN_TYPE_INT64, 1, &nbDims)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_COMP_TYPE, - CUDNN_TYPE_DATA_TYPE, 1, &compType)); + CUDNN_TYPE_DATA_TYPE, 1, &compType)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_CONV_MODE, - CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); + CUDNN_TYPE_CONVOLUTION_MODE, 1, &mode)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, - CUDNN_TYPE_INT64, nbDims, pad)); + CUDNN_TYPE_INT64, nbDims, pad)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, - CUDNN_TYPE_INT64, nbDims, pad)); + CUDNN_TYPE_INT64, nbDims, pad)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_DILATIONS, - CUDNN_TYPE_INT64, nbDims, dilation)); + CUDNN_TYPE_INT64, nbDims, dilation)); CHECK_CUDNN(cudnnBackendSetAttribute(cDesc, CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, - CUDNN_TYPE_INT64, nbDims, filterStr)); + CUDNN_TYPE_INT64, nbDims, filterStr)); CHECK_CUDNN(cudnnBackendFinalize(cDesc)); // Build the descriptor for the convolution forward operation @@ -529,48 +848,48 @@ DLL void hidet_cudnn_conv2d( void *p_beta = nullptr; set_alpha_beta(&p_alpha, &p_beta, compType); CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, - &fprop)); + &fprop)); CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &xDesc)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &xDesc)); CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &wDesc)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &wDesc)); CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &yDesc)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &yDesc)); CHECK_CUDNN(cudnnBackendSetAttribute(fprop, - CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &cDesc)); + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &cDesc)); CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, - get_attribute_type_from_compute_type(compType), 1, p_alpha)); + get_attribute_type_from_compute_type(compType), 1, p_alpha)); CHECK_CUDNN(cudnnBackendSetAttribute(fprop, CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, - get_attribute_type_from_compute_type(compType), 1, p_beta)); + get_attribute_type_from_compute_type(compType), 1, p_beta)); CHECK_CUDNN(cudnnBackendFinalize(fprop)); // Build the operation graph descriptor cudnnBackendDescriptor_t op_graph; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, &op_graph)); CHECK_CUDNN(cudnnBackendSetAttribute(op_graph, CUDNN_ATTR_OPERATIONGRAPH_OPS, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &fprop)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &fprop)); CHECK_CUDNN(cudnnBackendSetAttribute(op_graph, CUDNN_ATTR_OPERATIONGRAPH_HANDLE, - CUDNN_TYPE_HANDLE, 1, &cur_handle)); + CUDNN_TYPE_HANDLE, 1, &cur_handle)); CHECK_CUDNN(cudnnBackendFinalize(op_graph)); // Set up engine config cudnnBackendDescriptor_t engine; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINE_DESCRIPTOR, &engine)); CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_OPERATION_GRAPH, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); // TODO: Is it okay to hardcode the engine to be CUDNN_ATTR_ENGINE_GLOBAL_INDEX 0? // As mentioned here: https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html, // Engine selection should be determined based on some heuristics. int64_t gidx = 0; CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_GLOBAL_INDEX, - CUDNN_TYPE_INT64, 1, &gidx)); + CUDNN_TYPE_INT64, 1, &gidx)); CHECK_CUDNN(cudnnBackendFinalize(engine)); cudnnBackendDescriptor_t engcfg; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, &engcfg)); CHECK_CUDNN(cudnnBackendSetAttribute(engcfg, CUDNN_ATTR_ENGINECFG_ENGINE, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engine)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engine)); CHECK_CUDNN(cudnnBackendFinalize(engcfg)); // Set up the execution plan @@ -578,13 +897,13 @@ DLL void hidet_cudnn_conv2d( CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, &plan)); CHECK_CUDNN(cudnnBackendSetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_HANDLE, CUDNN_TYPE_HANDLE, 1, &cur_handle)); CHECK_CUDNN(cudnnBackendSetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, - CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engcfg)); + CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &engcfg)); CHECK_CUDNN(cudnnBackendFinalize(plan)); int64_t workspaceSize; CHECK_CUDNN(cudnnBackendGetAttribute(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE, - CUDNN_TYPE_INT64, 1, NULL, &workspaceSize)); - + CUDNN_TYPE_INT64, 1, NULL, &workspaceSize)); + void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers int64_t uids[3] = {'x', 'w', 'y'}; void *workspace = request_cuda_workspace(workspaceSize, false); @@ -592,11 +911,11 @@ DLL void hidet_cudnn_conv2d( cudnnBackendDescriptor_t varpack; CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR, &varpack)); CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS, - CUDNN_TYPE_VOID_PTR, 3, dev_ptrs)); + CUDNN_TYPE_VOID_PTR, 3, dev_ptrs)); CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS, - CUDNN_TYPE_INT64, 3, uids)); + CUDNN_TYPE_INT64, 3, uids)); CHECK_CUDNN(cudnnBackendSetAttribute(varpack, CUDNN_ATTR_VARIANT_PACK_WORKSPACE, - CUDNN_TYPE_VOID_PTR, 1, &workspace)); + CUDNN_TYPE_VOID_PTR, 1, &workspace)); CHECK_CUDNN(cudnnBackendFinalize(varpack)); // Execute the plan @@ -614,5 +933,3 @@ DLL void hidet_cudnn_conv2d( CHECK_CUDNN(cudnnBackendDestroyDescriptor(engcfg)); CHECK_CUDNN(cudnnBackendDestroyDescriptor(plan)); } - - From 730d30fa02697647f5615170405538a4cd59f6ff Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Wed, 3 Apr 2024 18:17:13 -0400 Subject: [PATCH 07/12] Add cudnn_gemm --- include/hidet/runtime/cuda/cudnn.h | 9 ++ python/hidet/cuda/cudnn/__init__.py | 2 +- python/hidet/cuda/cudnn/benchmark.py | 172 ++++++++++++++++++++---- python/hidet/cuda/cudnn/ffi.py | 54 ++++++++ python/hidet/cuda/cudnn/kernels.py | 188 +++++++++++++++++++++++++++ python/hidet/cuda/cudnn/utils.py | 12 +- tests/cuda/test_cudnn.py | 110 ++++++++++++++++ 7 files changed, 518 insertions(+), 29 deletions(-) diff --git a/include/hidet/runtime/cuda/cudnn.h b/include/hidet/runtime/cuda/cudnn.h index fcbc43697..391f10985 100644 --- a/include/hidet/runtime/cuda/cudnn.h +++ b/include/hidet/runtime/cuda/cudnn.h @@ -19,6 +19,15 @@ typedef struct cudnnContext *cudnnHandle_t; typedef void *cudnnBackendDescriptor_t; +/* Legacy API */ +struct cudnnTensorStruct; +struct cudnnFilterStruct; +struct cudnnConvolutionStruct; + +typedef struct cudnnTensorStruct *cudnnTensorDescriptor_t; +typedef struct cudnnFilterStruct *cudnnFilterDescriptor_t; +typedef struct cudnnConvolutionStruct *cudnnConvolutionDescriptor_t; + struct CudnnContext { cudnnHandle_t handles[HIDET_CUDNN_MAX_GPUS]; static CudnnContext* global(); diff --git a/python/hidet/cuda/cudnn/__init__.py b/python/hidet/cuda/cudnn/__init__.py index de471d1a6..b22207b94 100644 --- a/python/hidet/cuda/cudnn/__init__.py +++ b/python/hidet/cuda/cudnn/__init__.py @@ -10,4 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .ffi import cudnnDataType -from .kernels import conv2d +from .kernels import conv2d, conv2d_gemm, conv2d_autoselect_algo diff --git a/python/hidet/cuda/cudnn/benchmark.py b/python/hidet/cuda/cudnn/benchmark.py index 83666115b..44bc2067f 100644 --- a/python/hidet/cuda/cudnn/benchmark.py +++ b/python/hidet/cuda/cudnn/benchmark.py @@ -4,16 +4,11 @@ import hidet from hidet.cuda.cudnn import cudnnDataType from hidet.utils.benchmark import do_bench +from hidet import ops -def benchmark_cudnn_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): - if dtype_str == "float32": - dtype = hidet.float32 - elif dtype_str == "float64": - dtype = hidet.float64 - else: - raise Exception("Unsupported DataType") - +def benchmark_cudnn_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + # Uses ordinary cudnn.conv2d implemented with Graph-API tx = tw = ty = dtype pad_dim1, pad_dim2 = padding str_dim1, str_dim2 = stride @@ -49,27 +44,111 @@ def benchmark_cudnn_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, p dil_dim2, ), warmup=10, + rep=1, + ) + + print( + f"CuDNN Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_cudnn_conv2d_gemm(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + # Uses cudnn.conv2d_gemm implemented with Legacy-API + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + latencies = do_bench( + lambda: hidet.cuda.cudnn.conv2d_gemm( + n, + c, + h, + w, + k, + r, + s, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ), + warmup=10, rep=100, ) print( - f"CuDNN Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, " + f"cudnn_gemm Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" ) - print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds") - print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds") - print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds") + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") print("-------------------------------------------------") -def benchmark_torch_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): - if dtype_str == "float32": - dtype = np.float32 - elif dtype_str == "float64": - dtype = np.float64 - else: - raise Exception("Unsupported DataType") +def benchmark_cudnn_conv2d_autoselect_algo(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + # Uses cudnn Legacy-API to autoselect the fastest algorithm + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + latencies = do_bench( + lambda: hidet.cuda.cudnn.conv2d_autoselect_algo( + n, + c, + h, + w, + k, + r, + s, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ), + warmup=10, + rep=100, + ) + + print( + f"cudnn_autoselect_algo Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + +def benchmark_torch_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + # Native PyTorch Eager-mode Execution data = np.array(np.random.randn(n, c, h, w)).astype(dtype) weight = np.array(np.random.randn(k, c, r, s)).astype(dtype) @@ -86,23 +165,62 @@ def benchmark_torch_conv2d(dtype_str, compute_type, n, c, h, w, k, p, q, r, s, p ) print( - f"PyTorch Results for Configuration: dtype = {dtype_str}, input shape = {[n,c,h,w]}, " + f"PyTorch Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " + f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" + ) + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") + print("-------------------------------------------------") + + +def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): + # Uses optimized Hidet Graph implementation + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + hidet.option.search_space(2) + tensor_x = hidet.symbol((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + output = ops.conv2d( + tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2) + ) + graph = hidet.trace_from(output, inputs=[tensor_x, tensor_w]) + graph = hidet.graph.optimize(graph) + graph = graph.cuda_graph() + + latencies = do_bench( + lambda: graph.run_async(), + warmup=10, + rep=100, + ) + + print( + f"Optimized Hidet Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" ) - print("20th Percentile Latency Is: " + str(latencies[0]) + " milliseconds") - print("50th Percentile Latency Is: " + str(latencies[1]) + " milliseconds") - print("80th Percentile Latency Is: " + str(latencies[2]) + " milliseconds") + print("Median Latency Is: " + str(latencies[1]) + " milliseconds") print("-------------------------------------------------") if __name__ == '__main__': sizes = [ - [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], - [2, 3, 224, 224, 16, 109, 109, 7, 7, [0, 0], [2, 2], [1, 1]], + # Group 1 + [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], + [2, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], + [4, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], + [8, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], + # Group 2 + [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], + [2, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], + [4, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], + [8, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], ] - dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]] + dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float16', cudnnDataType.CUDNN_DATA_HALF]]#, ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]] for data_type in dtypes: for size in sizes: - benchmark_cudnn_conv2d(*(data_type + size)) + benchmark_cudnn_conv2d_gemm(*(data_type + size)) benchmark_torch_conv2d(*(data_type + size)) + benchmark_cudnn_conv2d_autoselect_algo(*(data_type + size)) + benchmark_hidet_conv2d(*(data_type + size)) diff --git a/python/hidet/cuda/cudnn/ffi.py b/python/hidet/cuda/cudnn/ffi.py index bb559fcda..46f512902 100644 --- a/python/hidet/cuda/cudnn/ffi.py +++ b/python/hidet/cuda/cudnn/ffi.py @@ -68,6 +68,60 @@ class cudnnDataType(IntEnum): restype=None, ) +conv2d_gemm = get_func( + func_name='hidet_cudnn_conv2d_gemm', + arg_types=[ + c_int32, # n + c_int32, # c + c_int32, # h + c_int32, # w + c_int32, # k + c_int32, # r + c_int32, # s + c_void_p, # ptr_x + c_void_p, # ptr_w + c_void_p, # ptr_y + c_int32, # tx + c_int32, # tw + c_int32, # ty + c_int32, # compute_type + c_int32, # pad_dim1 + c_int32, # pad_dim2 + c_int32, # str_dim1 + c_int32, # str_dim2 + c_int32, # dil_dim1 + c_int32, # dil_dim2 + ], + restype=None, +) + +conv2d_autoselect_algo = get_func( + func_name='hidet_cudnn_conv2d_autoselect_algo', + arg_types=[ + c_int32, # n + c_int32, # c + c_int32, # h + c_int32, # w + c_int32, # k + c_int32, # r + c_int32, # s + c_void_p, # ptr_x + c_void_p, # ptr_w + c_void_p, # ptr_y + c_int32, # tx + c_int32, # tw + c_int32, # ty + c_int32, # compute_type + c_int32, # pad_dim1 + c_int32, # pad_dim2 + c_int32, # str_dim1 + c_int32, # str_dim2 + c_int32, # dil_dim1 + c_int32, # dil_dim2 + ], + restype=None, +) + @initialize() def set_cudnn_library_path(): diff --git a/python/hidet/cuda/cudnn/kernels.py b/python/hidet/cuda/cudnn/kernels.py index 2b1f3b2e9..c772dd7de 100644 --- a/python/hidet/cuda/cudnn/kernels.py +++ b/python/hidet/cuda/cudnn/kernels.py @@ -115,3 +115,191 @@ def conv2d( dil_dim1, dil_dim2, ) + + +def conv2d_gemm( + n: int, + c: int, + h: int, + w: int, + k: int, + r: int, + s: int, + ptr_x, + ptr_w, + ptr_y, + tx: Union[int, DataType], + tw: Union[int, DataType], + ty: Union[int, DataType], + compute_type: Union[int, cudnnDataType], + pad_dim1: int, + pad_dim2: int, + str_dim1: int, + str_dim2: int, + dil_dim1: int, + dil_dim2: int, +): + """ + Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y. + + Parameters + ---------- + n: int + Batch number. + c: int + Number of channels in the input tensor x. + h: int + Height of the input tensor x. + w: int + Width of the input tensor x. + k: int + Number of channels in the output tensor y. + r: int + Height of the filter w. + s: int + Width of the filter w. + ptr_x: hidet.Tensor or int + Input tensor x, can be either a Tensor or an integer (the address of the tensor). + ptr_w: hidet.Tensor or int + Weight tensor w, can be either a Tensor or an integer (the address of the tensor). + ptr_y: hidet.Tensor or int + Output tensor y, can be either a Tensor or an integer (the address of the tensor). + tx: Union[int, DataType] + Type of elements in tensor x. + tw: Union[int, DataType] + Type of elements in tensor w. + ty: Union[int, DataType] + Type of elements in tensor y. + compute_type: Union[int, cudnnDataType] + The compute type of the operation. + For cuDNN, there's no such thing as a cudnnComputeType_t type. + As per the official example, the computeType is defined in terms of cudnnDataType_t + pad_dim1: int + The value to use for padding along the height dimension + pad_dim2: int + The value to use for padding along the width dimension + str_dim1: int + The stride to use for the height dimension + str_dim2: int + The stride to use for the width dimension + dil_dim1: int + The dilation to use for the height dimension + dil_dim2: int + The dilation to use for the width dimension + """ + ffi.conv2d_gemm( + n, + c, + h, + w, + k, + r, + s, + as_pointer(ptr_x), + as_pointer(ptr_w), + as_pointer(ptr_y), + as_cudnn_type(tx), + as_cudnn_type(tw), + as_cudnn_type(ty), + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) + + +def conv2d_autoselect_algo( + n: int, + c: int, + h: int, + w: int, + k: int, + r: int, + s: int, + ptr_x, + ptr_w, + ptr_y, + tx: Union[int, DataType], + tw: Union[int, DataType], + ty: Union[int, DataType], + compute_type: Union[int, cudnnDataType], + pad_dim1: int, + pad_dim2: int, + str_dim1: int, + str_dim2: int, + dil_dim1: int, + dil_dim2: int, +): + """ + Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y. + + Parameters + ---------- + n: int + Batch number. + c: int + Number of channels in the input tensor x. + h: int + Height of the input tensor x. + w: int + Width of the input tensor x. + k: int + Number of channels in the output tensor y. + r: int + Height of the filter w. + s: int + Width of the filter w. + ptr_x: hidet.Tensor or int + Input tensor x, can be either a Tensor or an integer (the address of the tensor). + ptr_w: hidet.Tensor or int + Weight tensor w, can be either a Tensor or an integer (the address of the tensor). + ptr_y: hidet.Tensor or int + Output tensor y, can be either a Tensor or an integer (the address of the tensor). + tx: Union[int, DataType] + Type of elements in tensor x. + tw: Union[int, DataType] + Type of elements in tensor w. + ty: Union[int, DataType] + Type of elements in tensor y. + compute_type: Union[int, cudnnDataType] + The compute type of the operation. + For cuDNN, there's no such thing as a cudnnComputeType_t type. + As per the official example, the computeType is defined in terms of cudnnDataType_t + pad_dim1: int + The value to use for padding along the height dimension + pad_dim2: int + The value to use for padding along the width dimension + str_dim1: int + The stride to use for the height dimension + str_dim2: int + The stride to use for the width dimension + dil_dim1: int + The dilation to use for the height dimension + dil_dim2: int + The dilation to use for the width dimension + """ + ffi.conv2d_autoselect_algo( + n, + c, + h, + w, + k, + r, + s, + as_pointer(ptr_x), + as_pointer(ptr_w), + as_pointer(ptr_y), + as_cudnn_type(tx), + as_cudnn_type(tw), + as_cudnn_type(ty), + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) diff --git a/python/hidet/cuda/cudnn/utils.py b/python/hidet/cuda/cudnn/utils.py index 97c3e5958..50670310e 100644 --- a/python/hidet/cuda/cudnn/utils.py +++ b/python/hidet/cuda/cudnn/utils.py @@ -21,6 +21,14 @@ dtypes.int64: cudnnDataType.CUDNN_DATA_INT64, } +_cudnn_type_dict_str = { + "float16": cudnnDataType.CUDNN_DATA_HALF, + "float32": cudnnDataType.CUDNN_DATA_FLOAT, + "float64": cudnnDataType.CUDNN_DATA_DOUBLE, + "int32": cudnnDataType.CUDNN_DATA_INT32, + "int64": cudnnDataType.CUDNN_DATA_INT64, +} + def as_pointer(obj) -> int: from hidet.graph.tensor import Tensor @@ -36,7 +44,9 @@ def as_pointer(obj) -> int: def as_cudnn_type(obj) -> int: if isinstance(obj, DataType): return _cudnn_type_dict[obj] + elif isinstance(obj, str): + return _cudnn_type_dict_str[obj] elif isinstance(obj, int): return obj else: - raise TypeError(f'Expected DataType or int, but got {type(obj)}') + raise TypeError(f'Expected DataType, int, or str, but got {type(obj)}') diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 0d355cf53..343d5e01b 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -73,5 +73,115 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) +@pytest.mark.parametrize( + "n, c, h, w, k, p, q, r, s, padding, stride, dilations", + [ + [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, + [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 + [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 + [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 + ], +) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [(hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), + (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), + (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), + ] +) +def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + golden = ops.conv2d( + tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2) + ) + hidet.cuda.cudnn.conv2d_gemm( + n, + c, + h, + w, + k, + r, + s, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) + + hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) + + +@pytest.mark.parametrize( + "n, c, h, w, k, p, q, r, s, padding, stride, dilations", + [ + [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, + [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 + [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 + [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 + ], +) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [(hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), + (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), + (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), + ] +) +def test_cudnn_conv2d_autoselect_algo(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): + tx = tw = ty = dtype + pad_dim1, pad_dim2 = padding + str_dim1, str_dim2 = stride + dil_dim1, dil_dim2 = dilations + + tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) + tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) + tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) + + golden = ops.conv2d( + tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2) + ) + hidet.cuda.cudnn.conv2d_autoselect_algo( + n, + c, + h, + w, + k, + r, + s, + tensor_x, + tensor_w, + tensor_y, + tx, + tw, + ty, + compute_type, + pad_dim1, + pad_dim2, + str_dim1, + str_dim2, + dil_dim1, + dil_dim2, + ) + + hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) + + if __name__ == '__main__': pytest.main([__file__]) From 40a61497b87a25646ae9eb4bb9c78648a7cc725c Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Tue, 25 Jun 2024 13:10:32 -0400 Subject: [PATCH 08/12] cuDNN cleanup --- python/hidet/cuda/cudnn/__init__.py | 2 +- python/hidet/cuda/cudnn/benchmark.py | 54 +--------------- python/hidet/cuda/cudnn/ffi.py | 27 -------- python/hidet/cuda/cudnn/kernels.py | 93 ---------------------------- src/hidet/runtime/cuda/cudnn.cpp | 90 +++++++-------------------- tests/cuda/test_cudnn.py | 55 ---------------- 6 files changed, 24 insertions(+), 297 deletions(-) diff --git a/python/hidet/cuda/cudnn/__init__.py b/python/hidet/cuda/cudnn/__init__.py index b22207b94..6e3dd210c 100644 --- a/python/hidet/cuda/cudnn/__init__.py +++ b/python/hidet/cuda/cudnn/__init__.py @@ -10,4 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .ffi import cudnnDataType -from .kernels import conv2d, conv2d_gemm, conv2d_autoselect_algo +from .kernels import conv2d, conv2d_gemm diff --git a/python/hidet/cuda/cudnn/benchmark.py b/python/hidet/cuda/cudnn/benchmark.py index 44bc2067f..aeb95bcc8 100644 --- a/python/hidet/cuda/cudnn/benchmark.py +++ b/python/hidet/cuda/cudnn/benchmark.py @@ -101,51 +101,6 @@ def benchmark_cudnn_conv2d_gemm(dtype, compute_type, n, c, h, w, k, p, q, r, s, print("-------------------------------------------------") -def benchmark_cudnn_conv2d_autoselect_algo(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): - # Uses cudnn Legacy-API to autoselect the fastest algorithm - tx = tw = ty = dtype - pad_dim1, pad_dim2 = padding - str_dim1, str_dim2 = stride - dil_dim1, dil_dim2 = dilations - - tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) - tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) - tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) - - latencies = do_bench( - lambda: hidet.cuda.cudnn.conv2d_autoselect_algo( - n, - c, - h, - w, - k, - r, - s, - tensor_x, - tensor_w, - tensor_y, - tx, - tw, - ty, - compute_type, - pad_dim1, - pad_dim2, - str_dim1, - str_dim2, - dil_dim1, - dil_dim2, - ), - warmup=10, - rep=100, - ) - - print( - f"cudnn_autoselect_algo Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " - f"weight shape = {[k,c,r,s]}, padding = {padding}, stride = {stride}, dilations = {dilations}:" - ) - print("Median Latency Is: " + str(latencies[1]) + " milliseconds") - print("-------------------------------------------------") - def benchmark_torch_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): # Native PyTorch Eager-mode Execution @@ -189,11 +144,7 @@ def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, paddi graph = hidet.graph.optimize(graph) graph = graph.cuda_graph() - latencies = do_bench( - lambda: graph.run_async(), - warmup=10, - rep=100, - ) + latencies = do_bench(lambda: graph.run_async(), warmup=10, rep=100) print( f"Optimized Hidet Results for Configuration: dtype = {dtype}, input shape = {[n,c,h,w]}, " @@ -216,11 +167,10 @@ def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, paddi [4, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], [8, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], ] - dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float16', cudnnDataType.CUDNN_DATA_HALF]]#, ['float64', cudnnDataType.CUDNN_DATA_DOUBLE]] + dtypes = [['float32', cudnnDataType.CUDNN_DATA_FLOAT], ['float16', cudnnDataType.CUDNN_DATA_HALF]] for data_type in dtypes: for size in sizes: benchmark_cudnn_conv2d_gemm(*(data_type + size)) benchmark_torch_conv2d(*(data_type + size)) - benchmark_cudnn_conv2d_autoselect_algo(*(data_type + size)) benchmark_hidet_conv2d(*(data_type + size)) diff --git a/python/hidet/cuda/cudnn/ffi.py b/python/hidet/cuda/cudnn/ffi.py index 46f512902..31d172e23 100644 --- a/python/hidet/cuda/cudnn/ffi.py +++ b/python/hidet/cuda/cudnn/ffi.py @@ -95,33 +95,6 @@ class cudnnDataType(IntEnum): restype=None, ) -conv2d_autoselect_algo = get_func( - func_name='hidet_cudnn_conv2d_autoselect_algo', - arg_types=[ - c_int32, # n - c_int32, # c - c_int32, # h - c_int32, # w - c_int32, # k - c_int32, # r - c_int32, # s - c_void_p, # ptr_x - c_void_p, # ptr_w - c_void_p, # ptr_y - c_int32, # tx - c_int32, # tw - c_int32, # ty - c_int32, # compute_type - c_int32, # pad_dim1 - c_int32, # pad_dim2 - c_int32, # str_dim1 - c_int32, # str_dim2 - c_int32, # dil_dim1 - c_int32, # dil_dim2 - ], - restype=None, -) - @initialize() def set_cudnn_library_path(): diff --git a/python/hidet/cuda/cudnn/kernels.py b/python/hidet/cuda/cudnn/kernels.py index c772dd7de..53a646f75 100644 --- a/python/hidet/cuda/cudnn/kernels.py +++ b/python/hidet/cuda/cudnn/kernels.py @@ -210,96 +210,3 @@ def conv2d_gemm( dil_dim2, ) - -def conv2d_autoselect_algo( - n: int, - c: int, - h: int, - w: int, - k: int, - r: int, - s: int, - ptr_x, - ptr_w, - ptr_y, - tx: Union[int, DataType], - tw: Union[int, DataType], - ty: Union[int, DataType], - compute_type: Union[int, cudnnDataType], - pad_dim1: int, - pad_dim2: int, - str_dim1: int, - str_dim2: int, - dil_dim1: int, - dil_dim2: int, -): - """ - Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y. - - Parameters - ---------- - n: int - Batch number. - c: int - Number of channels in the input tensor x. - h: int - Height of the input tensor x. - w: int - Width of the input tensor x. - k: int - Number of channels in the output tensor y. - r: int - Height of the filter w. - s: int - Width of the filter w. - ptr_x: hidet.Tensor or int - Input tensor x, can be either a Tensor or an integer (the address of the tensor). - ptr_w: hidet.Tensor or int - Weight tensor w, can be either a Tensor or an integer (the address of the tensor). - ptr_y: hidet.Tensor or int - Output tensor y, can be either a Tensor or an integer (the address of the tensor). - tx: Union[int, DataType] - Type of elements in tensor x. - tw: Union[int, DataType] - Type of elements in tensor w. - ty: Union[int, DataType] - Type of elements in tensor y. - compute_type: Union[int, cudnnDataType] - The compute type of the operation. - For cuDNN, there's no such thing as a cudnnComputeType_t type. - As per the official example, the computeType is defined in terms of cudnnDataType_t - pad_dim1: int - The value to use for padding along the height dimension - pad_dim2: int - The value to use for padding along the width dimension - str_dim1: int - The stride to use for the height dimension - str_dim2: int - The stride to use for the width dimension - dil_dim1: int - The dilation to use for the height dimension - dil_dim2: int - The dilation to use for the width dimension - """ - ffi.conv2d_autoselect_algo( - n, - c, - h, - w, - k, - r, - s, - as_pointer(ptr_x), - as_pointer(ptr_w), - as_pointer(ptr_y), - as_cudnn_type(tx), - as_cudnn_type(tw), - as_cudnn_type(ty), - compute_type, - pad_dim1, - pad_dim2, - str_dim1, - str_dim2, - dil_dim1, - dil_dim2, - ) diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp index 9b57d3868..21c834091 100644 --- a/src/hidet/runtime/cuda/cudnn.cpp +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -10,6 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -632,6 +633,7 @@ DLL void hidet_cudnn_conv2d_gemm( int tx, int tw, int ty, int compute_type, int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2) { + auto begin1 = std::chrono::steady_clock::now(); lazy_load_cudnn(); cudnnHandle_t cur_handle = CudnnContext::current_handle(); @@ -659,90 +661,40 @@ DLL void hidet_cudnn_conv2d_gemm( CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty), out_n, out_c, out_h, out_w)); - size_t workspaceSize{0}; - CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, - convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - &workspaceSize)); - void *workspace = request_cuda_workspace(workspaceSize, false); + + // size_t workspaceSize{0}; + // CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, + // convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + // &workspaceSize)); + + size_t workspaceSize{2000000}; + // std::cout << workspaceSize << std::endl; + // void *workspace = request_cuda_workspace(workspaceSize, false); + void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); void *p_alpha = nullptr; void *p_beta = nullptr; cudnnDataType_t compType = cudnnDataType_t(compute_type); set_alpha_beta(&p_alpha, &p_beta, compType); + auto end1 = std::chrono::steady_clock::now(); + std::cout << "Time difference 1 = " << std::chrono::duration_cast(end1 - begin1).count() << "[µs]" << std::endl; + auto begin2 = std::chrono::steady_clock::now(); CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w, convolution_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, workspace, workspaceSize, p_beta, output_descriptor, ptr_y)); + auto end2 = std::chrono::steady_clock::now(); + std::cout << "Time difference 2 = " << std::chrono::duration_cast(end2 - begin2).count() << "[µs]" << std::endl; + auto begin3 = std::chrono::steady_clock::now(); CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor)); CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor)); CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor)); CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor)); -} - -DLL void hidet_cudnn_conv2d_autoselect_algo( - int n, int c, int h, int w, int k, int r, int s, - void *ptr_x, void *ptr_w, void *ptr_y, - int tx, int tw, int ty, int compute_type, - int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2) -{ - lazy_load_cudnn(); - - cudnnHandle_t cur_handle = CudnnContext::current_handle(); - - // Set the stream to the current stream - cudaStream_t cur_stream = get_cuda_stream(); - CHECK_CUDNN(cudnnSetStream(cur_handle, cur_stream)); - - // Build descriptors and launch the kernel - cudnnTensorDescriptor_t input_descriptor; - CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); - CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(tx), n, c, h, w)); - cudnnFilterDescriptor_t kernel_descriptor; - CHECK_CUDNN(cudnnCreateFilterDescriptor(&kernel_descriptor)); - CHECK_CUDNN(cudnnSetFilter4dDescriptor(kernel_descriptor, cudnnDataType_t(tw), CUDNN_TENSOR_NCHW, k, c, r, s)); - cudnnConvolutionDescriptor_t convolution_descriptor; - CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); - CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, pad_dim1, pad_dim2, str_dim1, str_dim2, dil_dim1, dil_dim2, - CUDNN_CROSS_CORRELATION, cudnnDataType_t(compute_type))); - - int out_n{0}, out_c{0}, out_h{0}, out_w{0}; - CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, input_descriptor, kernel_descriptor, - &out_n, &out_c, &out_h, &out_w)); - cudnnTensorDescriptor_t output_descriptor; - CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); - CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty), - out_n, out_c, out_h, out_w)); - - int returnedAlgoCount; - cudnnConvolutionFwdAlgoPerf_t perfResults; - - CHECK_CUDNN(cudnnGetConvolutionForwardAlgorithm_v7(cur_handle, input_descriptor, kernel_descriptor, - convolution_descriptor, output_descriptor, - 1, &returnedAlgoCount, &perfResults)); - cudnnConvolutionFwdAlgo_t convolution_algorithm = perfResults.algo; - - size_t workspaceSize{0}; - CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, - convolution_descriptor, output_descriptor, convolution_algorithm, - &workspaceSize)); - void *workspace = request_cuda_workspace(workspaceSize, false); - - void *p_alpha = nullptr; - void *p_beta = nullptr; - cudnnDataType_t compType = cudnnDataType_t(compute_type); - set_alpha_beta(&p_alpha, &p_beta, compType); - - CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w, - convolution_descriptor, convolution_algorithm, - workspace, workspaceSize, - p_beta, output_descriptor, ptr_y)); - - CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor)); - CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor)); - CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor)); - CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor)); + hidet_cuda_free_async(workspace, cur_stream); + auto end3 = std::chrono::steady_clock::now(); + std::cout << "Time difference 3 = " << std::chrono::duration_cast(end3 - begin3).count() << "[µs]" << std::endl; } DLL void hidet_cudnn_conv2d( diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 343d5e01b..4579a5738 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -128,60 +128,5 @@ def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, paddi hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) -@pytest.mark.parametrize( - "n, c, h, w, k, p, q, r, s, padding, stride, dilations", - [ - [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, - [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 - [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 - [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 - ], -) -@pytest.mark.parametrize( - 'dtype, compute_type, tol', - [(hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), - (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), - (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), - ] -) -def test_cudnn_conv2d_autoselect_algo(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): - tx = tw = ty = dtype - pad_dim1, pad_dim2 = padding - str_dim1, str_dim2 = stride - dil_dim1, dil_dim2 = dilations - - tensor_x = hidet.randn((n, c, h, w), device='cuda', dtype=tx) - tensor_w = hidet.randn((k, c, r, s), device='cuda', dtype=tw) - tensor_y = hidet.empty((n, k, p, q), device='cuda', dtype=ty) - - golden = ops.conv2d( - tensor_x, tensor_w, stride=(str_dim1, str_dim2), dilations=(dil_dim1, dil_dim2), padding=(pad_dim1, pad_dim2) - ) - hidet.cuda.cudnn.conv2d_autoselect_algo( - n, - c, - h, - w, - k, - r, - s, - tensor_x, - tensor_w, - tensor_y, - tx, - tw, - ty, - compute_type, - pad_dim1, - pad_dim2, - str_dim1, - str_dim2, - dil_dim1, - dil_dim2, - ) - - hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol) - - if __name__ == '__main__': pytest.main([__file__]) From 5d5d8a63b6b230c2bee027c0b7abb56d94f95787 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Thu, 27 Jun 2024 12:41:27 -0400 Subject: [PATCH 09/12] [CUDNN] Cleanup --- src/hidet/runtime/cuda/cudnn.cpp | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/hidet/runtime/cuda/cudnn.cpp b/src/hidet/runtime/cuda/cudnn.cpp index 2ed0f0ff3..07f581e6a 100644 --- a/src/hidet/runtime/cuda/cudnn.cpp +++ b/src/hidet/runtime/cuda/cudnn.cpp @@ -633,7 +633,6 @@ DLL void hidet_cudnn_conv2d_gemm( int tx, int tw, int ty, int compute_type, int pad_dim1, int pad_dim2, int str_dim1, int str_dim2, int dil_dim1, int dil_dim2) { - auto begin1 = std::chrono::steady_clock::now(); lazy_load_cudnn(); cudnnHandle_t cur_handle = CudnnContext::current_handle(); @@ -662,39 +661,27 @@ DLL void hidet_cudnn_conv2d_gemm( CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, CUDNN_TENSOR_NCHW, cudnnDataType_t(ty), out_n, out_c, out_h, out_w)); - // size_t workspaceSize{0}; - // CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, - // convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - // &workspaceSize)); - - size_t workspaceSize{2000000}; - // std::cout << workspaceSize << std::endl; - // void *workspace = request_cuda_workspace(workspaceSize, false); + size_t workspaceSize{0}; + CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cur_handle, input_descriptor, kernel_descriptor, + convolution_descriptor, output_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + &workspaceSize)); void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); void *p_alpha = nullptr; void *p_beta = nullptr; cudnnDataType_t compType = cudnnDataType_t(compute_type); set_alpha_beta(&p_alpha, &p_beta, compType); - auto end1 = std::chrono::steady_clock::now(); - std::cout << "Time difference 1 = " << std::chrono::duration_cast(end1 - begin1).count() << "[µs]" << std::endl; - auto begin2 = std::chrono::steady_clock::now(); CHECK_CUDNN(cudnnConvolutionForward(cur_handle, p_alpha, input_descriptor, ptr_x, kernel_descriptor, ptr_w, convolution_descriptor, CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, workspace, workspaceSize, p_beta, output_descriptor, ptr_y)); - auto end2 = std::chrono::steady_clock::now(); - std::cout << "Time difference 2 = " << std::chrono::duration_cast(end2 - begin2).count() << "[µs]" << std::endl; - auto begin3 = std::chrono::steady_clock::now(); CHECK_CUDNN(cudnnDestroyTensorDescriptor(input_descriptor)); CHECK_CUDNN(cudnnDestroyTensorDescriptor(output_descriptor)); CHECK_CUDNN(cudnnDestroyFilterDescriptor(kernel_descriptor)); CHECK_CUDNN(cudnnDestroyConvolutionDescriptor(convolution_descriptor)); hidet_cuda_free_async(workspace, cur_stream); - auto end3 = std::chrono::steady_clock::now(); - std::cout << "Time difference 3 = " << std::chrono::duration_cast(end3 - begin3).count() << "[µs]" << std::endl; } DLL void hidet_cudnn_conv2d( @@ -830,9 +817,7 @@ DLL void hidet_cudnn_conv2d( CHECK_CUDNN(cudnnBackendCreateDescriptor(CUDNN_BACKEND_ENGINE_DESCRIPTOR, &engine)); CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_OPERATION_GRAPH, CUDNN_TYPE_BACKEND_DESCRIPTOR, 1, &op_graph)); - // TODO: Is it okay to hardcode the engine to be CUDNN_ATTR_ENGINE_GLOBAL_INDEX 0? - // As mentioned here: https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html, - // Engine selection should be determined based on some heuristics. + int64_t gidx = 0; CHECK_CUDNN(cudnnBackendSetAttribute(engine, CUDNN_ATTR_ENGINE_GLOBAL_INDEX, CUDNN_TYPE_INT64, 1, &gidx)); From 2f8fdaace380daeb7c33db0b7b8897c57cea4bd4 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Thu, 27 Jun 2024 12:48:46 -0400 Subject: [PATCH 10/12] [CUDNN] Format and lint --- python/hidet/cuda/cudnn/benchmark.py | 3 +-- python/hidet/cuda/cudnn/kernels.py | 1 - tests/cuda/test_cudnn.py | 26 ++++++++++++++------------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/hidet/cuda/cudnn/benchmark.py b/python/hidet/cuda/cudnn/benchmark.py index aeb95bcc8..446173734 100644 --- a/python/hidet/cuda/cudnn/benchmark.py +++ b/python/hidet/cuda/cudnn/benchmark.py @@ -101,7 +101,6 @@ def benchmark_cudnn_conv2d_gemm(dtype, compute_type, n, c, h, w, k, p, q, r, s, print("-------------------------------------------------") - def benchmark_torch_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): # Native PyTorch Eager-mode Execution data = np.array(np.random.randn(n, c, h, w)).astype(dtype) @@ -129,7 +128,7 @@ def benchmark_torch_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, paddi def benchmark_hidet_conv2d(dtype, compute_type, n, c, h, w, k, p, q, r, s, padding, stride, dilations): # Uses optimized Hidet Graph implementation - tx = tw = ty = dtype + tx = tw = dtype pad_dim1, pad_dim2 = padding str_dim1, str_dim2 = stride dil_dim1, dil_dim2 = dilations diff --git a/python/hidet/cuda/cudnn/kernels.py b/python/hidet/cuda/cudnn/kernels.py index 53a646f75..781d9bd33 100644 --- a/python/hidet/cuda/cudnn/kernels.py +++ b/python/hidet/cuda/cudnn/kernels.py @@ -209,4 +209,3 @@ def conv2d_gemm( dil_dim1, dil_dim2, ) - diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 4579a5738..985bcd82a 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -21,16 +21,17 @@ [ [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 - [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 - [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 + [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 + [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 ], ) @pytest.mark.parametrize( 'dtype, compute_type, tol', - [(hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), - (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), - (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), - ] + [ + (hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), + (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), + (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), + ], ) def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): tx = tw = ty = dtype @@ -78,16 +79,17 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s [ [1, 3, 32, 32, 12, 30, 30, 3, 3, [0, 0], [1, 1], [1, 1]], # kernel 3, [2, 3, 32, 32, 12, 11, 6, 7, 7, [1, 2], [2, 3], [2, 3]], # kernel 7, batch size 2 - [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 - [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 + [1, 3, 224, 224, 64, 112, 112, 7, 7, [3, 3], [2, 2], [1, 1]], # resnet layer 1 + [1, 64, 56, 56, 128, 56, 56, 1, 1, [0, 0], [1, 1], [1, 1]], # resnet layer 2 - kernel size 1 ], ) @pytest.mark.parametrize( 'dtype, compute_type, tol', - [(hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), - (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), - (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), - ] + [ + (hidet.float16, cudnnDataType.CUDNN_DATA_HALF, 1e-2), + (hidet.float32, cudnnDataType.CUDNN_DATA_FLOAT, 1e-5), + (hidet.float64, cudnnDataType.CUDNN_DATA_DOUBLE, 1e-8), + ], ) def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): tx = tw = ty = dtype From 23d80fe952b93a239e540c313e2e53fb82523ca3 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Fri, 28 Jun 2024 12:38:10 -0400 Subject: [PATCH 11/12] [CUDNN] Disable TF32 operations on Ampere architecture to avoid losing precision --- tests/cuda/test_cudnn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 985bcd82a..118c56fd3 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -92,6 +92,11 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s ], ) def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): + # Disable TF32 operations on Ampere architecture to avoid losing precision. + import os + + os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + tx = tw = ty = dtype pad_dim1, pad_dim2 = padding str_dim1, str_dim2 = stride From a48012c0f6e3efffe0fb1df76115991926386493 Mon Sep 17 00:00:00 2001 From: Yudi Sun Date: Tue, 2 Jul 2024 18:57:36 -0400 Subject: [PATCH 12/12] [CuDNN] Increase test tol for sm80 and higher --- tests/cuda/test_cudnn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/cuda/test_cudnn.py b/tests/cuda/test_cudnn.py index 118c56fd3..86269a576 100644 --- a/tests/cuda/test_cudnn.py +++ b/tests/cuda/test_cudnn.py @@ -11,6 +11,7 @@ # limitations under the License. import pytest import math +import torch import hidet from hidet import ops from hidet.cuda.cudnn import cudnnDataType @@ -92,11 +93,6 @@ def test_cudnn_conv2d(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, s ], ) def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, padding, stride, dilations, tol): - # Disable TF32 operations on Ampere architecture to avoid losing precision. - import os - - os.environ["NVIDIA_TF32_OVERRIDE"] = "0" - tx = tw = ty = dtype pad_dim1, pad_dim2 = padding str_dim1, str_dim2 = stride @@ -132,6 +128,9 @@ def test_cudnn_conv2d_gemm(n, c, h, w, k, p, q, r, s, dtype, compute_type, paddi dil_dim2, ) + if dtype == hidet.float32 and torch.cuda.get_device_capability()[0] >= 8: + tol = 1e-2 + hidet.utils.assert_close(actual=tensor_y, expected=golden, rtol=tol, atol=tol)