Skip to content

Commit d847c7b

Browse files
authored
[SYCLomatic] Use macro __INTEL_MKL__ to distinguish functionality of helper functions, make it buildable with Intel open source MKL library (#785)
open source MKL: oneAPI Math Kernel Library (oneMKL) Interfaces Signed-off-by: Tang, Jiajun [email protected]
1 parent e2544cb commit d847c7b

File tree

12 files changed

+220
-0
lines changed

12 files changed

+220
-0
lines changed

clang/runtime/dpct-rt/include/blas_utils.hpp.inc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ public:
147147
template <typename Tx, typename Tr>
148148
inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx,
149149
void *result) {
150+
#ifndef __INTEL_MKL__
151+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
152+
"Project does not support this API.");
153+
#else
150154
#ifdef DPCT_USM_LEVEL_NONE
151155
auto x_buffer = dpct::get_buffer<Tx>(x);
152156
auto r_buffer =
@@ -159,6 +163,7 @@ inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx,
159163
oneapi::mkl::blas::column_major::nrm2(q, n, reinterpret_cast<const Tx *>(x),
160164
incx, res_mem.get_ptr());
161165
#endif
166+
#endif
162167
}
163168
// DPCT_LABEL_END
164169

@@ -172,6 +177,10 @@ inline void nrm2_impl(sycl::queue &q, int n, const void *x, int incx,
172177
template <bool is_conjugate, class Txy, class Tr>
173178
inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx,
174179
const Txy *y, int incy, Tr *result) {
180+
#ifndef __INTEL_MKL__
181+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
182+
"Project does not support this API.");
183+
#else
175184
#ifdef DPCT_USM_LEVEL_NONE
176185
auto x_buffer = dpct::get_buffer<Txy>(x);
177186
auto y_buffer = dpct::get_buffer<Txy>(y);
@@ -200,6 +209,7 @@ inline void dotuc_impl(sycl::queue &q, int n, const Txy *x, int incx,
200209
} else
201210
oneapi::mkl::blas::column_major::dot(q, n, x, incx, y, incy, res_mem.get_ptr());
202211
#endif
212+
#endif
203213
}
204214
// DPCT_LABEL_END
205215

@@ -273,10 +283,15 @@ inline void dotuc(sycl::queue &q, int n, const void *x,
273283
template <class Tx, class Te>
274284
inline void scal_impl(sycl::queue &q, int n, const void *alpha, void *x,
275285
int incx) {
286+
#ifndef __INTEL_MKL__
287+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
288+
"Project does not support this API.");
289+
#else
276290
Te alpha_val = dpct::get_value(reinterpret_cast<const Te *>(alpha), q);
277291
auto data_x = get_memory(reinterpret_cast<Tx *>(x));
278292
oneapi::mkl::blas::column_major::scal(q, n, alpha_val,
279293
data_x, incx);
294+
#endif
280295
}
281296
// DPCT_LABEL_END
282297

@@ -289,12 +304,17 @@ inline void scal_impl(sycl::queue &q, int n, const void *alpha, void *x,
289304
template <class Txy, class Te>
290305
inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x,
291306
int incx, void *y, int incy) {
307+
#ifndef __INTEL_MKL__
308+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
309+
"Project does not support this API.");
310+
#else
292311
Te alpha_val = dpct::get_value(reinterpret_cast<const Te *>(alpha), q);
293312
auto data_x = get_memory(reinterpret_cast<const Txy *>(x));
294313
auto data_y = get_memory(reinterpret_cast<Txy *>(y));
295314
oneapi::mkl::blas::column_major::axpy(q, n, alpha_val,
296315
data_x, incx,
297316
data_y, incy);
317+
#endif
298318
}
299319
// DPCT_LABEL_END
300320

@@ -307,13 +327,18 @@ inline void axpy_impl(sycl::queue &q, int n, const void *alpha, const void *x,
307327
template <class Txy, class Tc, class Ts>
308328
inline void rot_impl(sycl::queue &q, int n, void *x, int incx, void *y,
309329
int incy, const void *c, const void *s) {
330+
#ifndef __INTEL_MKL__
331+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
332+
"Project does not support this API.");
333+
#else
310334
Tc c_value = dpct::get_value(reinterpret_cast<const Tc *>(c), q);
311335
Ts s_value = dpct::get_value(reinterpret_cast<const Ts *>(s), q);
312336
auto data_x = get_memory(reinterpret_cast<Txy *>(x));
313337
auto data_y = get_memory(reinterpret_cast<Txy *>(y));
314338
oneapi::mkl::blas::column_major::rot(q, n, data_x, incx,
315339
data_y, incy, c_value,
316340
s_value);
341+
#endif
317342
}
318343
// DPCT_LABEL_END
319344

