diff --git a/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp b/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp index 461e660f5056..5d25b4550399 100644 --- a/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp +++ b/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp @@ -1624,6 +1624,18 @@ void MapNamesBlas::setExplicitNamespaceMap( {"CUBLASLT_EPILOGUE_RELU", MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::epilogue_t::relu"}, + {"CUBLASLT_EPILOGUE_GELU", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::gelu"}, + {"CUBLASLT_EPILOGUE_GELU_AUX", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::gelu_aux"}, + {"CUBLASLT_EPILOGUE_BIAS", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::bias"}, + {"CUBLASLT_EPILOGUE_GELU_AUX_BIAS", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::epilogue_t::gelu_aux_bias"}, {"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE", MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::transform_desc_t::attribute::scale_type"}, diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 44972b6568d7..991eab06e547 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -13,6 +13,7 @@ #ifndef __DPCT_BLAS_GEMM_UTILS_HPP__ #define __DPCT_BLAS_GEMM_UTILS_HPP__ +#include "blas_utils.hpp" #include "compat_service.hpp" #include "dnnl_utils.hpp" @@ -33,7 +34,7 @@ enum class pointer_mode_t { alpha_device_vector_beta_zero, alpha_device_vector_beta_host }; -enum class epilogue_t { nop = 1, relu }; +enum class epilogue_t { nop = 1, relu, bias, gelu, gelu_aux, gelu_aux_bias }; class descriptor; using descriptor_ptr = descriptor *; @@ -695,9 +696,9 @@ template struct absmax_impl { } // namespace detail /// This function does the following operations: -/// (1) D_temp = epilogue(alpha * scale_a * op_a(A) * scale_b * op_b(B) + beta * C) -/// (2) Amax = absmax(D_temp) when matmul_desc_t::attribute::absmax_d_pointer is specified -/// (3) D = scale_d * D_temp +/// (1) D_temp = epilogue(alpha * scale_a * op_a(A) * scale_b * op_b(B) + beta * +/// C) (2) Amax = absmax(D_temp) when matmul_desc_t::attribute::absmax_d_pointer +/// is specified (3) D = scale_d * D_temp /// "op_a" is specified by the matmul_desc_t::attribute::trans_a /// (default is nontrans) /// "op_b" is specified by the matmul_desc_t::attribute::trans_b @@ -780,9 +781,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, } if (compute_desc->_epilogue != epilogue_t::nop && - compute_desc->_epilogue != epilogue_t::relu) { - throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only " - "supports relu epilogue currently."); + compute_desc->_epilogue != epilogue_t::relu && + compute_desc->_epilogue != epilogue_t::bias && + compute_desc->_epilogue != epilogue_t::gelu && + compute_desc->_epilogue != epilogue_t::gelu_aux && + compute_desc->_epilogue != epilogue_t::gelu_aux_bias) { + throw std::runtime_error( + "dpct::blas_gemm::experimental::matmul() only " + "supports relu, gelu, gelu_aux, gelu with bias epilogue currently."); } if (!(compute_desc->_scale_type == library_data_t::real_int32 && @@ -912,6 +918,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, ::dnnl::memory::dims src_dims = {M, K}; ::dnnl::memory::dims weights_dims = {K, N}; ::dnnl::memory::dims bias_dims = {M, N}; + ::dnnl::memory::dims epilogue_bias_dims = {M, N}; ::dnnl::memory::dims dst_dims = {M, N}; const ::dnnl::memory::dims src_strides = @@ -1022,18 +1029,37 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc, matmul_args.insert( {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *scales_alpha}); } - ::dnnl::post_ops matmul_ops; if (!beta_is_zero) { matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md); matmul_args.insert( {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *bias_mem}); } + sycl::queue &queue = ::dpct::cs::get_default_queue(); if (compute_desc->_epilogue != epilogue_t::nop) { - matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); + ::dnnl::post_ops matmul_ops; + if (compute_desc->_epilogue == epilogue_t::relu) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f); + } else if (compute_desc->_epilogue == epilogue_t::gelu) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + } else if (compute_desc->_epilogue == epilogue_t::gelu_aux) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c, + compute_desc->_epilogue_aux_ld, new_ldc, m, n, + sizeof(size_t), dpct::device_to_device, queue, + false); + } else if (compute_desc->_epilogue == epilogue_t::bias) { + matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md); + } else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) { + matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md); + dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c, + compute_desc->_epilogue_aux_ld, new_ldc, m, n, + sizeof(size_t), dpct::device_to_device, queue, + false); + } + matmul_attr.set_post_ops(matmul_ops); } - matmul_attr.set_post_ops(matmul_ops); - auto matmul_pd = ::dnnl::matmul::primitive_desc( handle->get_engine(), src_md, weights_md, dst_md, matmul_attr); auto matmul_prim = ::dnnl::matmul(matmul_pd); diff --git a/clang/test/dpct/cublaslt.cu b/clang/test/dpct/cublaslt.cu index 30fb82c05b4d..951ef29dba3e 100644 --- a/clang/test/dpct/cublaslt.cu +++ b/clang/test/dpct/cublaslt.cu @@ -234,9 +234,17 @@ void foo3() { // CHECK: dpct::blas_gemm::experimental::epilogue_t e; // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop; // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::bias; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux_bias; + // CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux; cublasLtEpilogue_t e; e = CUBLASLT_EPILOGUE_DEFAULT; e = CUBLASLT_EPILOGUE_RELU; + e = CUBLASLT_EPILOGUE_GELU; + e = CUBLASLT_EPILOGUE_BIAS; + e = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; + e = CUBLASLT_EPILOGUE_GELU_AUX; } void foo4() {