Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,12 @@ void MapNamesBlas::setExplicitNamespaceMap(
{"CUBLASLT_EPILOGUE_RELU",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::relu"},
{"CUBLASLT_EPILOGUE_BIAS",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::bias"},
{"CUBLASLT_EPILOGUE_GELU_AUX_BIAS",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::epilogue_t::gelu_aux_bias"},
{"CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE",
MapNames::getLibraryHelperNamespace() +
"blas_gemm::experimental::transform_desc_t::attribute::scale_type"},
Expand Down
20 changes: 16 additions & 4 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class pointer_mode_t {
alpha_device_vector_beta_zero,
alpha_device_vector_beta_host
};
enum class epilogue_t { nop = 1, relu };
enum class epilogue_t { nop = 1, relu, bias, gelu_aux_bias };

class descriptor;
using descriptor_ptr = descriptor *;
Expand Down Expand Up @@ -780,9 +780,11 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
}

if (compute_desc->_epilogue != epilogue_t::nop &&
compute_desc->_epilogue != epilogue_t::relu) {
compute_desc->_epilogue != epilogue_t::relu &&
compute_desc->_epilogue != epilogue_t::bias &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu epilogue currently.");
"supports relu, gelu, gelu with bias epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -1027,7 +1029,17 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,

if (compute_desc->_epilogue != epilogue_t::nop) {
::dnnl::post_ops matmul_ops;
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
if (compute_desc->_epilogue == epilogue_t::relu) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_relu, 0.f, 0.f);
} else if (compute_desc->_epilogue == epilogue_t::bias) {
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
} else if (compute_desc->_epilogue == epilogue_t::gelu_aux_bias) {
matmul_ops.append_eltwise(::dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
matmul_ops.append_binary(::dnnl::algorithm::binary_add, bias_md);
dpct::blas::matrix_mem_copy(matmul_desc_t::attribute::epilogue_aux_pointer, bias_mem,
matmul_desc_t::attribute::epilogue_aux_ld, new_ldc, m, n,
sizeof(size_t) , q_ptr);
}
matmul_attr.set_post_ops(matmul_ops);
}

Expand Down
4 changes: 4 additions & 0 deletions clang/test/dpct/cublaslt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,13 @@ void foo3() {
// CHECK: dpct::blas_gemm::experimental::epilogue_t e;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::nop;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::relu;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::bias;
// CHECK-NEXT: e = dpct::blas_gemm::experimental::epilogue_t::gelu_aux_bias;
cublasLtEpilogue_t e;
e = CUBLASLT_EPILOGUE_DEFAULT;
e = CUBLASLT_EPILOGUE_RELU;
e = CUBLASLT_EPILOGUE_BIAS;
e = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
}

void foo4() {
Expand Down
Loading