diff --git a/clang/lib/DPCT/RuleInfra/MapNames.cpp b/clang/lib/DPCT/RuleInfra/MapNames.cpp index 456acfb5123a..d79d067268f8 100644 --- a/clang/lib/DPCT/RuleInfra/MapNames.cpp +++ b/clang/lib/DPCT/RuleInfra/MapNames.cpp @@ -2356,4 +2356,4 @@ const std::unordered_map } // namespace dpct -} // namespace clang \ No newline at end of file +} // namespace clang diff --git a/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp b/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp index c7524985a95f..148df4365e69 100644 --- a/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp +++ b/clang/lib/DPCT/RulesMathLib/MapNamesBlas.cpp @@ -1586,6 +1586,10 @@ void MapNamesBlas::setExplicitNamespaceMap( MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t::" "attribute::epilogue_aux_pointer"}, + {"CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE", + MapNames::getLibraryHelperNamespace() + + "blas_gemm::experimental::matmul_desc_t::" + "attribute::epilogue_aux_data_type"}, {"CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET", MapNames::getLibraryHelperNamespace() + "blas_gemm::experimental::matmul_desc_t::attribute::unsupport"}, @@ -1708,4 +1712,4 @@ void MapNamesBlas::setExplicitNamespaceMap( } } // namespace dpct -} // namespace clang \ No newline at end of file +} // namespace clang diff --git a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp index 9d9d9718eacf..a9ed88104b15 100644 --- a/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp @@ -141,6 +141,7 @@ class matmul_desc_t { epilogue, epilogue_aux_ld, epilogue_aux_pointer, + epilogue_aux_data_type, a_scale_pointer, b_scale_pointer, d_scale_pointer, @@ -186,6 +187,7 @@ class matmul_desc_t { CASE(absmax_d_pointer) CASE(epilogue_aux_ld) CASE(epilogue_aux_pointer) + CASE(epilogue_aux_data_type) default: break; } @@ -200,6 +202,7 @@ class matmul_desc_t { oneapi::mkl::transpose _trans_b = oneapi::mkl::transpose::nontrans; oneapi::mkl::transpose _trans_c = oneapi::mkl::transpose::nontrans; epilogue_t _epilogue = epilogue_t::nop; + library_data_t _epilogue_aux_data_type = library_data_t::real_float; size_t _epilogue_aux_ld = 0; void *_a_scale_pointer = nullptr; void *_b_scale_pointer = nullptr; diff --git a/clang/test/dpct/cublaslt.cu b/clang/test/dpct/cublaslt.cu index a753badb08a9..30fb82c05b4d 100644 --- a/clang/test/dpct/cublaslt.cu +++ b/clang/test/dpct/cublaslt.cu @@ -196,6 +196,7 @@ void foo3() { // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_ld; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_pointer; + // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue_aux_data_type; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::unsupport; // CHECK-NEXT: d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::a_scale_pointer; @@ -218,6 +219,7 @@ void foo3() { d = CUBLASLT_MATMUL_DESC_EPILOGUE; d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD; d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER; + d = CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE; d = CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET; d = CUBLASLT_MATMUL_DESC_FAST_ACCUM; d = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;