@@ -328,6 +353,10 @@ inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
328353
oneapi::mkl::transpose b_trans, int m, int n, int k,
329354
const void *alpha, const void *a, int lda, const void *b,
330355
int ldb, const void *beta, void *c, int ldc) {
356+
#ifndef __INTEL_MKL__
357+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
358+
"Project does not support this API.");
359+
#else
331360
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
332361
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
333362
auto data_a = get_memory(reinterpret_cast<const Ta *>(a));
@@ -336,6 +365,7 @@ inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
336365
oneapi::mkl::blas::column_major::gemm(
337366
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
338367
data_b, ldb, beta_value, data_c, ldc);
368+
#endif
339369
}
340370
// DPCT_LABEL_END
341371

@@ -350,6 +380,10 @@ inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
350380
const void *alpha, const void **a, int lda,
351381
const void **b, int ldb, const void *beta, void **c,
352382
int ldc, int batch_size) {
383+
#ifndef __INTEL_MKL__
384+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
385+
"Project does not support this API.");
386+
#else
353387
struct matrix_info_t {
354388
oneapi::mkl::transpose transpose_info[2];
355389
Ts value_info[2];
@@ -388,6 +422,7 @@ inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
388422
cgh.depends_on(e);
389423
cgh.host_task([=] { std::free(matrix_info); });
390424
});
425+
#endif
391426
}
392427
// DPCT_LABEL_END
393428

@@ -405,6 +440,10 @@ gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
405440
long long int stride_a, const void *b, int ldb,
406441
long long int stride_b, const void *beta, void *c,
407442
int ldc, long long int stride_c, int batch_size) {
443+
#ifndef __INTEL_MKL__
444+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
445+
"Project does not support this API.");
446+
#else
408447
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
409448
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
410449
auto data_a = get_memory(reinterpret_cast<const Ta *>(a));
@@ -414,6 +453,7 @@ gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
414453
q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
415454
stride_a, data_b, ldb, stride_b, beta_value,
416455
data_c, ldc, stride_c, batch_size);
456+
#endif
417457
}
418458
// DPCT_LABEL_END
419459

@@ -566,6 +606,10 @@ trsm_batch_impl(sycl::queue &q, oneapi::mkl::side left_right,
566606
template <typename T>
567607
inline void getrfnp_batch_wrapper(sycl::queue &exec_queue, int n, T *a[],
568608
int lda, int *info, int batch_size) {
609+
#ifndef __INTEL_MKL__
610+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
611+
"Project does not support this API.");
612+
#else
569613
using Ty = typename DataType<T>::T2;
570614
// Set the info array value to 0
571615
detail::dpct_memset(exec_queue, info, 0, sizeof(int) * batch_size);
@@ -614,6 +658,7 @@ inline void getrfnp_batch_wrapper(sycl::queue &exec_queue, int n, T *a[],
614658
cgh.depends_on(events);
615659
cgh.host_task([=] { free(host_a); });
616660
});
661+
#endif
617662
}
618663
// DPCT_LABEL_END
619664

clang/runtime/dpct-rt/include/fft_utils.hpp.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum fft_type : int {
8383
// Device|get_default_queue
8484
// DPCT_DEPENDENCY_END
8585
// DPCT_CODE
86+
#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this.
8687
/// A class to perform FFT calculation.
8788
class fft_engine {
8889
public:
@@ -1292,6 +1293,7 @@ private:
12921293
};
12931294

12941295
using fft_engine_ptr = fft_engine *;
1296+
#endif
12951297
// DPCT_LABEL_END
12961298
} // namespace fft
12971299
} // namespace dpct

