diff --git a/clang/runtime/dpct-rt/include/dpct/fft_utils.hpp b/clang/runtime/dpct-rt/include/dpct/fft_utils.hpp index 9e45cece3d2e..742bb84c2e52 100644 --- a/clang/runtime/dpct-rt/include/dpct/fft_utils.hpp +++ b/clang/runtime/dpct-rt/include/dpct/fft_utils.hpp @@ -15,14 +15,10 @@ #include #include - namespace dpct { namespace fft { /// An enumeration type to describe the FFT direction is forward or backward. -enum fft_direction : int { - forward = 0, - backward -}; +enum fft_direction : int { forward = 0, backward }; /// An enumeration type to describe the types of FFT input and output data. enum fft_type : int { real_float_to_complex_float = 0, @@ -711,13 +707,13 @@ class fft_engine { distance); _desc_sc->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, _batch); -#ifdef __INTEL_MKL__ if (_is_user_specified_dir_and_placement && _is_inplace) _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); + oneapi::mkl::dft::config_value::INPLACE); else _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + oneapi::mkl::dft::config_value::NOT_INPLACE); +#ifdef __INTEL_MKL__ if (_use_external_workspace) { if (_q->get_device().is_gpu()) { _desc_sc->set_value( @@ -739,12 +735,6 @@ class fft_engine { } } #else - if (_is_user_specified_dir_and_placement && _is_inplace) - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - else - _desc_sc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); _desc_sc->commit(*_q); #endif } else if (_input_type == library_data_t::complex_double && @@ -763,13 +753,13 @@ class fft_engine { distance); _desc_dc->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, _batch); -#ifdef __INTEL_MKL__ if (_is_user_specified_dir_and_placement && _is_inplace) _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); + oneapi::mkl::dft::config_value::INPLACE); else _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + oneapi::mkl::dft::config_value::NOT_INPLACE); +#ifdef __INTEL_MKL__ if (_use_external_workspace) { if (_q->get_device().is_gpu()) { _desc_dc->set_value( @@ -791,12 +781,6 @@ class fft_engine { } } #else - if (_is_user_specified_dir_and_placement && _is_inplace) - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - else - _desc_dc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); _desc_dc->commit(*_q); #endif } else if ((_input_type == library_data_t::real_float && @@ -813,16 +797,16 @@ class fft_engine { _direction = fft_direction::backward; _desc_sr->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, _batch); -#ifdef __INTEL_MKL__ if (_is_user_specified_dir_and_placement && _is_inplace) { _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); + oneapi::mkl::dft::config_value::INPLACE); set_stride_and_distance_basic(_desc_sr); } else { _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + oneapi::mkl::dft::config_value::NOT_INPLACE); set_stride_and_distance_basic(_desc_sr); } +#ifdef __INTEL_MKL__ if (_use_external_workspace) { if (_q->get_device().is_gpu()) { _desc_sr->set_value( @@ -844,15 +828,6 @@ class fft_engine { } } #else - if (_is_user_specified_dir_and_placement && _is_inplace) { - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - set_stride_and_distance_basic(_desc_sr); - } else { - _desc_sr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - set_stride_and_distance_basic(_desc_sr); - } _desc_sr->commit(*_q); #endif } else if ((_input_type == library_data_t::real_double && @@ -869,16 +844,16 @@ class fft_engine { _direction = fft_direction::backward; _desc_dr->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, _batch); -#ifdef __INTEL_MKL__ if (_is_user_specified_dir_and_placement && _is_inplace) { _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); + oneapi::mkl::dft::config_value::INPLACE); set_stride_and_distance_basic(_desc_dr); } else { _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); + oneapi::mkl::dft::config_value::NOT_INPLACE); set_stride_and_distance_basic(_desc_dr); } +#ifdef __INTEL_MKL__ if (_use_external_workspace) { if (_q->get_device().is_gpu()) { _desc_dr->set_value( @@ -900,15 +875,6 @@ class fft_engine { } } #else - if (_is_user_specified_dir_and_placement && _is_inplace) { - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - set_stride_and_distance_basic(_desc_dr); - } else { - _desc_dr->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - set_stride_and_distance_basic(_desc_dr); - } _desc_dr->commit(*_q); #endif } else { @@ -918,27 +884,25 @@ class fft_engine { } void config_and_commit_advanced() { +#define CONFIG_LAYOUT_AND_PLACEMENT(DESC, PREC, DOM) \ + DESC = std::make_shared>(_n); \ + set_stride_advanced(DESC); \ + DESC->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, _fwd_dist); \ + DESC->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _bwd_dist); \ + DESC->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, \ + _batch); \ + if (_is_user_specified_dir_and_placement && _is_inplace) \ + DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ + oneapi::mkl::dft::config_value::INPLACE); \ + else \ + DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ + oneapi::mkl::dft::config_value::NOT_INPLACE); + #ifdef __INTEL_MKL__ -#define CONFIG_AND_COMMIT(DESC, PREC, DOM, TYPE) \ +#define CONFIG_AND_COMMIT(DESC, PREC, DOM) \ { \ - DESC = std::make_shared>( \ - _n); \ - set_stride_advanced(DESC); \ - DESC->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, _fwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _bwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, \ - _batch); \ - if (_is_user_specified_dir_and_placement && _is_inplace) \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - DFTI_CONFIG_VALUE::DFTI_INPLACE); \ - else \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); \ - if (_use_external_workspace) { \ - DESC->set_value(oneapi::mkl::dft::config_param::WORKSPACE, \ - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); \ - } \ + CONFIG_LAYOUT_AND_PLACEMENT(DESC, PREC, DOM) \ if (_is_estimate_call) { \ if (_q->get_device().is_gpu()) { \ DESC->get_value( \ @@ -954,46 +918,34 @@ class fft_engine { } \ } #else -#define CONFIG_AND_COMMIT(DESC, PREC, DOM, TYPE) \ +#define CONFIG_AND_COMMIT(DESC, PREC, DOM) \ { \ - DESC = std::make_shared>( \ - _n); \ - set_stride_advanced(DESC); \ - DESC->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, _fwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _bwd_dist); \ - DESC->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, \ - _batch); \ - if (_is_user_specified_dir_and_placement && _is_inplace) \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - oneapi::mkl::dft::config_value::INPLACE); \ - else \ - DESC->set_value(oneapi::mkl::dft::config_param::PLACEMENT, \ - oneapi::mkl::dft::config_value::NOT_INPLACE); \ + CONFIG_LAYOUT_AND_PLACEMENT(DESC, PREC, DOM) \ DESC->commit(*_q); \ } #endif if (_input_type == library_data_t::complex_float && _output_type == library_data_t::complex_float) { - CONFIG_AND_COMMIT(_desc_sc, SINGLE, COMPLEX, float); + CONFIG_AND_COMMIT(_desc_sc, SINGLE, COMPLEX); } else if (_input_type == library_data_t::complex_double && _output_type == library_data_t::complex_double) { - CONFIG_AND_COMMIT(_desc_dc, DOUBLE, COMPLEX, double); + CONFIG_AND_COMMIT(_desc_dc, DOUBLE, COMPLEX); } else if ((_input_type == library_data_t::real_float && _output_type == library_data_t::complex_float) || (_input_type == library_data_t::complex_float && _output_type == library_data_t::real_float)) { - CONFIG_AND_COMMIT(_desc_sr, SINGLE, REAL, float); + CONFIG_AND_COMMIT(_desc_sr, SINGLE, REAL); } else if ((_input_type == library_data_t::real_double && _output_type == library_data_t::complex_double) || (_input_type == library_data_t::complex_double && _output_type == library_data_t::real_double)) { - CONFIG_AND_COMMIT(_desc_dr, DOUBLE, REAL, double); + CONFIG_AND_COMMIT(_desc_dr, DOUBLE, REAL); } else { throw sycl::exception(sycl::make_error_code(sycl::errc::invalid), "invalid fft type"); } +#undef CONFIG_LAYOUT_AND_PLACEMENT #undef CONFIG_AND_COMMIT } @@ -1062,31 +1014,29 @@ class fft_engine { template void set_stride_advanced(std::shared_ptr desc) { if (_dim == 1) { - _fwd_strides = {0, _istride, 0, 0}; - _bwd_strides = {0, _ostride, 0, 0}; + _fwd_strides = {0, _istride}; + _bwd_strides = {0, _ostride}; } else if (_dim == 2) { - _fwd_strides = {0, _inembed[1] * _istride, _istride, 0}; - _bwd_strides = {0, _onembed[1] * _ostride, _ostride, 0}; + _fwd_strides = {0, _inembed[1] * _istride, _istride}; + _bwd_strides = {0, _onembed[1] * _ostride, _ostride}; } else if (_dim == 3) { _fwd_strides = {0, _inembed[2] * _inembed[1] * _istride, _inembed[2] * _istride, _istride}; _bwd_strides = {0, _onembed[2] * _onembed[1] * _ostride, _onembed[2] * _ostride, _ostride}; } -#ifdef __INTEL_MKL__ if (_direction == fft_direction::backward) { std::swap_ranges(_fwd_strides.begin(), _fwd_strides.end(), _bwd_strides.begin()); } +#ifdef __INTEL_MKL__ + desc->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, _fwd_strides); + desc->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, _bwd_strides); +#else desc->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, _fwd_strides.data()); desc->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, _bwd_strides.data()); -#else - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - _fwd_strides.data()); - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - _bwd_strides.data()); #endif } @@ -1096,16 +1046,19 @@ class fft_engine { desc->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, _bwd_dist); } -#ifdef __INTEL_MKL__ template void swap_strides(std::shared_ptr desc) { std::swap_ranges(_fwd_strides.begin(), _fwd_strides.end(), _bwd_strides.begin()); +#ifdef __INTEL_MKL__ + desc->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, _fwd_strides); + desc->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, _bwd_strides); +#else desc->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, _fwd_strides.data()); desc->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, _bwd_strides.data()); - } #endif + } template void set_stride_and_distance_basic(std::shared_ptr desc) { @@ -1113,25 +1066,25 @@ class fft_engine { std::int64_t backward_distance = 0; if (_dim == 1) { if constexpr (Is_inplace) { - _fwd_strides = {0, 1, 0, 0}; - _bwd_strides = {0, 1, 0, 0}; + _fwd_strides = {0, 1}; + _bwd_strides = {0, 1}; forward_distance = 2 * (_n[0] / 2 + 1); backward_distance = _n[0] / 2 + 1; } else { - _fwd_strides = {0, 1, 0, 0}; - _bwd_strides = {0, 1, 0, 0}; + _fwd_strides = {0, 1}; + _bwd_strides = {0, 1}; forward_distance = _n[0]; backward_distance = _n[0] / 2 + 1; } } else if (_dim == 2) { if constexpr (Is_inplace) { - _bwd_strides = {0, _n[1] / 2 + 1, 1, 0}; - _fwd_strides = {0, 2 * (_n[1] / 2 + 1), 1, 0}; + _bwd_strides = {0, _n[1] / 2 + 1, 1}; + _fwd_strides = {0, 2 * (_n[1] / 2 + 1), 1}; forward_distance = _n[0] * 2 * (_n[1] / 2 + 1); backward_distance = _n[0] * (_n[1] / 2 + 1); } else { - _bwd_strides = {0, _n[1] / 2 + 1, 1, 0}; - _fwd_strides = {0, _n[1], 1, 0}; + _bwd_strides = {0, _n[1] / 2 + 1, 1}; + _fwd_strides = {0, _n[1], 1}; forward_distance = _n[0] * _n[1]; backward_distance = _n[0] * (_n[1] / 2 + 1); } @@ -1149,22 +1102,13 @@ class fft_engine { } } #ifdef __INTEL_MKL__ + desc->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, _fwd_strides); + desc->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, _bwd_strides); +#else desc->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, _fwd_strides.data()); desc->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, _bwd_strides.data()); -#else - if (_direction == fft_direction::forward) { - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - _fwd_strides.data()); - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - _bwd_strides.data()); - } else { - desc->set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - _bwd_strides.data()); - desc->set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - _fwd_strides.data()); - } #endif desc->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); @@ -1194,31 +1138,14 @@ class fft_engine { if (direction.value() != _direction) { need_commit = true; swap_distance(desc); -#ifdef __INTEL_MKL__ if (!_is_basic) swap_strides(desc); -#endif _direction = direction.value(); } } if (is_this_compute_inplace != _is_inplace) { need_commit = true; _is_inplace = is_this_compute_inplace; -#ifdef __INTEL_MKL__ - if (_is_inplace) { - desc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_INPLACE); - if constexpr (!is_complex) - if (_is_basic) - set_stride_and_distance_basic(desc); - } else { - desc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_CONFIG_VALUE::DFTI_NOT_INPLACE); - if constexpr (!is_complex) - if (_is_basic) - set_stride_and_distance_basic(desc); - } -#else if (_is_inplace) { desc->set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); @@ -1232,7 +1159,6 @@ class fft_engine { if (_is_basic) set_stride_and_distance_basic(desc); } -#endif } if (need_commit) desc->commit(*_q); @@ -1297,9 +1223,9 @@ class fft_engine { bool _is_user_specified_dir_and_placement = false; bool _use_external_workspace = false; void *_external_workspace_ptr = nullptr; - size_t _workspace_bytes = 0; + std::int64_t _workspace_bytes = 0; bool _is_estimate_call = false; - size_t _workspace_estimate_bytes = 0; + std::int64_t _workspace_estimate_bytes = 0; std::shared_ptr> _desc_sr; @@ -1312,8 +1238,8 @@ class fft_engine { std::shared_ptr> _desc_dc; - std::array _fwd_strides = {0, 0, 0, 0}; - std::array _bwd_strides = {0, 0, 0, 0}; + std::vector _fwd_strides; + std::vector _bwd_strides; }; using fft_engine_ptr = fft_engine *;