Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,18 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_GELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu"},
{"CUBLASLT_EPILOGUE_GELU_AUX",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu_aux"},
{"CUBLASLT_EPILOGUE_BIAS",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bias"},
{"CUBLASLT_EPILOGUE_GELU_AUX_BIAS",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu_aux_bias"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
48 changes: 37 additions & 11 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef __DPCT_BLAS_GEMM_UTILS_HPP__
#define __DPCT_BLAS_GEMM_UTILS_HPP__

#include "blas_utils.hpp"
#include "compat_service.hpp"
#include "dnnl_utils.hpp"

Expand All @@ -33,7 +34,7 @@ enum class pointer_mode_t {
alpha_device_vector_beta_zero,
alpha_device_vector_beta_host
};
enum class epilogue_t { nop = 1, relu };
enum class epilogue_t { nop = 1, relu, bias, gelu, gelu_aux, gelu_aux_bias };

class descriptor;
using descriptor_ptr = descriptor *;
Expand Down Expand Up @@ -695,9 +696,9 @@ template <typename T> struct absmax_impl {
} // namespace detail

/// This function does the following operations:
/// (1) D_temp = epilogue(alpha * scale_a * op_a(A) * scale_b * op_b(B) + beta * C)
/// (2) Amax = absmax(D_temp) when matmul_desc_t::attribute::absmax_d_pointer is specified
/// (3) D = scale_d * D_temp
/// (1) D_temp = epilogue(alpha * scale_a * op_a(A) * scale_b * op_b(B) + beta *
/// C) (2) Amax = absmax(D_temp) when matmul_desc_t::attribute::absmax_d_pointer
/// is specified (3) D = scale_d * D_temp
/// "op_a" is specified by the matmul_desc_t::attribute::trans_a
/// (default is nontrans)
/// "op_b" is specified by the matmul_desc_t::attribute::trans_b
Expand Down Expand Up @@ -780,9 +781,14 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
}

if (compute_desc->_epilogue != epilogue_t::nop &&
compute_desc->_epilogue != epilogue_t::relu) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu epilogue currently.");
compute_desc->_epilogue != epilogue_t::relu &&
compute_desc->_epilogue != epilogue_t::bias &&
compute_desc->_epilogue != epilogue_t::gelu &&
compute_desc->_epilogue != epilogue_t::gelu_aux &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias) {
throw std::runtime_error(
"dpct::blas_gemm::experimental::matmul() only "
"supports relu, gelu, gelu_aux, gelu with bias epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -912,6 +918,7 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
::dnnl::memory::dims src_dims = {M, K};
::dnnl::memory::dims weights_dims = {K, N};
::dnnl::memory::dims bias_dims = {M, N};
::dnnl::memory::dims epilogue_bias_dims = {M, N};
::dnnl::memory::dims dst_dims = {M, N};

const ::dnnl::memory::dims src_strides =
Expand Down Expand Up @@ -1022,18 +1029,37 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
matmul_args.insert(
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *scales_alpha});
}

::dnnl::post_ops matmul_ops;
if (!beta_is_zero) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
matmul_args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *bias_mem});
}
sycl::queue &queue = ::dpct::cs::get_default_queue();
if (compute_desc->_epilogue != epilogue_t::nop) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
::dnnl::post_ops matmul_ops;
if (compute_desc->_epilogue == epilogue_t::relu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::gelu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
sizeof(size_t), dpct::device_to_device, queue,
false);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f);
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
dpct::blas::matrix_mem_copy(compute_desc->_epilogue_aux_pointer, new_c,
compute_desc->_epilogue_aux_ld, new_ldc, m, n,
sizeof(size_t), dpct::device_to_device, queue,
false);
}
matmul_attr.set_post_ops(matmul_ops);
}
matmul_attr.set_post_ops(matmul_ops);

auto matmul_pd = ::dnnl::matmul::primitive_desc(
handle->get_engine(), src_md, weights_md, dst_md, matmul_attr);
auto matmul_prim = ::dnnl::matmul(matmul_pd);
Expand Down
8 changes: 8 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,17 @@ void foo3() {
// CHECK: dpct::blas_gemm::experimental::epilogue_t e;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::bias;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux_bias;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux;
cublasLtEpilogue_t e;
e = CUBLASLT_EPILOGUE_DEFAULT;
e = CUBLASLT_EPILOGUE_RELU;
e = CUBLASLT_EPILOGUE_GELU;
e = CUBLASLT_EPILOGUE_BIAS;
e = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
e = CUBLASLT_EPILOGUE_GELU_AUX;
}

void foo4() {
Expand Down