-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cuDNN] Add cudnn conv2d #435
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yudi0201 !
Overall looks good to me. After merging this PR, we can add a primitive function to call the conv2d_cudnn in our runtime library and have an operator like hidet.ops.conv2d_cudnn
.
src/hidet/runtime/cuda/cudnn.cpp
Outdated
|
||
void *dev_ptrs[3] = {ptr_x, ptr_w, ptr_y}; // device pointers | ||
int64_t uids[3] = {'x', 'w', 'y'}; | ||
void *workspace = hidet_cuda_malloc_async(workspaceSize, cur_stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to use the workspace shared by all hidet operators (i.e., https://github.com/hidet-org/hidet/blob/main/include/hidet/runtime/cuda/context.h#L46).
When we run the operator in the second time, there will not be any memory allocation. Thus, it can also be used in cudaGraph.
CHECK_CUDNN(cudnnBackendDestroyDescriptor(xDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(wDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(yDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(cDesc)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(fprop)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(op_graph)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(engine)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(engcfg)); | ||
CHECK_CUDNN(cudnnBackendDestroyDescriptor(plan)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to benchmark the performance of our implementation vs. PyTorch's conv2d performance. I am not sure whether the overhead of creating/destroying descriptors is large enough to influence the performance.
@yaoyaoding |
It's doable, similar to cublas gemm: 072a606 |
What about adding cudnn*, cublas* etc to search space? |
That's exactly what the commint I mentioned before does. |
If cuDNN needs to be installed, could it be added to the README? It doesn't seem to be included in CUDA Toolkit: link |
We can add the |
Similar like other package like cublas. |
Hi @c-fteixeira, the ci seems can not initialize the vm for test, could you help us to take a look? thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @yudi0201 !
No description provided.