Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/lib/DPCT/MapNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,10 @@ void MapNames::setExplicitNamespaceMap(
getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::nop"},
{"CUBLASLT_EPILOGUE_RELU",
getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_GELU",
getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::gelu"},
{"CUBLASLT_EPILOGUE_GELU_AUX",
getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::gelu"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
15 changes: 11 additions & 4 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class pointer_mode_t {
alpha_device_vector_beta_zero,
alpha_device_vector_beta_host
};
enum class epilogue_t { nop = 1, relu };
enum class epilogue_t { nop = 1, relu, gelu };

class descriptor;
using descriptor_ptr = descriptor *;
Expand Down Expand Up @@ -690,7 +690,7 @@ template <typename T> struct absmax_impl {
/// scale_type==float && a_type==int8 && b_type==int8 && c_type==int32;
/// scale_type==float && a_type==float && b_type==float && c_type==float.
/// Currently, this function only supports beta==0 or beta==1.
/// Currently, this function only supports the relu epilogue.
/// Currently, this function only supports the relu and gelu epilogue.
/// NOTE: Non-col-major matrix will be converted to col-major matrix before.
/// TODO: Impl row-major matmul without layout conversion.
/// multiplication and converted back after multiplication.
Expand Down Expand Up @@ -753,9 +753,10 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
}

if (compute_desc->_epilogue != epilogue_t::nop &&
compute_desc->_epilogue != epilogue_t::relu) {
compute_desc->_epilogue != epilogue_t::relu &&
compute_desc->epilogue != epilogue_t::gelu) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu epilogue currently.");
"supports relu and gelu epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -999,8 +1000,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
}

if (compute_desc->_epilogue != epilogue_t::nop) {

::dnnl::post_ops matmul_ops;
if(compute_desc->_epilogue == epilogue_t::relu){
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
else if(compute_desc->_epilogue == epilogue_t::gelu){
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
matmul_attr.set_post_ops(matmul_ops);
}

Expand Down
4 changes: 4 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,13 @@ void foo3() {
// CHECK: dpct::blas_gemm::experimental::epilogue_t e;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu;
cublasLtEpilogue_t e;
e = CUBLASLT_EPILOGUE_DEFAULT;
e = CUBLASLT_EPILOGUE_RELU;
e = CUBLASLT_EPILOGUE_GELU;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

e = CUBLASLT_EPILOGUE_GELU_AUX;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GELU_AUX is not as same as GELU.
According to CUBLAS doc, GELU_AUX is

Apply GELU point-wise transform to the results (x := GELU(x)). This epilogue mode outputs GELU input as a separate matrix (useful for training). See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER of [cublasLtMatmulDescAttributes_t](https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmuldescattributes-t).

So you need update the implementation to copy the GELU input to EPILOGUE_AUX_POINTER

}

void foo4() {
Expand Down