Skip to content

Commit

Permalink
Add cudnn conv2d
Browse files Browse the repository at this point in the history
  • Loading branch information
Yudi Sun committed Mar 7, 2024
1 parent b7c9026 commit 075da22
Show file tree
Hide file tree
Showing 9 changed files with 940 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
add_library(hidet_runtime SHARED
src/hidet/runtime/cuda/context.cpp
src/hidet/runtime/cuda/cublas.cpp
src/hidet/runtime/cuda/cudnn.cpp
src/hidet/runtime/cuda/cuda.cpp
src/hidet/runtime/cpu/context.cpp
src/hidet/runtime/callbacks.cpp
Expand Down
18 changes: 18 additions & 0 deletions include/hidet/runtime/cuda/cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,21 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#define HIDET_CUDNN_MAX_GPUS 32

#include <hidet/runtime/common.h>

struct cudnnContext;
typedef struct cudnnContext *cudnnHandle_t;

typedef void *cudnnBackendDescriptor_t;

struct CudnnContext {
cudnnHandle_t handles[HIDET_CUDNN_MAX_GPUS];
static CudnnContext* global();
static cudnnHandle_t current_handle();
};

DLL void hidet_cudnn_set_library_path(const char* path);

1 change: 1 addition & 0 deletions python/hidet/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .event import Event

from . import cublas
from . import cudnn
13 changes: 13 additions & 0 deletions python/hidet/cuda/cudnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import cudnnDataType
from .kernels import conv2d
80 changes: 80 additions & 0 deletions python/hidet/cuda/cudnn/ffi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import glob
from enum import IntEnum
from ctypes import c_int32, c_void_p, c_char_p
from hidet.ffi.ffi import get_func

from hidet.utils.py import initialize

class cudnnDataType(IntEnum):
"""
defined in cudnn_ops_infer_v8.h
"""

CUDNN_DATA_FLOAT = 0
CUDNN_DATA_DOUBLE = 1
CUDNN_DATA_HALF = 2
CUDNN_DATA_INT8 = 3
CUDNN_DATA_INT32 = 4
CUDNN_DATA_INT8x4 = 5
CUDNN_DATA_UINT8 = 6
CUDNN_DATA_UINT8x4 = 7
CUDNN_DATA_INT8x32 = 8
CUDNN_DATA_BFLOAT16 = 9
CUDNN_DATA_INT64 = 10


set_library_path = get_func(func_name='hidet_cudnn_set_library_path', arg_types=[c_char_p], restype=None)

conv2d = get_func(
func_name='hidet_cudnn_conv2d',
arg_types=[
c_int32, # n
c_int32, # c
c_int32, # h
c_int32, # w
c_int32, # k
c_int32, # r
c_int32, # s
c_int32, # p
c_int32, # q
c_void_p, # ptr_x
c_void_p, # ptr_w
c_void_p, # ptr_y
c_int32, # tx
c_int32, # tw
c_int32, # ty
c_int32, # compute_type
c_int32, # pad_dim1
c_int32, # pad_dim2
c_int32, # str_dim1
c_int32, # str_dim2
c_int32, # dil_dim1
c_int32 # dil_dim2
],
restype=None,
)

@initialize()
def set_cudnn_library_path():
# use nvidia-cuda-cudnn
for path in sys.path:
nvidia_path = os.path.join(path, 'nvidia')
if not os.path.exists(nvidia_path):
continue
cudnn_path = glob.glob(os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.[0-9]*'))
if cudnn_path:
set_library_path(cudnn_path[0].encode('utf-8'))
return
116 changes: 116 additions & 0 deletions python/hidet/cuda/cudnn/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from hidet.ir.dtypes import DataType
from .ffi import cudnnDataType
from . import ffi
from .utils import as_pointer, as_cudnn_type

def conv2d(
n: int,
c: int,
h: int,
w: int,
k: int,
r: int,
s: int,
p: int,
q: int,
ptr_x,
ptr_w,
ptr_y,
tx: Union[int, DataType],
tw: Union[int, DataType],
ty: Union[int, DataType],
compute_type: Union[int, cudnnDataType],
pad_dim1: int,
pad_dim2: int,
str_dim1: int,
str_dim2: int,
dil_dim1: int,
dil_dim2: int
):
"""
Calculates the 2D convolution of tensor x with filter w, stores the result in tensor y.
Parameters
----------
n: int
Batch number.
c: int
Number of channels in the input tensor x.
h: int
Height of the input tensor x.
w: int
Width of the input tensor x.
k: int
Number of channels in the output tensor y.
r: int
Height of the filter w.
s: int
Width of the filter w.
p: int
Height of the output tensor y.
q: int
Width of the output tensor y.
ptr_x: hidet.Tensor or int
Input tensor x, can be either a Tensor or an integer (the address of the tensor).
ptr_w: hidet.Tensor or int
Weight tensor w, can be either a Tensor or an integer (the address of the tensor).
ptr_y: hidet.Tensor or int
Output tensor y, can be either a Tensor or an integer (the address of the tensor).
tx: Union[int, DataType]
Type of elements in tensor x.
tw: Union[int, DataType]
Type of elements in tensor w.
ty: Union[int, DataType]
Type of elements in tensor y.
compute_type: Union[int, cudnnDataType]
The compute type of the operation.
For cuDNN, there's no such thing as a cudnnComputeType_t type.
As per the official example, the computeType is defined in terms of cudnnDataType_t
pad_dim1: int
The value to use for padding along the height dimension
pad_dim2: int
The value to use for padding along the width dimension
str_dim1: int
The stride to use for the height dimension
str_dim2: int
The stride to use for the width dimension
dil_dim1: int
The dilation to use for the height dimension
dil_dim2: int
The dilation to use for the width dimension
"""
ffi.conv2d(
n,
c,
h,
w,
k,
r,
s,
p,
q,
as_pointer(ptr_x),
as_pointer(ptr_w),
as_pointer(ptr_y),
as_cudnn_type(tx),
as_cudnn_type(tw),
as_cudnn_type(ty),
compute_type,
pad_dim1,
pad_dim2,
str_dim1,
str_dim2,
dil_dim1,
dil_dim2
)
40 changes: 40 additions & 0 deletions python/hidet/cuda/cudnn/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir import dtypes
from hidet.ir.dtypes import DataType
from .ffi import cudnnDataType

_cudnn_type_dict = {
dtypes.float32: cudnnDataType.CUDNN_DATA_FLOAT,
dtypes.float64: cudnnDataType.CUDNN_DATA_DOUBLE,
dtypes.int32: cudnnDataType.CUDNN_DATA_INT32,
dtypes.int64: cudnnDataType.CUDNN_DATA_INT64
}

def as_pointer(obj) -> int:
from hidet.graph.tensor import Tensor

if isinstance(obj, Tensor):
return obj.storage.addr
elif isinstance(obj, int):
return obj
else:
raise TypeError(f'Expected Tensor or int, but got {type(obj)}')


def as_cudnn_type(obj) -> int:
if isinstance(obj, DataType):
return _cudnn_type_dict[obj]
elif isinstance(obj, int):
return obj
else:
raise TypeError(f'Expected DataType or int, but got {type(obj)}')
Loading

0 comments on commit 075da22

Please sign in to comment.