clang/runtime/dpct-rt/include/lapack_utils.hpp.inc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,13 +678,19 @@ template <typename T> struct getrfnp_impl {
678678
library_data_t a_type, void *a, std::int64_t lda,
679679
std::int64_t *ipiv, void *device_ws,
680680
std::size_t device_ws_size, int *info) {
681+
#ifndef __INTEL_MKL__
682+
throw std::runtime_error(
683+
"The oneAPI Math Kernel Library (oneMKL) Interfaces "
684+
"Project does not support this API.");
685+
#else
681686
std::int64_t a_stride = m * lda;
682687
auto a_data = dpct::detail::get_memory(reinterpret_cast<T *>(a));
683688
auto device_ws_data =
684689
dpct::detail::get_memory(reinterpret_cast<T *>(device_ws));
685690
oneapi::mkl::lapack::getrfnp_batch(q, m, n, a_data, lda, a_stride, 1,
686691
device_ws_data, device_ws_size);
687692
dpct::detail::dpct_memset(q, info, 0, sizeof(int));
693+
#endif
688694
}
689695
};
690696
// DPCT_LABEL_END
@@ -749,13 +755,19 @@ template <typename T> struct gesvd_conj_impl : public gesvd_impl<T> {
749755
void *u, std::int64_t ldu, library_data_t vt_type, void *vt,
750756
std::int64_t ldvt, void *device_ws,
751757
std::size_t device_ws_size, int *info) {
758+
#ifndef __INTEL_MKL__
759+
throw std::runtime_error(
760+
"The oneAPI Math Kernel Library (oneMKL) Interfaces "
761+
"Project does not support this API.");
762+
#else
752763
using base = gesvd_impl<T>;
753764
base::operator()(q, jobu, jobvt, m, n, a_type, a, lda, s_type, s, u_type, u,
754765
ldu, vt_type, vt, ldvt, device_ws, device_ws_size, info);
755766
auto vt_data = dpct::detail::get_memory(reinterpret_cast<T *>(vt));
756767
oneapi::mkl::blas::row_major::imatcopy(q, oneapi::mkl::transpose::conjtrans,
757768
n, n, T(1.0f), vt_data, ldvt, ldvt);
758769
dpct::detail::dpct_memset(q, info, 0, sizeof(int));
770+
#endif
759771
}
760772
};
761773
// DPCT_LABEL_END
@@ -883,6 +895,10 @@ inline int getrf(sycl::queue &q, std::int64_t m, std::int64_t n,
883895
library_data_t a_type, void *a, std::int64_t lda,
884896
std::int64_t *ipiv, void *device_ws,
885897
std::size_t device_ws_size, int *info) {
898+
#ifndef __INTEL_MKL__
899+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
900+
"Project does not support this API.");
901+
#else
886902
std::size_t device_ws_size_in_element_number =
887903
detail::byte_to_element_number(device_ws_size, a_type);
888904
if (ipiv == nullptr) {
@@ -893,6 +909,7 @@ inline int getrf(sycl::queue &q, std::int64_t m, std::int64_t n,
893909
return detail::lapack_shim<detail::getrf_impl>(
894910
q, a_type, info, "getrf", q, m, n, a_type, a, lda, ipiv, device_ws,
895911
device_ws_size_in_element_number, info);
912+
#endif
896913
}
897914
// DPCT_LABEL_END
898915

clang/runtime/dpct-rt/include/lib_common_utils.hpp.inc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ enum class version_field : int {
9999
/// \param field The version information field (major, minor, update or patch).
100100
/// \param result The result value.
101101
inline void mkl_get_version(version_field field, int *result) {
102+
#ifndef __INTEL_MKL__
103+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
104+
"Project does not support this API.");
105+
#else
102106
MKLVersion version;
103107
mkl_get_version(&version);
104108
if (version_field::major == field) {
@@ -112,6 +116,7 @@ inline void mkl_get_version(version_field field, int *result) {
112116
} else {
113117
throw std::runtime_error("unknown field");
114118
}
119+
#endif
115120
}
116121
// DPCT_LABEL_END
117122

clang/runtime/dpct-rt/include/rng_utils.hpp.inc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
// DPCT_CODE
4141
#include <sycl/sycl.hpp>
4242
#include <oneapi/mkl.hpp>
43+
#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this.
4344
#include <oneapi/mkl/rng/device.hpp>
45+
#endif
4446
// DPCT_LABEL_END
4547
// DPCT_LABEL_BEGIN|local_include_dependency|
4648
// DPCT_DEPENDENCY_EMPTY
@@ -51,6 +53,7 @@
5153

5254
namespace dpct {
5355
namespace rng {
56+
#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this.
5457
namespace device {
5558
// DPCT_LABEL_BEGIN|rng_generator|dpct::rng::device
5659
// DPCT_DEPENDENCY_BEGIN
@@ -265,6 +268,7 @@ private:
265268
// DPCT_LABEL_END
266269

267270
} // namespace device
271+
#endif
268272

269273
namespace host {
270274
namespace detail {
@@ -407,9 +411,14 @@ public:
407411
/// \param output The pointer of the first random number.
408412
/// \param n The number of random numbers.
409413
inline void generate_uniform_bits(unsigned int *output, std::int64_t n) {
414+
#ifndef __INTEL_MKL__
415+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) "
416+
"Interfaces Project does not support this API.");
417+
#else
410418
static_assert(sizeof(unsigned int) == sizeof(std::uint32_t));
411419
generate<oneapi::mkl::rng::uniform_bits<std::uint32_t>>(
412420
(std::uint32_t *)output, n);
421+
#endif
413422
}
414423

415424
/// Generate unsigned long long random number(s) with 'uniform_bits'
@@ -418,9 +427,14 @@ public:
418427
/// \param n The number of random numbers.
419428
inline void generate_uniform_bits(unsigned long long *output,
420429
std::int64_t n) {
430+
#ifndef __INTEL_MKL__
431+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) "
432+
"Interfaces Project does not support this API.");
433+
#else
421434
static_assert(sizeof(unsigned long long) == sizeof(std::uint64_t));
422435
generate<oneapi::mkl::rng::uniform_bits<std::uint64_t>>(
423436
(std::uint64_t *)output, n);
437+
#endif
424438
}
425439

426440
/// Generate float random number(s) with 'lognormal' distribution.
@@ -489,19 +503,27 @@ public:
489503
/// Skip ahead several random number(s).
490504
/// \param num_to_skip The number of random numbers to be skipped.
491505
void skip_ahead(const std::uint64_t num_to_skip) {
506+
#ifndef __INTEL_MKL__
507+
oneapi::mkl::rng::skip_ahead(_engine, num_to_skip);
508+
#else
492509
if constexpr (std::is_same_v<engine_t, oneapi::mkl::rng::mt2203>)
493510
throw std::runtime_error("no skip_ahead method of mt2203 engine.");
494511
else
495512
oneapi::mkl::rng::skip_ahead(_engine, num_to_skip);
513+
#endif
496514
}
497515

498516
private:
499517
static inline engine_t create_engine(sycl::queue *queue,
500518
const std::uint64_t seed,
501519
const std::uint32_t dimensions) {
520+
#ifdef __INTEL_MKL__
502521
return std::is_same_v<engine_t, oneapi::mkl::rng::sobol>
503522
? engine_t(*queue, dimensions)
504523
: engine_t(*queue, seed);
524+
#else
525+
return engine_t(*queue, seed);
526+
#endif
505527
}
506528

507529
template <typename distr_t, typename buffer_t, class... distr_params_t>
@@ -554,6 +576,10 @@ inline host_rng_ptr create_host_rng(const random_engine_type type) {
554576
case random_engine_type::mrg32k3a:
555577
return std::make_shared<
556578
rng::host::detail::rng_generator<oneapi::mkl::rng::mrg32k3a>>();
579+
#ifndef __INTEL_MKL__
580+
throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) "
581+
"Interfaces Project does not support this API.");
582+
#else
557583
case random_engine_type::mt2203:
558584
return std::make_shared<
559585
rng::host::detail::rng_generator<oneapi::mkl::rng::mt2203>>();
@@ -566,6 +592,7 @@ inline host_rng_ptr create_host_rng(const random_engine_type type) {
566592
case random_engine_type::mcg59:
567593
return std::make_shared<
568594
rng::host::detail::rng_generator<oneapi::mkl::rng::mcg59>>();
595+
#endif
569596
}
570597
}
571598
// DPCT_LABEL_END

0 commit comments

Comments
 (0)